diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1c285ec6bb0e3609918290470d9c63eebc089374
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,140 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# Temp files
+doctr/version.py
+logs/
+wandb/
+.idea/
+
+# Checkpoints
+*.pt
+*.pb
+*.index
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e66bb443e139f7a948506346e1c550e65c9a260
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,23 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: check-ast
+ - id: check-yaml
+ exclude: .conda
+ - id: check-toml
+ - id: check-json
+ - id: check-added-large-files
+ exclude: docs/images/
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+ - id: debug-statements
+ - id: check-merge-conflict
+ - id: no-commit-to-branch
+ args: ['--branch', 'main']
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.3.2
+ hooks:
+ - id: ruff
+ args: [ --fix ]
+ - id: ruff-format
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee84f1d7db0b6babc8b3c95015f9a30f05af9731
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+contact@mindee.com.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e2a849de3d5ce17850b9ad2308874100b656164
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,92 @@
+# Contributing to docTR
+
+Everything you need to know to contribute efficiently to the project.
+
+## Codebase structure
+
+- [doctr](https://github.com/mindee/doctr/blob/main/doctr) - The package codebase
+- [tests](https://github.com/mindee/doctr/blob/main/tests) - Python unit tests
+- [docs](https://github.com/mindee/doctr/blob/main/docs) - Library documentation building
+- [scripts](https://github.com/mindee/doctr/blob/main/scripts) - Example scripts
+- [references](https://github.com/mindee/doctr/blob/main/references) - Reference training scripts
+- [demo](https://github.com/mindee/doctr/blob/main/demo) - Small demo app to showcase docTR capabilities
+- [api](https://github.com/mindee/doctr/blob/main/api) - A minimal template to deploy a REST API with docTR
+
+## Continuous Integration
+
+This project uses the following integrations to ensure proper codebase maintenance:
+
+- [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage
+- [Codecov](https://codecov.io/) - reports back coverage results
+
+As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code.
+
+## Feedback
+
+### Feature requests & bug report
+
+Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/mindee/doctr/issues).
+
+First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in.
+
+### Questions
+
+If you are wondering how to do something with docTR, or a more general question, you should consider checking out Github [discussions](https://github.com/mindee/doctr/discussions). See it as a Q&A forum, or the docTR-specific StackOverflow!
+
+## Developing docTR
+
+### Developer mode installation
+
+Install all additional dependencies with the following command:
+
+```shell
+python -m pip install --upgrade pip
+pip install -e .[dev]
+pre-commit install
+```
+
+### Commits
+
+- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later.
+- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/)
+
+### Unit tests
+
+In order to run the same unit tests as the CI workflows, you can run unittests locally:
+
+```shell
+make test
+```
+
+### Code quality
+
+To run all quality checks together
+
+```shell
+make quality
+```
+
+#### Code style verification
+
+To run all style checks together
+
+```shell
+make style
+```
+
+### Modifying the documentation
+
+The current documentation is built using `sphinx` thanks to our CI.
+You can build the documentation locally:
+
+```shell
+make docs-single-version
+```
+
+Please note that files that have not been modified will not be rebuilt. If you want to force a complete rebuild, you can delete the `_build` directory. Additionally, you may need to clear your web browser's cache to see the modifications.
+
+You can now open your local version of the documentation located at `docs/_build/index.html` in your browser
+
+## Let's connect
+
+Should you wish to connect somewhere else than on GitHub, feel free to join us on [Slack](https://join.slack.com/t/mindee-community/shared_invite/zt-uzgmljfl-MotFVfH~IdEZxjp~0zldww), where you will find a `#doctr` channel!
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..beb8c8693ecd50347e346d987647f508ae164697
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,75 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV LANG=C.UTF-8
+ENV PYTHONUNBUFFERED=1
+ENV PYTHONDONTWRITEBYTECODE=1
+
+ARG SYSTEM=gpu
+
+# Enroll NVIDIA GPG public key and install CUDA
+RUN if [ "$SYSTEM" = "gpu" ]; then \
+ apt-get update && \
+ apt-get install -y gnupg ca-certificates wget && \
+ # - Install Nvidia repo keys
+ # - See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#network-repo-installation-for-ubuntu
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
+ dpkg -i cuda-keyring_1.1-1_all.deb && \
+ apt-get update && apt-get install -y --no-install-recommends \
+ cuda-command-line-tools-11-8 \
+ cuda-cudart-dev-11-8 \
+ cuda-nvcc-11-8 \
+ cuda-cupti-11-8 \
+ cuda-nvprune-11-8 \
+ cuda-libraries-11-8 \
+ cuda-nvrtc-11-8 \
+ libcufft-11-8 \
+ libcurand-11-8 \
+ libcusolver-11-8 \
+ libcusparse-11-8 \
+ libcublas-11-8 \
+ # - CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation
+ libcudnn8=8.6.0.163-1+cuda11.8 \
+ libnvinfer-plugin8=8.6.1.6-1+cuda11.8 \
+ libnvinfer8=8.6.1.6-1+cuda11.8; \
+fi
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ # - Other packages
+ build-essential \
+ pkg-config \
+ curl \
+ wget \
+ software-properties-common \
+ unzip \
+ git \
+ # - Packages to build Python
+ tar make gcc zlib1g-dev libffi-dev libssl-dev liblzma-dev libbz2-dev libsqlite3-dev \
+ # - Packages for docTR
+ libgl1-mesa-dev libsm6 libxext6 libxrender-dev libpangocairo-1.0-0 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/* \
+fi
+
+# Install Python
+ARG PYTHON_VERSION=3.10.13
+
+RUN wget http://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \
+ tar -zxf Python-$PYTHON_VERSION.tgz && \
+ cd Python-$PYTHON_VERSION && \
+ mkdir /opt/python/ && \
+ ./configure --prefix=/opt/python && \
+ make && \
+ make install && \
+ cd .. && \
+ rm Python-$PYTHON_VERSION.tgz && \
+ rm -r Python-$PYTHON_VERSION
+
+ENV PATH=/opt/python/bin:$PATH
+
+# Install docTR
+ARG FRAMEWORK=tf
+ARG DOCTR_REPO='mindee/doctr'
+ARG DOCTR_VERSION=main
+RUN pip3 install -U pip setuptools wheel && \
+ pip3 install "python-doctr[$FRAMEWORK]@git+https://github.com/$DOCTR_REPO.git@$DOCTR_VERSION"
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..63d22de4a6df22f7f83c7511e6ce983968ed4996
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 Mindee
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..7ed5339cdab029655f5334997287b718955c3648
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,33 @@
+.PHONY: quality style test test-common test-tf test-torch docs-single-version docs
+# this target runs checks on all files
+quality:
+ ruff check .
+ mypy doctr/
+
+# this target runs checks on all files and potentially modifies some of them
+style:
+ ruff check --fix .
+ ruff format .
+
+# Run tests for the library
+test:
+ coverage run -m pytest tests/common/
+ USE_TF='1' coverage run -m pytest tests/tensorflow/
+ USE_TORCH='1' coverage run -m pytest tests/pytorch/
+
+test-common:
+ coverage run -m pytest tests/common/
+
+test-tf:
+ USE_TF='1' coverage run -m pytest tests/tensorflow/
+
+test-torch:
+ USE_TORCH='1' coverage run -m pytest tests/pytorch/
+
+# Check that docs can build
+docs-single-version:
+ sphinx-build docs/source docs/_build -a
+
+# Check that docs can build
+docs:
+ cd docs && bash build.sh
diff --git a/README.md b/README.md
index 9c06ed3242541c31614fe8be691f7a2d6f1885f5..82cc3c2480c3362911de17b2fbf46ccdb3c3ac1a 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,384 @@
----
-title: Doctr
-emoji: 📚
-colorFrom: yellow
-colorTo: blue
-sdk: streamlit
-sdk_version: 1.35.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+
+
+[](https://slack.mindee.com) [](LICENSE)  [](https://github.com/mindee/doctr/pkgs/container/doctr) [](https://codecov.io/gh/mindee/doctr) [](https://www.codefactor.io/repository/github/mindee/doctr) [](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [](https://mindee.github.io/doctr) [](https://pypi.org/project/python-doctr/) [](https://huggingface.co/spaces/mindee/doctr) [](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb)
+
+
+**Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
+
+What you can expect from this repository:
+
+- efficient ways to parse textual information (localize and identify each word) from your documents
+- guidance on how to integrate this in your current architecture
+
+
+
+## Quick Tour
+
+### Getting your pretrained model
+
+End-to-End OCR is achieved in docTR using a two-stage approach: text detection (localizing words), then text recognition (identify all characters in the word).
+As such, you can select the architecture used for [text detection](https://mindee.github.io/doctr/latest/modules/models.html#doctr-models-detection), and the one for [text recognition](https://mindee.github.io/doctr/latest//modules/models.html#doctr-models-recognition) from the list of available implementations.
+
+```python
+from doctr.models import ocr_predictor
+
+model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
+```
+
+### Reading files
+
+Documents can be interpreted from PDF or images:
+
+```python
+from doctr.io import DocumentFile
+# PDF
+pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
+# Image
+single_img_doc = DocumentFile.from_images("path/to/your/img.jpg")
+# Webpage
+webpage_doc = DocumentFile.from_url("https://www.yoursite.com")
+# Multiple page images
+multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"])
+```
+
+### Putting it together
+
+Let's use the default pretrained model for an example:
+
+```python
+from doctr.io import DocumentFile
+from doctr.models import ocr_predictor
+
+model = ocr_predictor(pretrained=True)
+# PDF
+doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
+# Analyze
+result = model(doc)
+```
+
+### Dealing with rotated documents
+
+Should you use docTR on documents that include rotated pages, or pages with multiple box orientations,
+you have multiple options to handle it:
+
+- If you only use straight document pages with straight words (horizontal, same reading direction),
+consider passing `assume_straight_boxes=True` to the ocr_predictor. It will directly fit straight boxes
+on your page and return straight boxes, which makes it the fastest option.
+
+- If you want the predictor to output straight boxes (no matter the orientation of your pages, the final localizations
+will be converted to straight boxes), you need to pass `export_as_straight_boxes=True` in the predictor. Otherwise, if `assume_straight_pages=False`, it will return rotated bounding boxes (potentially with an angle of 0°).
+
+If both options are set to False, the predictor will always fit and return rotated boxes.
+
+To interpret your model's predictions, you can visualize them interactively as follows:
+
+```python
+result.show()
+```
+
+
+
+Or even rebuild the original document from its predictions:
+
+```python
+import matplotlib.pyplot as plt
+
+synthetic_pages = result.synthesize()
+plt.imshow(synthetic_pages[0]); plt.axis('off'); plt.show()
+```
+
+
+
+The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
+To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/modules/io.html#document-structure):
+
+You can also export them as a nested dict, more appropriate for JSON format:
+
+```python
+json_output = result.export()
+```
+
+### Use the KIE predictor
+
+The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and addresses in a document.
+
+The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you.
+
+```python
+from doctr.io import DocumentFile
+from doctr.models import kie_predictor
+
+# Model
+model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
+# PDF
+doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
+# Analyze
+result = model(doc)
+
+predictions = result.pages[0].predictions
+for class_name in predictions.keys():
+ list_predictions = predictions[class_name]
+ for prediction in list_predictions:
+ print(f"Prediction for {class_name}: {prediction}")
+```
+
+The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.
+
+### If you are looking for support from the Mindee team
+
+[](https://mindee.com/product/doctr)
+
+## Installation
+
+### Prerequisites
+
+Python 3.9 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR.
+
+Since we use [weasyprint](https://weasyprint.org/), you will need extra dependencies if you are not running Linux.
+
+For MacOS users, you can install them as follows:
+
+```shell
+brew install cairo pango gdk-pixbuf libffi
+```
+
+For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases).
+
+### Latest release
+
+You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows:
+
+```shell
+pip install python-doctr
+```
+
+> :warning: Please note that the basic installation is not standalone, as it does not provide a deep learning framework, which is required for the package to run.
+
+We try to keep framework-specific dependencies to a minimum. You can install framework-specific builds as follows:
+
+```shell
+# for TensorFlow
+pip install "python-doctr[tf]"
+# for PyTorch
+pip install "python-doctr[torch]"
+```
+
+For MacBooks with M1 chip, you will need some additional packages or specific versions:
+
+- TensorFlow 2: [metal plugin](https://developer.apple.com/metal/tensorflow-plugin/)
+- PyTorch: [version >= 1.12.0](https://pytorch.org/get-started/locally/#start-locally)
+
+### Developer mode
+
+Alternatively, you can install it from source, which will require you to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git).
+First clone the project repository:
+
+```shell
+git clone https://github.com/mindee/doctr.git
+pip install -e doctr/.
+```
+
+Again, if you prefer to avoid the risk of missing dependencies, you can install the TensorFlow or the PyTorch build:
+
+```shell
+# for TensorFlow
+pip install -e doctr/.[tf]
+# for PyTorch
+pip install -e doctr/.[torch]
+```
+
+## Models architectures
+
+Credits where it's due: this repository is implementing, among others, architectures from published research papers.
+
+### Text Detection
+
+- DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf).
+- LinkNet: [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/pdf/1707.03718.pdf)
+- FAST: [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/pdf/2111.02394.pdf)
+
+### Text Recognition
+
+- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
+- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
+- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
+- ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).
+- PARSeq: [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/pdf/2207.06966).
+
+## More goodies
+
+### Documentation
+
+The full package documentation is available [here](https://mindee.github.io/doctr/) for detailed specifications.
+
+### Demo app
+
+A minimal demo app is provided for you to play with our end-to-end OCR models!
+
+
+
+#### Live demo
+
+Courtesy of :hugs: [Hugging Face](https://huggingface.co/) :hugs:, docTR has now a fully deployed version available on [Spaces](https://huggingface.co/spaces)!
+Check it out [](https://huggingface.co/spaces/mindee/doctr)
+
+#### Running it locally
+
+If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required.
+
+##### Tensorflow version
+
+```shell
+pip install -r demo/tf-requirements.txt
+```
+
+Then run your app in your default browser with:
+
+```shell
+USE_TF=1 streamlit run demo/app.py
+```
+
+##### PyTorch version
+
+```shell
+pip install -r demo/pt-requirements.txt
+```
+
+Then run your app in your default browser with:
+
+```shell
+USE_TORCH=1 streamlit run demo/app.py
+```
+
+#### TensorFlow.js
+
+Instead of having your demo actually running Python, you would prefer to run everything in your web browser?
+Check out our [TensorFlow.js demo](https://github.com/mindee/doctr-tfjs-demo) to get started!
+
+
+
+### Docker container
+
+[We offer Docker container support for easy testing and deployment](https://github.com/mindee/doctr/pkgs/container/doctr).
+
+#### Using GPU with docTR Docker Images
+
+The docTR Docker images are GPU-ready and based on CUDA `11.8`.
+However, to use GPU support with these Docker images, please ensure that Docker is configured to use your GPU.
+
+To verify and configure GPU support for Docker, please follow the instructions provided in the [NVIDIA Container Toolkit Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+
+Once Docker is configured to use GPUs, you can run docTR Docker containers with GPU support:
+
+```shell
+docker run -it --gpus all ghcr.io/mindee/doctr:tf-py3.8.18-gpu-2023-09 bash
+```
+
+#### Available Tags
+
+The Docker images for docTR follow a specific tag nomenclature: `-py--`. Here's a breakdown of the tag structure:
+
+- ``: `tf` (TensorFlow) or `torch` (PyTorch).
+- ``: `3.8.18`, `3.9.18`, or `3.10.13`.
+- ``: `cpu` or `gpu`
+- ``: a tag >= `v0.7.1`
+- ``: e.g. `2023-09`
+
+Here are examples of different image tags:
+
+| Tag | Description |
+|----------------------------|---------------------------------------------------|
+| `tf-py3.8.18-cpu-v0.7.1` | TensorFlow version `3.8.18` with docTR `v0.7.1`. |
+| `torch-py3.9.18-gpu-2023-09`| PyTorch version `3.9.18` with GPU support and a monthly build from `2023-09`. |
+
+#### Building Docker Images Locally
+
+You can also build docTR Docker images locally on your computer.
+
+```shell
+docker build -t doctr .
+```
+
+You can specify custom Python versions and docTR versions using build arguments. For example, to build a docTR image with TensorFlow, Python version `3.9.10`, and docTR version `v0.7.0`, run the following command:
+
+```shell
+docker build -t doctr --build-arg FRAMEWORK=tf --build-arg PYTHON_VERSION=3.9.10 --build-arg DOCTR_VERSION=v0.7.0 .
+```
+
+### Example script
+
+An example script is provided for a simple documentation analysis of a PDF or image file:
+
+```shell
+python scripts/analyze.py path/to/your/doc.pdf
+```
+
+All script arguments can be checked using `python scripts/analyze.py --help`
+
+### Minimal API integration
+
+Looking to integrate docTR into your API? Here is a template to get you started with a fully working API using the wonderful [FastAPI](https://github.com/tiangolo/fastapi) framework.
+
+#### Deploy your API locally
+
+Specific dependencies are required to run the API template, which you can install as follows:
+
+```shell
+cd api/
+pip install poetry
+make lock
+pip install -r requirements.txt
+```
+
+You can now run your API locally:
+
+```shell
+uvicorn --reload --workers 1 --host 0.0.0.0 --port=8002 --app-dir api/ app.main:app
+```
+
+Alternatively, you can run the same server on a docker container if you prefer using:
+
+```shell
+PORT=8002 docker-compose up -d --build
+```
+
+#### What you have deployed
+
+Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route:
+
+```python
+import requests
+with open('/path/to/your/doc.jpg', 'rb') as f:
+ data = f.read()
+response = requests.post("http://localhost:8002/ocr", files={'file': data}).json()
+```
+
+### Example notebooks
+
+Looking for more illustrations of docTR features? You might want to check the [Jupyter notebooks](https://github.com/mindee/doctr/tree/main/notebooks) designed to give you a broader overview.
+
+## Citation
+
+If you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference:
+
+```bibtex
+@misc{doctr2021,
+ title={docTR: Document Text Recognition},
+ author={Mindee},
+ year={2021},
+ publisher = {GitHub},
+ howpublished = {\url{https://github.com/mindee/doctr}}
+}
+```
+
+## Contributing
+
+If you scrolled down to this section, you most likely appreciate open source. Do you feel like extending the range of our supported characters? Or perhaps submitting a paper implementation? Or contributing in any other way?
+
+You're in luck, we compiled a short guide (cf. [`CONTRIBUTING`](https://mindee.github.io/doctr/contributing/contributing.html)) for you to easily do so!
+
+## License
+
+Distributed under the Apache 2.0 License. See [`LICENSE`](https://github.com/mindee/doctr?tab=Apache-2.0-1-ov-file#readme) for more information.
diff --git a/backend/pytorch.py b/backend/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf0d725cbbf80a2abc5be770bea0b177867858a
--- /dev/null
+++ b/backend/pytorch.py
@@ -0,0 +1,93 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import numpy as np
+import torch
+
+from doctr.models import ocr_predictor
+from doctr.models.predictor import OCRPredictor
+
+DET_ARCHS = [
+ "db_resnet50",
+ "db_resnet34",
+ "db_mobilenet_v3_large",
+ "linknet_resnet18",
+ "linknet_resnet34",
+ "linknet_resnet50",
+ "fast_tiny",
+ "fast_small",
+ "fast_base",
+]
+
+RECO_ARCHS = [
+ "crnn_vgg16_bn",
+ "crnn_mobilenet_v3_small",
+ "crnn_mobilenet_v3_large",
+ "master",
+ "sar_resnet31",
+ "vitstr_small",
+ "vitstr_base",
+ "parseq",
+]
+
+
+def load_predictor(
+ det_arch: str,
+ reco_arch: str,
+ assume_straight_pages: bool,
+ straighten_pages: bool,
+ bin_thresh: float,
+ box_thresh: float,
+ device: torch.device,
+) -> OCRPredictor:
+ """Load a predictor from doctr.models
+
+ Args:
+ ----
+ det_arch: detection architecture
+ reco_arch: recognition architecture
+ assume_straight_pages: whether to assume straight pages or not
+ straighten_pages: whether to straighten rotated pages or not
+ bin_thresh: binarization threshold for the segmentation map
+ box_thresh: minimal objectness score to consider a box
+ device: torch.device, the device to load the predictor on
+
+ Returns:
+ -------
+ instance of OCRPredictor
+ """
+ predictor = ocr_predictor(
+ det_arch,
+ reco_arch,
+ pretrained=True,
+ assume_straight_pages=assume_straight_pages,
+ straighten_pages=straighten_pages,
+ export_as_straight_boxes=straighten_pages,
+ detect_orientation=not assume_straight_pages,
+ ).to(device)
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
+ return predictor
+
+
+def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
+ """Forward an image through the predictor
+
+ Args:
+ ----
+ predictor: instance of OCRPredictor
+ image: image to process
+ device: torch.device, the device to process the image on
+
+ Returns:
+ -------
+ segmentation map
+ """
+ with torch.no_grad():
+ processed_batches = predictor.det_predictor.pre_processor([image])
+ out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
+ seg_map = out["out_map"].to("cpu").numpy()
+
+ return seg_map
diff --git a/doctr/__init__.py b/doctr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf0a22430288bbe0d5618ae7dae2e8e97126fa38
--- /dev/null
+++ b/doctr/__init__.py
@@ -0,0 +1,3 @@
+from . import io, models, datasets, transforms, utils
+from .file_utils import is_tf_available, is_torch_available
+from .version import __version__ # noqa: F401
diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b52035ff062341470f53badc8e023477fbfb2b3
--- /dev/null
+++ b/doctr/datasets/__init__.py
@@ -0,0 +1,26 @@
+from doctr.file_utils import is_tf_available
+
+from .generator import *
+from .cord import *
+from .detection import *
+from .doc_artefacts import *
+from .funsd import *
+from .ic03 import *
+from .ic13 import *
+from .iiit5k import *
+from .iiithws import *
+from .imgur5k import *
+from .mjsynth import *
+from .ocr import *
+from .recognition import *
+from .orientation import *
+from .sroie import *
+from .svhn import *
+from .svt import *
+from .synthtext import *
+from .utils import *
+from .vocabs import *
+from .wildreceipt import *
+
+if is_tf_available():
+ from .loader import *
diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88fbb28e89e3327b5ce5603bf6cd865b8febb3b
--- /dev/null
+++ b/doctr/datasets/cord.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["CORD"]
+
+
+class CORD(VisionDataset):
+ """CORD dataset from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/cord-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import CORD
+ >>> train_set = CORD(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ TRAIN = (
+ "https://doctr-static.mindee.com/models?id=v0.1.1/cord_train.zip&src=0",
+ "45f9dc77f126490f3e52d7cb4f70ef3c57e649ea86d19d862a2757c9c455d7f8",
+ "cord_train.zip",
+ )
+
+ TEST = (
+ "https://doctr-static.mindee.com/models?id=v0.1.1/cord_test.zip&src=0",
+ "8c895e3d6f7e1161c5b7245e3723ce15c04d84be89eaa6093949b75a66fb3c58",
+ "cord_test.zip",
+ )
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ url, sha256, name = self.TRAIN if train else self.TEST
+ super().__init__(
+ url,
+ name,
+ sha256,
+ True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+
+ # List images
+ tmp_root = os.path.join(self.root, "image")
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ self.train = train
+ np_dtype = np.float32
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
+
+ stem = Path(img_path).stem
+ _targets = []
+ with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f:
+ label = json.load(f)
+ for line in label["valid_line"]:
+ for word in line["words"]:
+ if len(word["text"]) > 0:
+ x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
+ y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
+ box: Union[List[float], np.ndarray]
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box = np.array(
+ [
+ [x[0], y[0]],
+ [x[1], y[1]],
+ [x[2], y[2]],
+ [x[3], y[3]],
+ ],
+ dtype=np_dtype,
+ )
+ else:
+ # Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax
+ box = [min(x), min(y), max(x), max(y)]
+ _targets.append((word["text"], box))
+
+ text_targets, box_targets = zip(*_targets)
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
+ )
+ for crop, label in zip(crops, list(text_targets)):
+ self.data.append((crop, label))
+ else:
+ self.data.append((
+ img_path,
+ dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
+ ))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/datasets/__init__.py b/doctr/datasets/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/datasets/datasets/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..58f1ca29f6b5e3eae62d587a9444fb63b2e4c340
--- /dev/null
+++ b/doctr/datasets/datasets/base.py
@@ -0,0 +1,132 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+import shutil
+from pathlib import Path
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from doctr.io.image import get_img_shape
+from doctr.utils.data import download_from_url
+
+from ...models.utils import _copy_tensor
+
+__all__ = ["_AbstractDataset", "_VisionDataset"]
+
+
+class _AbstractDataset:
+ data: List[Any] = []
+ _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ img_transforms: Optional[Callable[[Any], Any]] = None,
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
+ pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
+ ) -> None:
+ if not Path(root).is_dir():
+ raise ValueError(f"expected a path to a reachable folder: {root}")
+
+ self.root = root
+ self.img_transforms = img_transforms
+ self.sample_transforms = sample_transforms
+ self._pre_transforms = pre_transforms
+ self._get_img_shape = get_img_shape
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _read_sample(self, index: int) -> Tuple[Any, Any]:
+ raise NotImplementedError
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ # Read image
+ img, target = self._read_sample(index)
+ # Pre-transforms (format conversion at run-time etc.)
+ if self._pre_transforms is not None:
+ img, target = self._pre_transforms(img, target)
+
+ if self.img_transforms is not None:
+ # typing issue cf. https://github.com/python/mypy/issues/5485
+ img = self.img_transforms(img)
+
+ if self.sample_transforms is not None:
+ # Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
+ if (
+ isinstance(target, dict)
+ and all(isinstance(item, np.ndarray) for item in target.values())
+ and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
+ ):
+ img_transformed = _copy_tensor(img)
+ for class_name, bboxes in target.items():
+ img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
+ img = img_transformed
+ else:
+ img, target = self.sample_transforms(img, target)
+
+ return img, target
+
+ def extra_repr(self) -> str:
+ return ""
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.extra_repr()})"
+
+
+class _VisionDataset(_AbstractDataset):
+ """Implements an abstract dataset
+
+ Args:
+ ----
+ url: URL of the dataset
+ file_name: name of the file once downloaded
+ file_hash: expected SHA256 of the file
+ extract_archive: whether the downloaded file is an archive to be extracted
+ download: whether the dataset should be downloaded if not present on disk
+ overwrite: whether the archive should be re-extracted
+ cache_dir: cache directory
+ cache_subdir: subfolder to use in the cache
+ """
+
+ def __init__(
+ self,
+ url: str,
+ file_name: Optional[str] = None,
+ file_hash: Optional[str] = None,
+ extract_archive: bool = False,
+ download: bool = False,
+ overwrite: bool = False,
+ cache_dir: Optional[str] = None,
+ cache_subdir: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ cache_dir = (
+ str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
+ if cache_dir is None
+ else cache_dir
+ )
+
+ cache_subdir = "datasets" if cache_subdir is None else cache_subdir
+
+ file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
+ # Download the file if not present
+ archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
+
+ if not os.path.exists(archive_path) and not download:
+ raise ValueError("the dataset needs to be downloaded first with download=True")
+
+ archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
+
+ # Extract the archive
+ if extract_archive:
+ archive_path = Path(archive_path)
+ dataset_path = archive_path.parent.joinpath(archive_path.stem)
+ if not dataset_path.is_dir() or overwrite:
+ shutil.unpack_archive(archive_path, dataset_path)
+
+ super().__init__(dataset_path if extract_archive else archive_path, **kwargs)
diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd4d8401680152c35cdd10e16d970115751ef596
--- /dev/null
+++ b/doctr/datasets/datasets/pytorch.py
@@ -0,0 +1,59 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from copy import deepcopy
+from typing import Any, List, Tuple
+
+import numpy as np
+import torch
+
+from doctr.io import read_img_as_tensor, tensor_from_numpy
+
+from .base import _AbstractDataset, _VisionDataset
+
+__all__ = ["AbstractDataset", "VisionDataset"]
+
+
+class AbstractDataset(_AbstractDataset):
+ """Abstract class for all datasets"""
+
+ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
+ img_name, target = self.data[index]
+
+ # Check target
+ if isinstance(target, dict):
+ assert "boxes" in target, "Target should contain 'boxes' key"
+ assert "labels" in target, "Target should contain 'labels' key"
+ elif isinstance(target, tuple):
+ assert len(target) == 2
+ assert isinstance(target[0], str) or isinstance(
+ target[0], np.ndarray
+ ), "first element of the tuple should be a string or a numpy array"
+ assert isinstance(target[1], list), "second element of the tuple should be a list"
+ else:
+ assert isinstance(target, str) or isinstance(
+ target, np.ndarray
+ ), "Target should be a string or a numpy array"
+
+ # Read image
+ img = (
+ tensor_from_numpy(img_name, dtype=torch.float32)
+ if isinstance(img_name, np.ndarray)
+ else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
+ )
+
+ return img, deepcopy(target)
+
+ @staticmethod
+ def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
+ images, targets = zip(*samples)
+ images = torch.stack(images, dim=0) # type: ignore[assignment]
+
+ return images, list(targets) # type: ignore[return-value]
+
+
+class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
+ pass
diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b7b79289a9cf89f7f3e0c4cf0e7046ba002f75
--- /dev/null
+++ b/doctr/datasets/datasets/tensorflow.py
@@ -0,0 +1,59 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from copy import deepcopy
+from typing import Any, List, Tuple
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.io import read_img_as_tensor, tensor_from_numpy
+
+from .base import _AbstractDataset, _VisionDataset
+
+__all__ = ["AbstractDataset", "VisionDataset"]
+
+
+class AbstractDataset(_AbstractDataset):
+ """Abstract class for all datasets"""
+
+ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
+ img_name, target = self.data[index]
+
+ # Check target
+ if isinstance(target, dict):
+ assert "boxes" in target, "Target should contain 'boxes' key"
+ assert "labels" in target, "Target should contain 'labels' key"
+ elif isinstance(target, tuple):
+ assert len(target) == 2
+ assert isinstance(target[0], str) or isinstance(
+ target[0], np.ndarray
+ ), "first element of the tuple should be a string or a numpy array"
+ assert isinstance(target[1], list), "second element of the tuple should be a list"
+ else:
+ assert isinstance(target, str) or isinstance(
+ target, np.ndarray
+ ), "Target should be a string or a numpy array"
+
+ # Read image
+ img = (
+ tensor_from_numpy(img_name, dtype=tf.float32)
+ if isinstance(img_name, np.ndarray)
+ else read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32)
+ )
+
+ return img, deepcopy(target)
+
+ @staticmethod
+ def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
+ images, targets = zip(*samples)
+ images = tf.stack(images, axis=0)
+
+ return images, list(targets)
+
+
+class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
+ pass
diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..0000704dfa2a2e074924a9efef5e6fae75d4d018
--- /dev/null
+++ b/doctr/datasets/detection.py
@@ -0,0 +1,98 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from typing import Any, Dict, List, Tuple, Type, Union
+
+import numpy as np
+
+from doctr.file_utils import CLASS_NAME
+
+from .datasets import AbstractDataset
+from .utils import pre_transform_multiclass
+
+__all__ = ["DetectionDataset"]
+
+
+class DetectionDataset(AbstractDataset):
+ """Implements a text detection dataset
+
+ >>> from doctr.datasets import DetectionDataset
+ >>> train_set = DetectionDataset(img_folder="/path/to/images",
+ >>> label_path="/path/to/labels.json")
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_path: path to the annotations of each image
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_path: str,
+ use_polygons: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ img_folder,
+ pre_transforms=pre_transform_multiclass,
+ **kwargs,
+ )
+
+ # File existence check
+ self._class_names: List = []
+ if not os.path.exists(label_path):
+ raise FileNotFoundError(f"unable to locate {label_path}")
+ with open(label_path, "rb") as f:
+ labels = json.load(f)
+
+ self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
+ np_dtype = np.float32
+ for img_name, label in labels.items():
+ # File existence check
+ if not os.path.exists(os.path.join(self.root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
+
+ geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)
+
+ self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
+
+ def format_polygons(
+ self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
+ ) -> Tuple[np.ndarray, List[str]]:
+ """Format polygons into an array
+
+ Args:
+ ----
+ polygons: the bounding boxes
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ np_dtype: dtype of array
+
+ Returns:
+ -------
+ geoms: bounding boxes as np array
+ polygons_classes: list of classes for each bounding box
+ """
+ if isinstance(polygons, list):
+ self._class_names += [CLASS_NAME]
+ polygons_classes = [CLASS_NAME for _ in polygons]
+ _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
+ elif isinstance(polygons, dict):
+ self._class_names += list(polygons.keys())
+ polygons_classes = [k for k, v in polygons.items() for _ in v]
+ _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
+ else:
+ raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
+ geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
+ return geoms, polygons_classes
+
+ @property
+ def class_names(self):
+ return sorted(set(self._class_names))
diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a05a01150316970521086e7722ce628808287a2
--- /dev/null
+++ b/doctr/datasets/doc_artefacts.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+
+from .datasets import VisionDataset
+
+__all__ = ["DocArtefacts"]
+
+
+class DocArtefacts(VisionDataset):
+ """Object detection dataset for non-textual elements in documents.
+ The dataset includes a variety of synthetic document pages with non-textual elements.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import DocArtefacts
+ >>> train_set = DocArtefacts(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0"
+ SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b"
+ CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"]
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(self.URL, None, self.SHA256, True, **kwargs)
+ self.train = train
+
+ # Update root
+ self.root = os.path.join(self.root, "train" if train else "val")
+ # List images
+ tmp_root = os.path.join(self.root, "images")
+ with open(os.path.join(self.root, "labels.json"), "rb") as f:
+ labels = json.load(f)
+ self.data: List[Tuple[str, Dict[str, Any]]] = []
+ img_list = os.listdir(tmp_root)
+ if len(labels) != len(img_list):
+ raise AssertionError("the number of images and labels do not match")
+ np_dtype = np.float32
+ for img_name, label in labels.items():
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
+
+ # xmin, ymin, xmax, ymax
+ boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype)
+ classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64)
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ boxes = np.stack(
+ [
+ np.stack([boxes[:, 0], boxes[:, 1]], axis=-1),
+ np.stack([boxes[:, 2], boxes[:, 1]], axis=-1),
+ np.stack([boxes[:, 2], boxes[:, 3]], axis=-1),
+ np.stack([boxes[:, 0], boxes[:, 3]], axis=-1),
+ ],
+ axis=1,
+ )
+ self.data.append((img_name, dict(boxes=boxes, labels=classes)))
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0580b473a7ad39b56c3a6593948d7234b7a787bf
--- /dev/null
+++ b/doctr/datasets/funsd.py
@@ -0,0 +1,112 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["FUNSD"]
+
+
+class FUNSD(VisionDataset):
+ """FUNSD dataset from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/funsd-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import FUNSD
+ >>> train_set = FUNSD(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ URL = "https://guillaumejaume.github.io/FUNSD/dataset.zip"
+ SHA256 = "c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f"
+ FILE_NAME = "funsd.zip"
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ self.URL,
+ self.FILE_NAME,
+ self.SHA256,
+ True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+ np_dtype = np.float32
+
+ # Use the subset
+ subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
+
+ # # List images
+ tmp_root = os.path.join(self.root, subfolder, "images")
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
+
+ stem = Path(img_path).stem
+ with open(os.path.join(self.root, subfolder, "annotations", f"{stem}.json"), "rb") as f:
+ data = json.load(f)
+
+ _targets = [
+ (word["text"], word["box"])
+ for block in data["form"]
+ for word in block["words"]
+ if len(word["text"]) > 0
+ ]
+ text_targets, box_targets = zip(*_targets)
+ if use_polygons:
+ # xmin, ymin, xmax, ymax -> (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box_targets = [ # type: ignore[assignment]
+ [
+ [box[0], box[1]],
+ [box[2], box[1]],
+ [box[2], box[3]],
+ [box[0], box[3]],
+ ]
+ for box in box_targets
+ ]
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=np_dtype)
+ )
+ for crop, label in zip(crops, list(text_targets)):
+ # filter labels with unknown characters
+ if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
+ self.data.append((crop, label))
+ else:
+ self.data.append((
+ img_path,
+ dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(text_targets)),
+ ))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/generator/__init__.py b/doctr/datasets/generator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/datasets/generator/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..424f59563d1165989dfe12ea06ab6410e7241fb9
--- /dev/null
+++ b/doctr/datasets/generator/base.py
@@ -0,0 +1,155 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import random
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+from PIL import Image, ImageDraw
+
+from doctr.io.image import tensor_from_pil
+from doctr.utils.fonts import get_font
+
+from ..datasets import AbstractDataset
+
+
+def synthesize_text_img(
+ text: str,
+ font_size: int = 32,
+ font_family: Optional[str] = None,
+ background_color: Optional[Tuple[int, int, int]] = None,
+ text_color: Optional[Tuple[int, int, int]] = None,
+) -> Image.Image:
+ """Generate a synthetic text image
+
+ Args:
+ ----
+ text: the text to render as an image
+ font_size: the size of the font
+ font_family: the font family (has to be installed on your system)
+ background_color: background color of the final image
+ text_color: text color on the final image
+
+ Returns:
+ -------
+ PIL image of the text
+ """
+ background_color = (0, 0, 0) if background_color is None else background_color
+ text_color = (255, 255, 255) if text_color is None else text_color
+
+ font = get_font(font_family, font_size)
+ left, top, right, bottom = font.getbbox(text)
+ text_w, text_h = right - left, bottom - top
+ h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
+ # If single letter, make the image square, otherwise expand to meet the text size
+ img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))
+
+ img = Image.new("RGB", img_size[::-1], color=background_color)
+ d = ImageDraw.Draw(img)
+
+ # Offset so that the text is centered
+ text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
+ # Draw the text
+ d.text(text_pos, text, font=font, fill=text_color)
+ return img
+
+
+class _CharacterGenerator(AbstractDataset):
+ def __init__(
+ self,
+ vocab: str,
+ num_samples: int,
+ cache_samples: bool = False,
+ font_family: Optional[Union[str, List[str]]] = None,
+ img_transforms: Optional[Callable[[Any], Any]] = None,
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
+ ) -> None:
+ self.vocab = vocab
+ self._num_samples = num_samples
+ self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
+ # Validate fonts
+ if isinstance(font_family, list):
+ for font in self.font_family:
+ try:
+ _ = get_font(font, 10)
+ except OSError:
+ raise ValueError(f"unable to locate font: {font}")
+ self.img_transforms = img_transforms
+ self.sample_transforms = sample_transforms
+
+ self._data: List[Image.Image] = []
+ if cache_samples:
+ self._data = [
+ (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
+ for idx, char in enumerate(self.vocab)
+ for font in self.font_family
+ ]
+
+ def __len__(self) -> int:
+ return self._num_samples
+
+ def _read_sample(self, index: int) -> Tuple[Any, int]:
+ # Samples are already cached
+ if len(self._data) > 0:
+ idx = index % len(self._data)
+ pil_img, target = self._data[idx] # type: ignore[misc]
+ else:
+ target = index % len(self.vocab)
+ pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
+ img = tensor_from_pil(pil_img)
+
+ return img, target
+
+
+class _WordGenerator(AbstractDataset):
+ def __init__(
+ self,
+ vocab: str,
+ min_chars: int,
+ max_chars: int,
+ num_samples: int,
+ cache_samples: bool = False,
+ font_family: Optional[Union[str, List[str]]] = None,
+ img_transforms: Optional[Callable[[Any], Any]] = None,
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
+ ) -> None:
+ self.vocab = vocab
+ self.wordlen_range = (min_chars, max_chars)
+ self._num_samples = num_samples
+ self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
+ # Validate fonts
+ if isinstance(font_family, list):
+ for font in self.font_family:
+ try:
+ _ = get_font(font, 10)
+ except OSError:
+ raise ValueError(f"unable to locate font: {font}")
+ self.img_transforms = img_transforms
+ self.sample_transforms = sample_transforms
+
+ self._data: List[Image.Image] = []
+ if cache_samples:
+ _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
+ self._data = [
+ (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
+ for text in _words
+ ]
+
+ def _generate_string(self, min_chars: int, max_chars: int) -> str:
+ num_chars = random.randint(min_chars, max_chars)
+ return "".join(random.choice(self.vocab) for _ in range(num_chars))
+
+ def __len__(self) -> int:
+ return self._num_samples
+
+ def _read_sample(self, index: int) -> Tuple[Any, str]:
+ # Samples are already cached
+ if len(self._data) > 0:
+ pil_img, target = self._data[index] # type: ignore[misc]
+ else:
+ target = self._generate_string(*self.wordlen_range)
+ pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
+ img = tensor_from_pil(pil_img)
+
+ return img, target
diff --git a/doctr/datasets/generator/pytorch.py b/doctr/datasets/generator/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b254c91e4a383f49721a58939bdba92660b00cba
--- /dev/null
+++ b/doctr/datasets/generator/pytorch.py
@@ -0,0 +1,54 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from torch.utils.data._utils.collate import default_collate
+
+from .base import _CharacterGenerator, _WordGenerator
+
+__all__ = ["CharacterGenerator", "WordGenerator"]
+
+
+class CharacterGenerator(_CharacterGenerator):
+ """Implements a character image generation dataset
+
+ >>> from doctr.datasets import CharacterGenerator
+ >>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
+ >>> img, target = ds[0]
+
+ Args:
+ ----
+ vocab: vocabulary to take the character from
+ num_samples: number of samples that will be generated iterating over the dataset
+ cache_samples: whether generated images should be cached firsthand
+ font_family: font to use to generate the text images
+ img_transforms: composable transformations that will be applied to each image
+ sample_transforms: composable transformations that will be applied to both the image and the target
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ setattr(self, "collate_fn", default_collate)
+
+
+class WordGenerator(_WordGenerator):
+ """Implements a character image generation dataset
+
+ >>> from doctr.datasets import WordGenerator
+ >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
+ >>> img, target = ds[0]
+
+ Args:
+ ----
+ vocab: vocabulary to take the character from
+ min_chars: minimum number of characters in a word
+ max_chars: maximum number of characters in a word
+ num_samples: number of samples that will be generated iterating over the dataset
+ cache_samples: whether generated images should be cached firsthand
+ font_family: font to use to generate the text images
+ img_transforms: composable transformations that will be applied to each image
+ sample_transforms: composable transformations that will be applied to both the image and the target
+ """
+
+ pass
diff --git a/doctr/datasets/generator/tensorflow.py b/doctr/datasets/generator/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e205e03862db57360a7ec3c38a350c12cd7bb7
--- /dev/null
+++ b/doctr/datasets/generator/tensorflow.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import tensorflow as tf
+
+from .base import _CharacterGenerator, _WordGenerator
+
+__all__ = ["CharacterGenerator", "WordGenerator"]
+
+
+class CharacterGenerator(_CharacterGenerator):
+ """Implements a character image generation dataset
+
+ >>> from doctr.datasets import CharacterGenerator
+ >>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
+ >>> img, target = ds[0]
+
+ Args:
+ ----
+ vocab: vocabulary to take the character from
+ num_samples: number of samples that will be generated iterating over the dataset
+ cache_samples: whether generated images should be cached firsthand
+ font_family: font to use to generate the text images
+ img_transforms: composable transformations that will be applied to each image
+ sample_transforms: composable transformations that will be applied to both the image and the target
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ @staticmethod
+ def collate_fn(samples):
+ images, targets = zip(*samples)
+ images = tf.stack(images, axis=0)
+
+ return images, tf.convert_to_tensor(targets)
+
+
+class WordGenerator(_WordGenerator):
+ """Implements a character image generation dataset
+
+ >>> from doctr.datasets import WordGenerator
+ >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
+ >>> img, target = ds[0]
+
+ Args:
+ ----
+ vocab: vocabulary to take the character from
+ min_chars: minimum number of characters in a word
+ max_chars: maximum number of characters in a word
+ num_samples: number of samples that will be generated iterating over the dataset
+ cache_samples: whether generated images should be cached firsthand
+ font_family: font to use to generate the text images
+ img_transforms: composable transformations that will be applied to each image
+ sample_transforms: composable transformations that will be applied to both the image and the target
+ """
+
+ pass
diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f080e4d450b1eac9f630435f3d93b239095db0f
--- /dev/null
+++ b/doctr/datasets/ic03.py
@@ -0,0 +1,126 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import defusedxml.ElementTree as ET
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["IC03"]
+
+
+class IC03(VisionDataset):
+ """IC03 dataset from `"ICDAR 2003 Robust Reading Competitions: Entries, Results and Future Directions"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic03-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import IC03
+ >>> train_set = IC03(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ TRAIN = (
+ "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTrain/scene.zip",
+ "9d86df514eb09dd693fb0b8c671ef54a0cfe02e803b1bbef9fc676061502eb94",
+ "ic03_train.zip",
+ )
+ TEST = (
+ "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTest/scene.zip",
+ "dbc4b5fd5d04616b8464a1b42ea22db351ee22c2546dd15ac35611857ea111f8",
+ "ic03_test.zip",
+ )
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ url, sha256, file_name = self.TRAIN if train else self.TEST
+ super().__init__(
+ url,
+ file_name,
+ sha256,
+ True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ # Load xml data
+ tmp_root = (
+ os.path.join(self.root, "SceneTrialTrain" if self.train else "SceneTrialTest") if sha256 else self.root
+ )
+ xml_tree = ET.parse(os.path.join(tmp_root, "words.xml"))
+ xml_root = xml_tree.getroot()
+
+ for image in tqdm(iterable=xml_root, desc="Unpacking IC03", total=len(xml_root)):
+ name, _resolution, rectangles = image
+
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, name.text)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
+
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ _boxes = [
+ [
+ [float(rect.attrib["x"]), float(rect.attrib["y"])],
+ [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
+ [
+ float(rect.attrib["x"]) + float(rect.attrib["width"]),
+ float(rect.attrib["y"]) + float(rect.attrib["height"]),
+ ],
+ [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
+ ]
+ for rect in rectangles
+ ]
+ else:
+ # x_min, y_min, x_max, y_max
+ _boxes = [
+ [
+ float(rect.attrib["x"]), # type: ignore[list-item]
+ float(rect.attrib["y"]), # type: ignore[list-item]
+ float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
+ float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
+ ]
+ for rect in rectangles
+ ]
+
+ # filter images without boxes
+ if len(_boxes) > 0:
+ boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
+ # Get the labels
+ labels = [lab.text for rect in rectangles for lab in rect if lab.text]
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
+ for crop, label in zip(crops, labels):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ self.data.append((crop, label))
+ else:
+ self.data.append((name.text, dict(boxes=boxes, labels=labels)))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py
new file mode 100644
index 0000000000000000000000000000000000000000..81ba62f00145487d5af0f6937305019da65ce210
--- /dev/null
+++ b/doctr/datasets/ic13.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import csv
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import AbstractDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["IC13"]
+
+
+class IC13(AbstractDataset):
+ """IC13 dataset from `"ICDAR 2013 Robust Reading Competition" `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic13-grid.png&src=0
+ :align: center
+
+ >>> # NOTE: You need to download both image and label parts from Focused Scene Text challenge Task2.1 2013-2015.
+ >>> from doctr.datasets import IC13
+ >>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images",
+ >>> label_folder="/path/to/Challenge2_Training_Task1_GT")
+ >>> img, target = train_set[0]
+ >>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images",
+ >>> label_folder="/path/to/Challenge2_Test_Task1_GT")
+ >>> img, target = test_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_folder: folder with all annotation files for the images
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_folder: str,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
+ )
+
+ # File existence check
+ if not os.path.exists(label_folder) or not os.path.exists(img_folder):
+ raise FileNotFoundError(
+ f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}"
+ )
+
+ self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ img_names = os.listdir(img_folder)
+
+ for img_name in tqdm(iterable=img_names, desc="Unpacking IC13", total=len(img_names)):
+ img_path = Path(img_folder, img_name)
+ label_path = Path(label_folder, "gt_" + Path(img_name).stem + ".txt")
+
+ with open(label_path, newline="\n") as f:
+ _lines = [
+ [val[:-1] if val.endswith(",") else val for val in row]
+ for row in csv.reader(f, delimiter=" ", quotechar="'")
+ ]
+ labels = [line[-1].replace('"', "") for line in _lines]
+ # xmin, ymin, xmax, ymax
+ box_targets: np.ndarray = np.array([list(map(int, line[:4])) for line in _lines], dtype=np_dtype)
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box_targets = np.array(
+ [
+ [
+ [coords[0], coords[1]],
+ [coords[2], coords[1]],
+ [coords[2], coords[3]],
+ [coords[0], coords[3]],
+ ]
+ for coords in box_targets
+ ],
+ dtype=np_dtype,
+ )
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
+ for crop, label in zip(crops, labels):
+ self.data.append((crop, label))
+ else:
+ self.data.append((img_path, dict(boxes=box_targets, labels=labels)))
diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b33ebb50b3297b27a0db4184af283fb6a2e0d2f
--- /dev/null
+++ b/doctr/datasets/iiit5k.py
@@ -0,0 +1,103 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+import scipy.io as sio
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative
+
+__all__ = ["IIIT5K"]
+
+
+class IIIT5K(VisionDataset):
+ """IIIT-5K character-level localization dataset from
+ `"BMVC 2012 Scene Text Recognition using Higher Order Language Priors"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/iiit5k-grid.png&src=0
+ :align: center
+
+ >>> # NOTE: this dataset is for character-level localization
+ >>> from doctr.datasets import IIIT5K
+ >>> train_set = IIIT5K(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ URL = "https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz"
+ SHA256 = "7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e"
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ self.URL,
+ None,
+ file_hash=self.SHA256,
+ extract_archive=True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+
+ # Load mat data
+ tmp_root = os.path.join(self.root, "IIIT5K") if self.SHA256 else self.root
+ mat_file = "trainCharBound" if self.train else "testCharBound"
+ mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0]
+
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)):
+ _raw_path = img_path[0]
+ _raw_label = label[0]
+
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, _raw_path)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")
+
+ if recognition_task:
+ self.data.append((_raw_path, _raw_label))
+ else:
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box_targets = [
+ [
+ [box[0], box[1]],
+ [box[0] + box[2], box[1]],
+ [box[0] + box[2], box[1] + box[3]],
+ [box[0], box[1] + box[3]],
+ ]
+ for box in box_targets
+ ]
+ else:
+ # xmin, ymin, xmax, ymax
+ box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
+
+ # label are casted to list where each char corresponds to the character's bounding box
+ self.data.append((
+ _raw_path,
+ dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(_raw_label)),
+ ))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/iiithws.py b/doctr/datasets/iiithws.py
new file mode 100644
index 0000000000000000000000000000000000000000..e33e3acd536af612c65125e1a138ec29d8b62727
--- /dev/null
+++ b/doctr/datasets/iiithws.py
@@ -0,0 +1,75 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from random import sample
+from typing import Any, List, Tuple
+
+from tqdm import tqdm
+
+from .datasets import AbstractDataset
+
+__all__ = ["IIITHWS"]
+
+
+class IIITHWS(AbstractDataset):
+ """IIITHWS dataset from `"Generating Synthetic Data for Text Recognition"
+ `_ | `"repository" `_ |
+ `"website" `_.
+
+ >>> # NOTE: This is a pure recognition dataset without bounding box labels.
+ >>> # NOTE: You need to download the dataset.
+ >>> from doctr.datasets import IIITHWS
+ >>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
+ >>> label_path="/path/to/IIIT-HWS-90K.txt",
+ >>> train=True)
+ >>> img, target = train_set[0]
+ >>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
+ >>> label_path="/path/to/IIIT-HWS-90K.txt")
+ >>> train=False)
+ >>> img, target = test_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_path: path to the file with the labels
+ train: whether the subset should be the training one
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_path: str,
+ train: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(img_folder, **kwargs)
+
+ # File existence check
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
+
+ self.data: List[Tuple[str, str]] = []
+ self.train = train
+
+ with open(label_path) as f:
+ annotations = f.readlines()
+
+ # Shuffle the dataset otherwise the test set will contain the same labels n times
+ annotations = sample(annotations, len(annotations))
+ train_samples = int(len(annotations) * 0.9)
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
+
+ for annotation in tqdm(
+ iterable=annotations[set_slice], desc="Unpacking IIITHWS", total=len(annotations[set_slice])
+ ):
+ img_path, label = annotation.split()[0:2]
+ img_path = os.path.join(img_folder, img_path)
+
+ self.data.append((img_path, label))
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce70c9f3bc982b75c69ebce7f016bd76adcf6147
--- /dev/null
+++ b/doctr/datasets/imgur5k.py
@@ -0,0 +1,147 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import glob
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+from .datasets import AbstractDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["IMGUR5K"]
+
+
+class IMGUR5K(AbstractDataset):
+ """IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example"
+ `_ |
+ `repository `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/imgur5k-grid.png&src=0
+ :align: center
+ :width: 630
+ :height: 400
+
+ >>> # NOTE: You need to download/generate the dataset from the repository.
+ >>> from doctr.datasets import IMGUR5K
+ >>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
+ >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
+ >>> img, target = train_set[0]
+ >>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
+ >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
+ >>> img, target = test_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_path: path to the annotations file of the dataset
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_path: str,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
+ )
+
+ # File existence check
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
+
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ self.train = train
+ np_dtype = np.float32
+
+ img_names = os.listdir(img_folder)
+ train_samples = int(len(img_names) * 0.9)
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
+
+ # define folder to write IMGUR5K recognition dataset
+ reco_folder_name = "IMGUR5K_recognition_train" if self.train else "IMGUR5K_recognition_test"
+ reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
+ reco_folder_path = os.path.join(os.path.dirname(self.root), reco_folder_name)
+ reco_images_counter = 0
+
+ if recognition_task and os.path.isdir(reco_folder_path):
+ self._read_from_folder(reco_folder_path)
+ return
+ elif recognition_task and not os.path.isdir(reco_folder_path):
+ os.makedirs(reco_folder_path, exist_ok=False)
+
+ with open(label_path) as f:
+ annotation_file = json.load(f)
+
+ for img_name in tqdm(iterable=img_names[set_slice], desc="Unpacking IMGUR5K", total=len(img_names[set_slice])):
+ img_path = Path(img_folder, img_name)
+ img_id = img_name.split(".")[0]
+
+ # File existence check
+ if not os.path.exists(os.path.join(self.root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
+
+ # some files have no annotations which are marked with only a dot in the 'word' key
+ # ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md
+ if img_id not in annotation_file["index_to_ann_map"].keys():
+ continue
+ ann_ids = annotation_file["index_to_ann_map"][img_id]
+ annotations = [annotation_file["ann_id"][a_id] for a_id in ann_ids]
+
+ labels = [ann["word"] for ann in annotations if ann["word"] != "."]
+ # x_center, y_center, width, height, angle
+ _boxes = [
+ list(map(float, ann["bounding_box"].strip("[ ]").split(", ")))
+ for ann in annotations
+ if ann["word"] != "."
+ ]
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
+
+ if not use_polygons:
+ # xmin, ymin, xmax, ymax
+ box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets]
+
+ # filter images without boxes
+ if len(box_targets) > 0:
+ if recognition_task:
+ crops = crop_bboxes_from_image(
+ img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype)
+ )
+ for crop, label in zip(crops, labels):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ # write data to disk
+ with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
+ f.write(label)
+ tmp_img = Image.fromarray(crop)
+ tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
+ reco_images_counter += 1
+ else:
+ self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))
+
+ if recognition_task:
+ self._read_from_folder(reco_folder_path)
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
+
+ def _read_from_folder(self, path: str) -> None:
+ for img_path in glob.glob(os.path.join(path, "*.png")):
+ with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
+ self.data.append((img_path, f.read()))
diff --git a/doctr/datasets/loader.py b/doctr/datasets/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f08f7496afa155389116b8d3b52922152938b237
--- /dev/null
+++ b/doctr/datasets/loader.py
@@ -0,0 +1,102 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Callable, Optional
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.utils.multithreading import multithread_exec
+
+__all__ = ["DataLoader"]
+
+
+def default_collate(samples):
+ """Collate multiple elements into batches
+
+ Args:
+ ----
+ samples: list of N tuples containing M elements
+
+ Returns:
+ -------
+ Tuple of M sequences contianing N elements each
+ """
+ batch_data = zip(*samples)
+
+ tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)
+
+ return tf_data
+
+
+class DataLoader:
+ """Implements a dataset wrapper for fast data loading
+
+ >>> from doctr.datasets import CORD, DataLoader
+ >>> train_set = CORD(train=True, download=True)
+ >>> train_loader = DataLoader(train_set, batch_size=32)
+ >>> train_iter = iter(train_loader)
+ >>> images, targets = next(train_iter)
+
+ Args:
+ ----
+ dataset: the dataset
+ shuffle: whether the samples should be shuffled before passing it to the iterator
+ batch_size: number of elements in each batch
+ drop_last: if `True`, drops the last batch if it isn't full
+ num_workers: number of workers to use for data loading
+ collate_fn: function to merge samples into a batch
+ """
+
+ def __init__(
+ self,
+ dataset,
+ shuffle: bool = True,
+ batch_size: int = 1,
+ drop_last: bool = False,
+ num_workers: Optional[int] = None,
+ collate_fn: Optional[Callable] = None,
+ ) -> None:
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.batch_size = batch_size
+ nb = len(self.dataset) / batch_size
+ self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
+ if collate_fn is None:
+ self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
+ else:
+ self.collate_fn = collate_fn
+ self.num_workers = num_workers
+ self.reset()
+
+ def __len__(self) -> int:
+ return self.num_batches
+
+ def reset(self) -> None:
+ # Updates indices after each epoch
+ self._num_yielded = 0
+ self.indices = np.arange(len(self.dataset))
+ if self.shuffle is True:
+ np.random.shuffle(self.indices)
+
+ def __iter__(self):
+ self.reset()
+ return self
+
+ def __next__(self):
+ if self._num_yielded < self.num_batches:
+ # Get next indices
+ idx = self._num_yielded * self.batch_size
+ indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
+
+ samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
+
+ batch_data = self.collate_fn(samples)
+
+ self._num_yielded += 1
+ return batch_data
+ else:
+ raise StopIteration
diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b16caebe22e0812be38d26ed8d83003533d493
--- /dev/null
+++ b/doctr/datasets/mjsynth.py
@@ -0,0 +1,106 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, List, Tuple
+
+from tqdm import tqdm
+
+from .datasets import AbstractDataset
+
+__all__ = ["MJSynth"]
+
+
+class MJSynth(AbstractDataset):
+ """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition"
+ `_.
+
+ >>> # NOTE: This is a pure recognition dataset without bounding box labels.
+ >>> # NOTE: You need to download the dataset.
+ >>> from doctr.datasets import MJSynth
+ >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
+ >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt",
+ >>> train=True)
+ >>> img, target = train_set[0]
+ >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
+ >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt")
+ >>> train=False)
+ >>> img, target = test_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_path: path to the file with the labels
+ train: whether the subset should be the training one
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ # filter corrupted or missing images
+ BLACKLIST = [
+ "./1881/4/225_Marbling_46673.jpg\n",
+ "./2069/4/192_whittier_86389.jpg\n",
+ "./869/4/234_TRIASSIC_80582.jpg\n",
+ "./173/2/358_BURROWING_10395.jpg\n",
+ "./913/4/231_randoms_62372.jpg\n",
+ "./596/2/372_Ump_81662.jpg\n",
+ "./936/2/375_LOCALITIES_44992.jpg\n",
+ "./2540/4/246_SQUAMOUS_73902.jpg\n",
+ "./1332/4/224_TETHERED_78397.jpg\n",
+ "./627/6/83_PATRIARCHATE_55931.jpg\n",
+ "./2013/2/370_refract_63890.jpg\n",
+ "./2911/6/77_heretical_35885.jpg\n",
+ "./1730/2/361_HEREON_35880.jpg\n",
+ "./2194/2/334_EFFLORESCENT_24742.jpg\n",
+ "./2025/2/364_SNORTERS_72304.jpg\n",
+ "./368/4/232_friar_30876.jpg\n",
+ "./275/6/96_hackle_34465.jpg\n",
+ "./384/4/220_bolts_8596.jpg\n",
+ "./905/4/234_Postscripts_59142.jpg\n",
+ "./2749/6/101_Chided_13155.jpg\n",
+ "./495/6/81_MIDYEAR_48332.jpg\n",
+ "./2852/6/60_TOILSOME_79481.jpg\n",
+ "./554/2/366_Teleconferences_77948.jpg\n",
+ "./1696/4/211_Queened_61779.jpg\n",
+ "./2128/2/369_REDACTED_63458.jpg\n",
+ "./2557/2/351_DOWN_23492.jpg\n",
+ "./2489/4/221_snored_72290.jpg\n",
+ "./1650/2/355_stony_74902.jpg\n",
+ "./1863/4/223_Diligently_21672.jpg\n",
+ "./264/2/362_FORETASTE_30276.jpg\n",
+ "./429/4/208_Mainmasts_46140.jpg\n",
+ "./1817/2/363_actuating_904.jpg\n",
+ ]
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_path: str,
+ train: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(img_folder, **kwargs)
+
+ # File existence check
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
+
+ self.data: List[Tuple[str, str]] = []
+ self.train = train
+
+ with open(label_path) as f:
+ img_paths = f.readlines()
+
+ train_samples = int(len(img_paths) * 0.9)
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
+
+ for path in tqdm(iterable=img_paths[set_slice], desc="Unpacking MJSynth", total=len(img_paths[set_slice])):
+ if path not in self.BLACKLIST:
+ label = path.split("_")[1]
+ img_path = os.path.join(img_folder, path[2:]).strip()
+
+ self.data.append((img_path, label))
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b93c124ce74dcadc143abe3792e447454e391b01
--- /dev/null
+++ b/doctr/datasets/ocr.py
@@ -0,0 +1,71 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+
+from .datasets import AbstractDataset
+
+__all__ = ["OCRDataset"]
+
+
+class OCRDataset(AbstractDataset):
+ """Implements an OCR dataset
+
+ >>> from doctr.datasets import OCRDataset
+ >>> train_set = OCRDataset(img_folder="/path/to/images",
+ >>> label_file="/path/to/labels.json")
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ img_folder: local path to image folder (all jpg at the root)
+ label_file: local path to the label file
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_file: str,
+ use_polygons: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(img_folder, **kwargs)
+
+ # List images
+ self.data: List[Tuple[str, Dict[str, Any]]] = []
+ np_dtype = np.float32
+ with open(label_file, "rb") as f:
+ data = json.load(f)
+
+ for img_name, annotations in data.items():
+ # Get image path
+ img_name = Path(img_name)
+ # File existence check
+ if not os.path.exists(os.path.join(self.root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
+
+ # handle empty images
+ if len(annotations["typed_words"]) == 0:
+ self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[])))
+ continue
+ # Unpack the straight boxes (xmin, ymin, xmax, ymax)
+ geoms = [list(map(float, obj["geometry"][:4])) for obj in annotations["typed_words"]]
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ geoms = [
+ [geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item]
+ for geom in geoms
+ ]
+
+ text_targets = [obj["value"] for obj in annotations["typed_words"]]
+
+ self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets)))
diff --git a/doctr/datasets/orientation.py b/doctr/datasets/orientation.py
new file mode 100644
index 0000000000000000000000000000000000000000..10bd55444e65cb2770ce9a4d15711a82cba5e06a
--- /dev/null
+++ b/doctr/datasets/orientation.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, List, Tuple
+
+import numpy as np
+
+from .datasets import AbstractDataset
+
+__all__ = ["OrientationDataset"]
+
+
+class OrientationDataset(AbstractDataset):
+ """Implements a basic image dataset where targets are filled with zeros.
+
+ >>> from doctr.datasets import OrientationDataset
+ >>> train_set = OrientationDataset(img_folder="/path/to/images")
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ img_folder,
+ **kwargs,
+ )
+
+ # initialize dataset with 0 degree rotation targets
+ self.data: List[Tuple[str, np.ndarray]] = [(img_name, np.array([0])) for img_name in os.listdir(self.root)]
diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf37a20acfa2c3787bc8e4a1e88692d13fcdd15
--- /dev/null
+++ b/doctr/datasets/recognition.py
@@ -0,0 +1,56 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from pathlib import Path
+from typing import Any, List, Tuple
+
+from .datasets import AbstractDataset
+
+__all__ = ["RecognitionDataset"]
+
+
+class RecognitionDataset(AbstractDataset):
+ """Dataset implementation for text recognition tasks
+
+ >>> from doctr.datasets import RecognitionDataset
+ >>> train_set = RecognitionDataset(img_folder="/path/to/images",
+ >>> labels_path="/path/to/labels.json")
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ img_folder: path to the images folder
+ labels_path: pathe to the json file containing all labels (character sequences)
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ labels_path: str,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(img_folder, **kwargs)
+
+ self.data: List[Tuple[str, str]] = []
+ with open(labels_path, encoding="utf-8") as f:
+ labels = json.load(f)
+
+ for img_name, label in labels.items():
+ if not os.path.exists(os.path.join(self.root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
+
+ self.data.append((img_name, label))
+
+ def merge_dataset(self, ds: AbstractDataset) -> None:
+ # Update data with new root for self
+ self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data]
+ # Define new root
+ self.root = Path("/")
+ # Merge with ds data
+ for img_path, label in ds.data:
+ self.data.append((str(Path(ds.root).joinpath(img_path)), label))
diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72fde68a1f5e54333b5f7ab68c21069286770a1
--- /dev/null
+++ b/doctr/datasets/sroie.py
@@ -0,0 +1,103 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import csv
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["SROIE"]
+
+
+class SROIE(VisionDataset):
+ """SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/sroie-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import SROIE
+ >>> train_set = SROIE(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ TRAIN = (
+ "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_train_task1.zip&src=0",
+ "d4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f",
+ "sroie2019_train_task1.zip",
+ )
+ TEST = (
+ "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_test.zip&src=0",
+ "41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2",
+ "sroie2019_test.zip",
+ )
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ url, sha256, name = self.TRAIN if train else self.TEST
+ super().__init__(
+ url,
+ name,
+ sha256,
+ True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+
+ tmp_root = os.path.join(self.root, "images")
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
+
+ stem = Path(img_path).stem
+ with open(os.path.join(self.root, "annotations", f"{stem}.txt"), encoding="latin") as f:
+ _rows = [row for row in list(csv.reader(f, delimiter=",")) if len(row) > 0]
+
+ labels = [",".join(row[8:]) for row in _rows]
+ # reorder coordinates (8 -> (4,2) ->
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty lines
+ coords: np.ndarray = np.stack(
+ [np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2)) for row in _rows], axis=0
+ )
+
+ if not use_polygons:
+ # xmin, ymin, xmax, ymax
+ coords = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1)
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path), geoms=coords)
+ for crop, label in zip(crops, labels):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ self.data.append((crop, label))
+ else:
+ self.data.append((img_path, dict(boxes=coords, labels=labels)))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py
new file mode 100644
index 0000000000000000000000000000000000000000..57085c5213a549f276858e6623d7e2a91006ad65
--- /dev/null
+++ b/doctr/datasets/svhn.py
@@ -0,0 +1,131 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import h5py
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["SVHN"]
+
+
+class SVHN(VisionDataset):
+ """SVHN dataset from `"The Street View House Numbers (SVHN) Dataset"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svhn-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import SVHN
+ >>> train_set = SVHN(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ TRAIN = (
+ "http://ufldl.stanford.edu/housenumbers/train.tar.gz",
+ "4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898",
+ "svhn_train.tar",
+ )
+
+ TEST = (
+ "http://ufldl.stanford.edu/housenumbers/test.tar.gz",
+ "57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5",
+ "svhn_test.tar",
+ )
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ url, sha256, name = self.TRAIN if train else self.TEST
+ super().__init__(
+ url,
+ file_name=name,
+ file_hash=sha256,
+ extract_archive=True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ tmp_root = os.path.join(self.root, "train" if train else "test")
+
+ # Load mat data (matlab v7.3 - can not be loaded with scipy)
+ with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
+ img_refs = f["digitStruct/name"]
+ box_refs = f["digitStruct/bbox"]
+ for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc="Unpacking SVHN", total=len(img_refs)):
+ # convert ascii matrix to string
+ img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
+
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_name)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
+
+ # Unpack the information
+ box = f[box_ref[0]]
+ if box["left"].shape[0] == 1:
+ box_dict = {k: [int(vals[0][0])] for k, vals in box.items()}
+ else:
+ box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()}
+
+ # Convert it to the right format
+ coords: np.ndarray = np.array(
+ [box_dict["left"], box_dict["top"], box_dict["width"], box_dict["height"]], dtype=np_dtype
+ ).transpose()
+ label_targets = list(map(str, box_dict["label"]))
+
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box_targets: np.ndarray = np.stack(
+ [
+ np.stack([coords[:, 0], coords[:, 1]], axis=-1),
+ np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1),
+ np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1),
+ np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1),
+ ],
+ axis=1,
+ )
+ else:
+ # x, y, width, height -> xmin, ymin, xmax, ymax
+ box_targets = np.stack(
+ [
+ coords[:, 0],
+ coords[:, 1],
+ coords[:, 0] + coords[:, 2],
+ coords[:, 1] + coords[:, 3],
+ ],
+ axis=-1,
+ )
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
+ for crop, label in zip(crops, label_targets):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ self.data.append((crop, label))
+ else:
+ self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb7b6d599e6e2dc5cf4c424da6f9c61a579adf0
--- /dev/null
+++ b/doctr/datasets/svt.py
@@ -0,0 +1,117 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import defusedxml.ElementTree as ET
+import numpy as np
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["SVT"]
+
+
+class SVT(VisionDataset):
+ """SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
+ `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import SVT
+ >>> train_set = SVT(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ URL = "http://vision.ucsd.edu/~kai/svt/svt.zip"
+ SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ self.URL,
+ None,
+ self.SHA256,
+ True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ # Load xml data
+ tmp_root = os.path.join(self.root, "svt1") if self.SHA256 else self.root
+ xml_tree = (
+ ET.parse(os.path.join(tmp_root, "train.xml"))
+ if self.train
+ else ET.parse(os.path.join(tmp_root, "test.xml"))
+ )
+ xml_root = xml_tree.getroot()
+
+ for image in tqdm(iterable=xml_root, desc="Unpacking SVT", total=len(xml_root)):
+ name, _, _, _resolution, rectangles = image
+
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, name.text)):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
+
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ _boxes = [
+ [
+ [float(rect.attrib["x"]), float(rect.attrib["y"])],
+ [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
+ [
+ float(rect.attrib["x"]) + float(rect.attrib["width"]),
+ float(rect.attrib["y"]) + float(rect.attrib["height"]),
+ ],
+ [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
+ ]
+ for rect in rectangles
+ ]
+ else:
+ # x_min, y_min, x_max, y_max
+ _boxes = [
+ [
+ float(rect.attrib["x"]), # type: ignore[list-item]
+ float(rect.attrib["y"]), # type: ignore[list-item]
+ float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
+ float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
+ ]
+ for rect in rectangles
+ ]
+
+ boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
+ # Get the labels
+ labels = [lab.text for rect in rectangles for lab in rect]
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
+ for crop, label in zip(crops, labels):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ self.data.append((crop, label))
+ else:
+ self.data.append((name.text, dict(boxes=boxes, labels=labels)))
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py
new file mode 100644
index 0000000000000000000000000000000000000000..a60e22e83212d2586159612d00e651dd66f82a5f
--- /dev/null
+++ b/doctr/datasets/synthtext.py
@@ -0,0 +1,128 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import glob
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+from PIL import Image
+from scipy import io as sio
+from tqdm import tqdm
+
+from .datasets import VisionDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["SynthText"]
+
+
+class SynthText(VisionDataset):
+ """SynthText dataset from `"Synthetic Data for Text Localisation in Natural Images"
+ `_ | `"repository" `_ |
+ `"website" `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
+ :align: center
+
+ >>> from doctr.datasets import SynthText
+ >>> train_set = SynthText(train=True, download=True)
+ >>> img, target = train_set[0]
+
+ Args:
+ ----
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `VisionDataset`.
+ """
+
+ URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
+ SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
+
+ def __init__(
+ self,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ self.URL,
+ None,
+ file_hash=None,
+ extract_archive=True,
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
+ **kwargs,
+ )
+ self.train = train
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
+ np_dtype = np.float32
+
+ # Load mat data
+ tmp_root = os.path.join(self.root, "SynthText") if self.SHA256 else self.root
+ # define folder to write SynthText recognition dataset
+ reco_folder_name = "SynthText_recognition_train" if self.train else "SynthText_recognition_test"
+ reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
+ reco_folder_path = os.path.join(tmp_root, reco_folder_name)
+ reco_images_counter = 0
+
+ if recognition_task and os.path.isdir(reco_folder_path):
+ self._read_from_folder(reco_folder_path)
+ return
+ elif recognition_task and not os.path.isdir(reco_folder_path):
+ os.makedirs(reco_folder_path, exist_ok=False)
+
+ mat_data = sio.loadmat(os.path.join(tmp_root, "gt.mat"))
+ train_samples = int(len(mat_data["imnames"][0]) * 0.9)
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
+ paths = mat_data["imnames"][0][set_slice]
+ boxes = mat_data["wordBB"][0][set_slice]
+ labels = mat_data["txt"][0][set_slice]
+ del mat_data
+
+ for img_path, word_boxes, txt in tqdm(
+ iterable=zip(paths, boxes, labels), desc="Unpacking SynthText", total=len(paths)
+ ):
+ # File existence check
+ if not os.path.exists(os.path.join(tmp_root, img_path[0])):
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}")
+
+ labels = [elt for word in txt.tolist() for elt in word.split()]
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ word_boxes = (
+ word_boxes.transpose(2, 1, 0)
+ if word_boxes.ndim == 3
+ else np.expand_dims(word_boxes.transpose(1, 0), axis=0)
+ )
+
+ if not use_polygons:
+ # xmin, ymin, xmax, ymax
+ word_boxes = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1)
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
+ for crop, label in zip(crops, labels):
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
+ # write data to disk
+ with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
+ f.write(label)
+ tmp_img = Image.fromarray(crop)
+ tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
+ reco_images_counter += 1
+ else:
+ self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels)))
+
+ if recognition_task:
+ self._read_from_folder(reco_folder_path)
+
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
+
+ def _read_from_folder(self, path: str) -> None:
+ for img_path in glob.glob(os.path.join(path, "*.png")):
+ with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
+ self.data.append((img_path, f.read()))
diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cc0f165725a410aceccc4e0e7150a52d807e94
--- /dev/null
+++ b/doctr/datasets/utils.py
@@ -0,0 +1,216 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import string
+import unicodedata
+from collections.abc import Sequence
+from functools import partial
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
+from typing import Sequence as SequenceType
+
+import numpy as np
+from PIL import Image
+
+from doctr.io.image import get_img_shape
+from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extract_rcrops
+
+from .vocabs import VOCABS
+
+__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"]
+
+ImageTensor = TypeVar("ImageTensor")
+
+
+def translate(
+ input_string: str,
+ vocab_name: str,
+ unknown_char: str = "■",
+) -> str:
+ """Translate a string input in a given vocabulary
+
+ Args:
+ ----
+ input_string: input string to translate
+ vocab_name: vocabulary to use (french, latin, ...)
+ unknown_char: unknown character for non-translatable characters
+
+ Returns:
+ -------
+ A string translated in a given vocab
+ """
+ if VOCABS.get(vocab_name) is None:
+ raise KeyError("output vocabulary must be in vocabs dictionnary")
+
+ translated = ""
+ for char in input_string:
+ if char not in VOCABS[vocab_name]:
+ # we need to translate char into a vocab char
+ if char in string.whitespace:
+ # remove whitespaces
+ continue
+ # normalize character if it is not in vocab
+ char = unicodedata.normalize("NFD", char).encode("ascii", "ignore").decode("ascii")
+ if char == "" or char not in VOCABS[vocab_name]:
+ # if normalization fails or char still not in vocab, return unknown character)
+ char = unknown_char
+ translated += char
+ return translated
+
+
+def encode_string(
+ input_string: str,
+ vocab: str,
+) -> List[int]:
+ """Given a predefined mapping, encode the string to a sequence of numbers
+
+ Args:
+ ----
+ input_string: string to encode
+ vocab: vocabulary (string), the encoding is given by the indexing of the character sequence
+
+ Returns:
+ -------
+ A list encoding the input_string
+ """
+ try:
+ return list(map(vocab.index, input_string))
+ except ValueError:
+ raise ValueError(
+ f"some characters cannot be found in 'vocab'. \
+ Please check the input string {input_string} and the vocabulary {vocab}"
+ )
+
+
+def decode_sequence(
+ input_seq: Union[np.ndarray, SequenceType[int]],
+ mapping: str,
+) -> str:
+ """Given a predefined mapping, decode the sequence of numbers to a string
+
+ Args:
+ ----
+ input_seq: array to decode
+ mapping: vocabulary (string), the encoding is given by the indexing of the character sequence
+
+ Returns:
+ -------
+ A string, decoded from input_seq
+ """
+ if not isinstance(input_seq, (Sequence, np.ndarray)):
+ raise TypeError("Invalid sequence type")
+ if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)):
+ raise AssertionError("Input must be an array of int, with max less than mapping size")
+
+ return "".join(map(mapping.__getitem__, input_seq))
+
+
+def encode_sequences(
+ sequences: List[str],
+ vocab: str,
+ target_size: Optional[int] = None,
+ eos: int = -1,
+ sos: Optional[int] = None,
+ pad: Optional[int] = None,
+ dynamic_seq_length: bool = False,
+) -> np.ndarray:
+ """Encode character sequences using a given vocab as mapping
+
+ Args:
+ ----
+ sequences: the list of character sequences of size N
+ vocab: the ordered vocab to use for encoding
+ target_size: maximum length of the encoded data
+ eos: encoding of End Of String
+ sos: optional encoding of Start Of String
+ pad: optional encoding for padding. In case of padding, all sequences are followed by 1 EOS then PAD
+ dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size
+
+ Returns:
+ -------
+ the padded encoded data as a tensor
+ """
+ if 0 <= eos < len(vocab):
+ raise ValueError("argument 'eos' needs to be outside of vocab possible indices")
+
+ if not isinstance(target_size, int) or dynamic_seq_length:
+ # Maximum string length + EOS
+ max_length = max(len(w) for w in sequences) + 1
+ if isinstance(sos, int):
+ max_length += 1
+ if isinstance(pad, int):
+ max_length += 1
+ target_size = max_length if not isinstance(target_size, int) else min(max_length, target_size)
+
+ # Pad all sequences
+ if isinstance(pad, int): # pad with padding symbol
+ if 0 <= pad < len(vocab):
+ raise ValueError("argument 'pad' needs to be outside of vocab possible indices")
+ # In that case, add EOS at the end of the word before padding
+ default_symbol = pad
+ else: # pad with eos symbol
+ default_symbol = eos
+ encoded_data: np.ndarray = np.full([len(sequences), target_size], default_symbol, dtype=np.int32)
+
+ # Encode the strings
+ for idx, seq in enumerate(map(partial(encode_string, vocab=vocab), sequences)):
+ if isinstance(pad, int): # add eos at the end of the sequence
+ seq.append(eos)
+ encoded_data[idx, : min(len(seq), target_size)] = seq[: min(len(seq), target_size)]
+
+ if isinstance(sos, int): # place sos symbol at the beginning of each sequence
+ if 0 <= sos < len(vocab):
+ raise ValueError("argument 'sos' needs to be outside of vocab possible indices")
+ encoded_data = np.roll(encoded_data, 1)
+ encoded_data[:, 0] = sos
+
+ return encoded_data
+
+
+def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]:
+ target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
+ return img, target
+
+
+def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> List[np.ndarray]:
+ """Crop a set of bounding boxes from an image
+
+ Args:
+ ----
+ img_path: path to the image
+ geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
+
+ Returns:
+ -------
+ a list of cropped images
+ """
+ img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
+ # Polygon
+ if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
+ return extract_rcrops(img, geoms.astype(dtype=int))
+ if geoms.ndim == 2 and geoms.shape[1] == 4:
+ return extract_crops(img, geoms.astype(dtype=int))
+ raise ValueError("Invalid geometry format")
+
+
+def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]:
+ """Converts multiclass target to relative coordinates.
+
+ Args:
+ ----
+ img: Image
+ target: tuple of target polygons and their classes names
+
+ Returns:
+ -------
+ Image and dictionary of boxes, with class names as keys
+ """
+ boxes = convert_to_relative_coords(target[0], get_img_shape(img))
+ boxes_classes = target[1]
+ boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))}
+ for k, poly in zip(boxes_classes, boxes):
+ boxes_dict[k].append(poly)
+ boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()}
+ return img, boxes_dict
diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc32d866581c931419072f97b641846b25a18db
--- /dev/null
+++ b/doctr/datasets/vocabs.py
@@ -0,0 +1,71 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import string
+from typing import Dict
+
+__all__ = ["VOCABS"]
+
+
+VOCABS: Dict[str, str] = {
+ "digits": string.digits,
+ "ascii_letters": string.ascii_letters,
+ "punctuation": string.punctuation,
+ "currency": "£€¥¢฿",
+ "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
+ "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
+ "persian_letters": "پچڢڤگ",
+ "hindi_digits": "٠١٢٣٤٥٦٧٨٩",
+ "arabic_diacritics": "ًٌٍَُِّْ",
+ "arabic_punctuation": "؟؛«»—",
+}
+
+VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
+VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"]
+VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"]
+VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ"
+VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ"
+VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿"
+VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙÚ"
+VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
+VOCABS["arabic"] = (
+ VOCABS["digits"]
+ + VOCABS["hindi_digits"]
+ + VOCABS["arabic_letters"]
+ + VOCABS["persian_letters"]
+ + VOCABS["arabic_diacritics"]
+ + VOCABS["arabic_punctuation"]
+ + VOCABS["punctuation"]
+)
+VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ"
+VOCABS["polish"] = VOCABS["english"] + "ąćęłńóśźżĄĆĘŁŃÓŚŹŻ"
+VOCABS["dutch"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ"
+VOCABS["norwegian"] = VOCABS["english"] + "æøåÆØÅ"
+VOCABS["danish"] = VOCABS["english"] + "æøåÆØÅ"
+VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
+VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ"
+VOCABS["vietnamese"] = (
+ VOCABS["english"]
+ + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
+ + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
+)
+VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
+VOCABS["multilingual"] = "".join(
+ dict.fromkeys(
+ VOCABS["french"]
+ + VOCABS["portuguese"]
+ + VOCABS["spanish"]
+ + VOCABS["german"]
+ + VOCABS["czech"]
+ + VOCABS["polish"]
+ + VOCABS["dutch"]
+ + VOCABS["italian"]
+ + VOCABS["norwegian"]
+ + VOCABS["danish"]
+ + VOCABS["finnish"]
+ + VOCABS["swedish"]
+ + "§"
+ )
+)
diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py
new file mode 100644
index 0000000000000000000000000000000000000000..19108d77612af08cb227750abea1beae938605ff
--- /dev/null
+++ b/doctr/datasets/wildreceipt.py
@@ -0,0 +1,111 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import json
+import os
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+
+from .datasets import AbstractDataset
+from .utils import convert_target_to_relative, crop_bboxes_from_image
+
+__all__ = ["WILDRECEIPT"]
+
+
+class WILDRECEIPT(AbstractDataset):
+ """WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction"
+ `_ |
+ `repository `_.
+
+ .. image:: https://doctr-static.mindee.com/models?id=v0.7.0/wildreceipt-dataset.jpg&src=0
+ :align: center
+
+ >>> # NOTE: You need to download the dataset first.
+ >>> from doctr.datasets import WILDRECEIPT
+ >>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/",
+ >>> label_path="/path/to/wildreceipt/train.txt")
+ >>> img, target = train_set[0]
+ >>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/",
+ >>> label_path="/path/to/wildreceipt/test.txt")
+ >>> img, target = test_set[0]
+
+ Args:
+ ----
+ img_folder: folder with all the images of the dataset
+ label_path: path to the annotations file of the dataset
+ train: whether the subset should be the training one
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
+ recognition_task: whether the dataset should be used for recognition task
+ **kwargs: keyword arguments from `AbstractDataset`.
+ """
+
+ def __init__(
+ self,
+ img_folder: str,
+ label_path: str,
+ train: bool = True,
+ use_polygons: bool = False,
+ recognition_task: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
+ )
+ # File existence check
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
+
+ tmp_root = img_folder
+ self.train = train
+ np_dtype = np.float32
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
+
+ with open(label_path, "r") as file:
+ data = file.read()
+ # Split the text file into separate JSON strings
+ json_strings = data.strip().split("\n")
+ box: Union[List[float], np.ndarray]
+ _targets = []
+ for json_string in json_strings:
+ json_data = json.loads(json_string)
+ img_path = json_data["file_name"]
+ annotations = json_data["annotations"]
+ for annotation in annotations:
+ coordinates = annotation["box"]
+ if use_polygons:
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
+ box = np.array(
+ [
+ [coordinates[0], coordinates[1]],
+ [coordinates[2], coordinates[3]],
+ [coordinates[4], coordinates[5]],
+ [coordinates[6], coordinates[7]],
+ ],
+ dtype=np_dtype,
+ )
+ else:
+ x, y = coordinates[::2], coordinates[1::2]
+ box = [min(x), min(y), max(x), max(y)]
+ _targets.append((annotation["text"], box))
+ text_targets, box_targets = zip(*_targets)
+
+ if recognition_task:
+ crops = crop_bboxes_from_image(
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
+ )
+ for crop, label in zip(crops, list(text_targets)):
+ if label and " " not in label:
+ self.data.append((crop, label))
+ else:
+ self.data.append((
+ img_path,
+ dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
+ ))
+ self.root = tmp_root
+
+ def extra_repr(self) -> str:
+ return f"train={self.train}"
diff --git a/doctr/file_utils.py b/doctr/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..93f12a55cbe79af3f6d67fa119606e3dc53a9e52
--- /dev/null
+++ b/doctr/file_utils.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
+
+import importlib.util
+import logging
+import os
+import sys
+
+CLASS_NAME: str = "words"
+
+
+if sys.version_info < (3, 8): # pragma: no cover
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+
+
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logging.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError: # pragma: no cover
+ _torch_available = False
+else: # pragma: no cover
+ logging.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+
+
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
+ logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logging.info(f"TensorFlow version {_tf_version} available.")
+else: # pragma: no cover
+ logging.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+
+if not _torch_available and not _tf_available: # pragma: no cover
+ raise ModuleNotFoundError(
+ "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
+ " is installed and that either USE_TF or USE_TORCH is enabled."
+ )
+
+
+def is_torch_available():
+ """Whether PyTorch is installed."""
+ return _torch_available
+
+
+def is_tf_available():
+ """Whether TensorFlow is installed."""
+ return _tf_available
diff --git a/doctr/io/__init__.py b/doctr/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eab8c240615522894f0b8a4ef09aaf59636a811
--- /dev/null
+++ b/doctr/io/__init__.py
@@ -0,0 +1,5 @@
+from .elements import *
+from .html import *
+from .image import *
+from .pdf import *
+from .reader import *
diff --git a/doctr/io/elements.py b/doctr/io/elements.py
new file mode 100644
index 0000000000000000000000000000000000000000..4862b17b6bea5fadb1f183008ea5849888810fd1
--- /dev/null
+++ b/doctr/io/elements.py
@@ -0,0 +1,621 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from defusedxml import defuse_stdlib
+
+defuse_stdlib()
+from xml.etree import ElementTree as ET
+from xml.etree.ElementTree import Element as ETElement
+from xml.etree.ElementTree import SubElement
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+import doctr
+from doctr.utils.common_types import BoundingBox
+from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
+from doctr.utils.repr import NestedObject
+from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
+
+__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
+
+
+class Element(NestedObject):
+ """Implements an abstract document element with exporting and text rendering capabilities"""
+
+ _children_names: List[str] = []
+ _exported_keys: List[str] = []
+
+ def __init__(self, **kwargs: Any) -> None:
+ for k, v in kwargs.items():
+ if k in self._children_names:
+ setattr(self, k, v)
+ else:
+ raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'")
+
+ def export(self) -> Dict[str, Any]:
+ """Exports the object into a nested dict format"""
+ export_dict = {k: getattr(self, k) for k in self._exported_keys}
+ for children_name in self._children_names:
+ if children_name in ["predictions"]:
+ export_dict[children_name] = {
+ k: [item.export() for item in c] for k, c in getattr(self, children_name).items()
+ }
+ else:
+ export_dict[children_name] = [c.export() for c in getattr(self, children_name)]
+
+ return export_dict
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ raise NotImplementedError
+
+ def render(self) -> str:
+ raise NotImplementedError
+
+
+class Word(Element):
+ """Implements a word element
+
+ Args:
+ ----
+ value: the text string of the word
+ confidence: the confidence associated with the text prediction
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
+ the page's size
+ crop_orientation: the general orientation of the crop in degrees and its confidence
+ """
+
+ _exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"]
+ _children_names: List[str] = []
+
+ def __init__(
+ self,
+ value: str,
+ confidence: float,
+ geometry: Union[BoundingBox, np.ndarray],
+ crop_orientation: Dict[str, Any],
+ ) -> None:
+ super().__init__()
+ self.value = value
+ self.confidence = confidence
+ self.geometry = geometry
+ self.crop_orientation = crop_orientation
+
+ def render(self) -> str:
+ """Renders the full text of the element"""
+ return self.value
+
+ def extra_repr(self) -> str:
+ return f"value='{self.value}', confidence={self.confidence:.2}"
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ return cls(**kwargs)
+
+
+class Artefact(Element):
+ """Implements a non-textual element
+
+ Args:
+ ----
+ artefact_type: the type of artefact
+ confidence: the confidence of the type prediction
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
+ the page's size.
+ """
+
+ _exported_keys: List[str] = ["geometry", "type", "confidence"]
+ _children_names: List[str] = []
+
+ def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None:
+ super().__init__()
+ self.geometry = geometry
+ self.type = artefact_type
+ self.confidence = confidence
+
+ def render(self) -> str:
+ """Renders the full text of the element"""
+ return f"[{self.type.upper()}]"
+
+ def extra_repr(self) -> str:
+ return f"type='{self.type}', confidence={self.confidence:.2}"
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ return cls(**kwargs)
+
+
+class Line(Element):
+ """Implements a line element as a collection of words
+
+ Args:
+ ----
+ words: list of word elements
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
+ the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
+ all words in it.
+ """
+
+ _exported_keys: List[str] = ["geometry"]
+ _children_names: List[str] = ["words"]
+ words: List[Word] = []
+
+ def __init__(
+ self,
+ words: List[Word],
+ geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
+ ) -> None:
+ # Resolve the geometry using the smallest enclosing bounding box
+ if geometry is None:
+ # Check whether this is a rotated or straight box
+ box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
+ geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator]
+
+ super().__init__(words=words)
+ self.geometry = geometry
+
+ def render(self) -> str:
+ """Renders the full text of the element"""
+ return " ".join(w.render() for w in self.words)
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ kwargs.update({
+ "words": [Word.from_dict(_dict) for _dict in save_dict["words"]],
+ })
+ return cls(**kwargs)
+
+
+class Prediction(Word):
+ """Implements a prediction element"""
+
+ def render(self) -> str:
+ """Renders the full text of the element"""
+ return self.value
+
+ def extra_repr(self) -> str:
+ return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}"
+
+
+class Block(Element):
+ """Implements a block element as a collection of lines and artefacts
+
+ Args:
+ ----
+ lines: list of line elements
+ artefacts: list of artefacts
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
+ the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
+ all lines and artefacts in it.
+ """
+
+ _exported_keys: List[str] = ["geometry"]
+ _children_names: List[str] = ["lines", "artefacts"]
+ lines: List[Line] = []
+ artefacts: List[Artefact] = []
+
+ def __init__(
+ self,
+ lines: List[Line] = [],
+ artefacts: List[Artefact] = [],
+ geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
+ ) -> None:
+ # Resolve the geometry using the smallest enclosing bounding box
+ if geometry is None:
+ line_boxes = [word.geometry for line in lines for word in line.words]
+ artefact_boxes = [artefact.geometry for artefact in artefacts]
+ box_resolution_fn = (
+ resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox
+ )
+ geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator]
+
+ super().__init__(lines=lines, artefacts=artefacts)
+ self.geometry = geometry
+
+ def render(self, line_break: str = "\n") -> str:
+ """Renders the full text of the element"""
+ return line_break.join(line.render() for line in self.lines)
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ kwargs.update({
+ "lines": [Line.from_dict(_dict) for _dict in save_dict["lines"]],
+ "artefacts": [Artefact.from_dict(_dict) for _dict in save_dict["artefacts"]],
+ })
+ return cls(**kwargs)
+
+
+class Page(Element):
+ """Implements a page element as a collection of blocks
+
+ Args:
+ ----
+ page: image encoded as a numpy array in uint8
+ blocks: list of block elements
+ page_idx: the index of the page in the input raw document
+ dimensions: the page size in pixels in format (height, width)
+ orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
+ language: a dictionary with the language value and confidence of the prediction
+ """
+
+ _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
+ _children_names: List[str] = ["blocks"]
+ blocks: List[Block] = []
+
+ def __init__(
+ self,
+ page: np.ndarray,
+ blocks: List[Block],
+ page_idx: int,
+ dimensions: Tuple[int, int],
+ orientation: Optional[Dict[str, Any]] = None,
+ language: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__(blocks=blocks)
+ self.page = page
+ self.page_idx = page_idx
+ self.dimensions = dimensions
+ self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
+ self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
+
+ def render(self, block_break: str = "\n\n") -> str:
+ """Renders the full text of the element"""
+ return block_break.join(b.render() for b in self.blocks)
+
+ def extra_repr(self) -> str:
+ return f"dimensions={self.dimensions}"
+
+ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
+ """Overlay the result on a given image
+
+ Args:
+ interactive: whether the display should be interactive
+ preserve_aspect_ratio: pass True if you passed True to the predictor
+ **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
+ """
+ visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
+ plt.show(**kwargs)
+
+ def synthesize(self, **kwargs) -> np.ndarray:
+ """Synthesize the page from the predictions
+
+ Returns
+ -------
+ synthesized page
+ """
+ return synthesize_page(self.export(), **kwargs)
+
+ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
+ """Export the page as XML (hOCR-format)
+ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
+
+ Args:
+ ----
+ file_title: the title of the XML file
+
+ Returns:
+ -------
+ a tuple of the XML byte string, and its ElementTree
+ """
+ p_idx = self.page_idx
+ block_count: int = 1
+ line_count: int = 1
+ word_count: int = 1
+ height, width = self.dimensions
+ language = self.language if "language" in self.language.keys() else "en"
+ # Create the XML root element
+ page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
+ # Create the header / SubElements of the root element
+ head = SubElement(page_hocr, "head")
+ SubElement(head, "title").text = file_title
+ SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
+ SubElement(
+ head,
+ "meta",
+ attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
+ )
+ SubElement(
+ head,
+ "meta",
+ attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
+ )
+ # Create the body
+ body = SubElement(page_hocr, "body")
+ SubElement(
+ body,
+ "div",
+ attrib={
+ "class": "ocr_page",
+ "id": f"page_{p_idx + 1}",
+ "title": f"image; bbox 0 0 {width} {height}; ppageno 0",
+ },
+ )
+ # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
+ for block in self.blocks:
+ if len(block.geometry) != 2:
+ raise TypeError("XML export is only available for straight bounding boxes for now.")
+ (xmin, ymin), (xmax, ymax) = block.geometry
+ block_div = SubElement(
+ body,
+ "div",
+ attrib={
+ "class": "ocr_carea",
+ "id": f"block_{block_count}",
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
+ {int(round(xmax * width))} {int(round(ymax * height))}",
+ },
+ )
+ paragraph = SubElement(
+ block_div,
+ "p",
+ attrib={
+ "class": "ocr_par",
+ "id": f"par_{block_count}",
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
+ {int(round(xmax * width))} {int(round(ymax * height))}",
+ },
+ )
+ block_count += 1
+ for line in block.lines:
+ (xmin, ymin), (xmax, ymax) = line.geometry
+ # NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
+ line_span = SubElement(
+ paragraph,
+ "span",
+ attrib={
+ "class": "ocr_line",
+ "id": f"line_{line_count}",
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
+ baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
+ },
+ )
+ line_count += 1
+ for word in line.words:
+ (xmin, ymin), (xmax, ymax) = word.geometry
+ conf = word.confidence
+ word_div = SubElement(
+ line_span,
+ "span",
+ attrib={
+ "class": "ocrx_word",
+ "id": f"word_{word_count}",
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
+ x_wconf {int(round(conf * 100))}",
+ },
+ )
+ # set the text
+ word_div.text = word.value
+ word_count += 1
+
+ return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr))
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]})
+ return cls(**kwargs)
+
+
+class KIEPage(Element):
+ """Implements a KIE page element as a collection of predictions
+
+ Args:
+ ----
+ predictions: Dictionary with list of block elements for each detection class
+ page: image encoded as a numpy array in uint8
+ page_idx: the index of the page in the input raw document
+ dimensions: the page size in pixels in format (height, width)
+ orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
+ language: a dictionary with the language value and confidence of the prediction
+ """
+
+ _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
+ _children_names: List[str] = ["predictions"]
+ predictions: Dict[str, List[Prediction]] = {}
+
+ def __init__(
+ self,
+ page: np.ndarray,
+ predictions: Dict[str, List[Prediction]],
+ page_idx: int,
+ dimensions: Tuple[int, int],
+ orientation: Optional[Dict[str, Any]] = None,
+ language: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__(predictions=predictions)
+ self.page = page
+ self.page_idx = page_idx
+ self.dimensions = dimensions
+ self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
+ self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
+
+ def render(self, prediction_break: str = "\n\n") -> str:
+ """Renders the full text of the element"""
+ return prediction_break.join(
+ f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions
+ )
+
+ def extra_repr(self) -> str:
+ return f"dimensions={self.dimensions}"
+
+ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
+ """Overlay the result on a given image
+
+ Args:
+ interactive: whether the display should be interactive
+ preserve_aspect_ratio: pass True if you passed True to the predictor
+ **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
+ """
+ visualize_kie_page(
+ self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
+ )
+ plt.show(**kwargs)
+
+ def synthesize(self, **kwargs) -> np.ndarray:
+ """Synthesize the page from the predictions
+
+ Args:
+ ----
+ **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
+
+ Returns:
+ -------
+ synthesized page
+ """
+ return synthesize_kie_page(self.export(), **kwargs)
+
+ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
+ """Export the page as XML (hOCR-format)
+ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
+
+ Args:
+ ----
+ file_title: the title of the XML file
+
+ Returns:
+ -------
+ a tuple of the XML byte string, and its ElementTree
+ """
+ p_idx = self.page_idx
+ prediction_count: int = 1
+ height, width = self.dimensions
+ language = self.language if "language" in self.language.keys() else "en"
+ # Create the XML root element
+ page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
+ # Create the header / SubElements of the root element
+ head = SubElement(page_hocr, "head")
+ SubElement(head, "title").text = file_title
+ SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
+ SubElement(
+ head,
+ "meta",
+ attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
+ )
+ SubElement(
+ head,
+ "meta",
+ attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
+ )
+ # Create the body
+ body = SubElement(page_hocr, "body")
+ SubElement(
+ body,
+ "div",
+ attrib={
+ "class": "ocr_page",
+ "id": f"page_{p_idx + 1}",
+ "title": f"image; bbox 0 0 {width} {height}; ppageno 0",
+ },
+ )
+ # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
+ for class_name, predictions in self.predictions.items():
+ for prediction in predictions:
+ if len(prediction.geometry) != 2:
+ raise TypeError("XML export is only available for straight bounding boxes for now.")
+ (xmin, ymin), (xmax, ymax) = prediction.geometry
+ prediction_div = SubElement(
+ body,
+ "div",
+ attrib={
+ "class": "ocr_carea",
+ "id": f"{class_name}_prediction_{prediction_count}",
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
+ {int(round(xmax * width))} {int(round(ymax * height))}",
+ },
+ )
+ prediction_div.text = prediction.value
+ prediction_count += 1
+
+ return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ kwargs.update({
+ "predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]
+ })
+ return cls(**kwargs)
+
+
+class Document(Element):
+ """Implements a document element as a collection of pages
+
+ Args:
+ ----
+ pages: list of page elements
+ """
+
+ _children_names: List[str] = ["pages"]
+ pages: List[Page] = []
+
+ def __init__(
+ self,
+ pages: List[Page],
+ ) -> None:
+ super().__init__(pages=pages)
+
+ def render(self, page_break: str = "\n\n\n\n") -> str:
+ """Renders the full text of the element"""
+ return page_break.join(p.render() for p in self.pages)
+
+ def show(self, **kwargs) -> None:
+ """Overlay the result on a given image"""
+ for result in self.pages:
+ result.show(**kwargs)
+
+ def synthesize(self, **kwargs) -> List[np.ndarray]:
+ """Synthesize all pages from their predictions
+
+ Returns
+ -------
+ list of synthesized pages
+ """
+ return [page.synthesize() for page in self.pages]
+
+ def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]:
+ """Export the document as XML (hOCR-format)
+
+ Args:
+ ----
+ **kwargs: additional keyword arguments passed to the Page.export_as_xml method
+
+ Returns:
+ -------
+ list of tuple of (bytes, ElementTree)
+ """
+ return [page.export_as_xml(**kwargs) for page in self.pages]
+
+ @classmethod
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
+ kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]})
+ return cls(**kwargs)
+
+
+class KIEDocument(Document):
+ """Implements a document element as a collection of pages
+
+ Args:
+ ----
+ pages: list of page elements
+ """
+
+ _children_names: List[str] = ["pages"]
+ pages: List[KIEPage] = [] # type: ignore[assignment]
+
+ def __init__(
+ self,
+ pages: List[KIEPage],
+ ) -> None:
+ super().__init__(pages=pages) # type: ignore[arg-type]
diff --git a/doctr/io/html.py b/doctr/io/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2d00c1023335a0ccc39b7422bb2d13a3fad3e9
--- /dev/null
+++ b/doctr/io/html.py
@@ -0,0 +1,28 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any
+
+from weasyprint import HTML
+
+__all__ = ["read_html"]
+
+
+def read_html(url: str, **kwargs: Any) -> bytes:
+ """Read a PDF file and convert it into an image in numpy format
+
+ >>> from doctr.io import read_html
+ >>> doc = read_html("https://www.yoursite.com")
+
+ Args:
+ ----
+ url: URL of the target web page
+ **kwargs: keyword arguments from `weasyprint.HTML`
+
+ Returns:
+ -------
+ decoded PDF file as a bytes stream
+ """
+ return HTML(url, **kwargs).write_pdf()
diff --git a/doctr/io/image/__init__.py b/doctr/io/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..393c70c359df5bcebea5d8cdcb2277d60923e8e6
--- /dev/null
+++ b/doctr/io/image/__init__.py
@@ -0,0 +1,8 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+from .base import *
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import *
diff --git a/doctr/io/image/base.py b/doctr/io/image/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c2ed3065bb326c99b70e7d9c52cdb9d36ef809
--- /dev/null
+++ b/doctr/io/image/base.py
@@ -0,0 +1,56 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from pathlib import Path
+from typing import Optional, Tuple
+
+import cv2
+import numpy as np
+
+from doctr.utils.common_types import AbstractFile
+
+__all__ = ["read_img_as_numpy"]
+
+
+def read_img_as_numpy(
+ file: AbstractFile,
+ output_size: Optional[Tuple[int, int]] = None,
+ rgb_output: bool = True,
+) -> np.ndarray:
+ """Read an image file into numpy format
+
+ >>> from doctr.io import read_img_as_numpy
+ >>> page = read_img_as_numpy("path/to/your/doc.jpg")
+
+ Args:
+ ----
+ file: the path to the image file
+ output_size: the expected output size of each page in format H x W
+ rgb_output: whether the output ndarray channel order should be RGB instead of BGR.
+
+ Returns:
+ -------
+ the page decoded as numpy ndarray of shape H x W x 3
+ """
+ if isinstance(file, (str, Path)):
+ if not Path(file).is_file():
+ raise FileNotFoundError(f"unable to access {file}")
+ img = cv2.imread(str(file), cv2.IMREAD_COLOR)
+ elif isinstance(file, bytes):
+ _file: np.ndarray = np.frombuffer(file, np.uint8)
+ img = cv2.imdecode(_file, cv2.IMREAD_COLOR)
+ else:
+ raise TypeError("unsupported object type for argument 'file'")
+
+ # Validity check
+ if img is None:
+ raise ValueError("unable to read file.")
+ # Resizing
+ if isinstance(output_size, tuple):
+ img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR)
+ # Switch the channel order
+ if rgb_output:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
diff --git a/doctr/io/image/pytorch.py b/doctr/io/image/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e8450e840de9b96692102bb44706aa70adc325a
--- /dev/null
+++ b/doctr/io/image/pytorch.py
@@ -0,0 +1,109 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from io import BytesIO
+from typing import Tuple
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision.transforms.functional import to_tensor
+
+from doctr.utils.common_types import AbstractPath
+
+__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
+
+
+def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Convert a PIL Image to a PyTorch tensor
+
+ Args:
+ ----
+ pil_img: a PIL image
+ dtype: the output tensor data type
+
+ Returns:
+ -------
+ decoded image as tensor
+ """
+ if dtype == torch.float32:
+ img = to_tensor(pil_img)
+ else:
+ img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype)
+
+ return img
+
+
+def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Read an image file as a PyTorch tensor
+
+ Args:
+ ----
+ img_path: location of the image file
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ decoded image as a tensor
+ """
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
+ raise ValueError("insupported value for dtype")
+
+ pil_img = Image.open(img_path, mode="r").convert("RGB")
+
+ return tensor_from_pil(pil_img, dtype)
+
+
+def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Read a byte stream as a PyTorch tensor
+
+ Args:
+ ----
+ img_content: bytes of a decoded image
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ decoded image as a tensor
+ """
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
+ raise ValueError("insupported value for dtype")
+
+ pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")
+
+ return tensor_from_pil(pil_img, dtype)
+
+
+def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Read an image file as a PyTorch tensor
+
+ Args:
+ ----
+ npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ same image as a tensor of shape (C, H, W)
+ """
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
+ raise ValueError("insupported value for dtype")
+
+ if dtype == torch.float32:
+ img = to_tensor(npy_img)
+ else:
+ img = torch.from_numpy(npy_img)
+ # put it from HWC to CHW format
+ img = img.permute((2, 0, 1)).contiguous()
+ if dtype == torch.float16:
+ # Switch to FP16
+ img = img.to(dtype=torch.float16).div(255)
+
+ return img
+
+
+def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
+ """Get the shape of an image"""
+ return img.shape[-2:] # type: ignore[return-value]
diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..28fb2fadd5cb103bb0c581a5fc10083c7570013b
--- /dev/null
+++ b/doctr/io/image/tensorflow.py
@@ -0,0 +1,110 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+from PIL import Image
+from tensorflow.keras.utils import img_to_array
+
+from doctr.utils.common_types import AbstractPath
+
+__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
+
+
+def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Convert a PIL Image to a TensorFlow tensor
+
+ Args:
+ ----
+ pil_img: a PIL image
+ dtype: the output tensor data type
+
+ Returns:
+ -------
+ decoded image as tensor
+ """
+ npy_img = img_to_array(pil_img)
+
+ return tensor_from_numpy(npy_img, dtype)
+
+
+def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Read an image file as a TensorFlow tensor
+
+ Args:
+ ----
+ img_path: location of the image file
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ decoded image as a tensor
+ """
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
+ raise ValueError("insupported value for dtype")
+
+ img = tf.io.read_file(img_path)
+ img = tf.image.decode_jpeg(img, channels=3)
+
+ if dtype != tf.uint8:
+ img = tf.image.convert_image_dtype(img, dtype=dtype)
+ img = tf.clip_by_value(img, 0, 1)
+
+ return img
+
+
+def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Read a byte stream as a TensorFlow tensor
+
+ Args:
+ ----
+ img_content: bytes of a decoded image
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ decoded image as a tensor
+ """
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
+ raise ValueError("insupported value for dtype")
+
+ img = tf.io.decode_image(img_content, channels=3)
+
+ if dtype != tf.uint8:
+ img = tf.image.convert_image_dtype(img, dtype=dtype)
+ img = tf.clip_by_value(img, 0, 1)
+
+ return img
+
+
+def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Read an image file as a TensorFlow tensor
+
+ Args:
+ ----
+ npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
+
+ Returns:
+ -------
+ same image as a tensor of shape (H, W, C)
+ """
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
+ raise ValueError("insupported value for dtype")
+
+ if dtype == tf.uint8:
+ img = tf.convert_to_tensor(npy_img, dtype=dtype)
+ else:
+ img = tf.image.convert_image_dtype(npy_img, dtype=dtype)
+ img = tf.clip_by_value(img, 0, 1)
+
+ return img
+
+
+def get_img_shape(img: tf.Tensor) -> Tuple[int, int]:
+ """Get the shape of an image"""
+ return img.shape[:2]
diff --git a/doctr/io/pdf.py b/doctr/io/pdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91413f7b1b50cf11061f986cc2f4d2a3a9daacf
--- /dev/null
+++ b/doctr/io/pdf.py
@@ -0,0 +1,42 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List, Optional
+
+import numpy as np
+import pypdfium2 as pdfium
+
+from doctr.utils.common_types import AbstractFile
+
+__all__ = ["read_pdf"]
+
+
+def read_pdf(
+ file: AbstractFile,
+ scale: float = 2,
+ rgb_mode: bool = True,
+ password: Optional[str] = None,
+ **kwargs: Any,
+) -> List[np.ndarray]:
+ """Read a PDF file and convert it into an image in numpy format
+
+ >>> from doctr.io import read_pdf
+ >>> doc = read_pdf("path/to/your/doc.pdf")
+
+ Args:
+ ----
+ file: the path to the PDF file
+ scale: rendering scale (1 corresponds to 72dpi)
+ rgb_mode: if True, the output will be RGB, otherwise BGR
+ password: a password to unlock the document, if encrypted
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
+
+ Returns:
+ -------
+ the list of pages decoded as numpy ndarray of shape H x W x C
+ """
+ # Rasterise pages to numpy ndarrays with pypdfium2
+ pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
+ return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
diff --git a/doctr/io/reader.py b/doctr/io/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..37af393e461d26807dca8899b91db84833d63b52
--- /dev/null
+++ b/doctr/io/reader.py
@@ -0,0 +1,79 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from pathlib import Path
+from typing import List, Sequence, Union
+
+import numpy as np
+
+from doctr.utils.common_types import AbstractFile
+
+from .html import read_html
+from .image import read_img_as_numpy
+from .pdf import read_pdf
+
+__all__ = ["DocumentFile"]
+
+
+class DocumentFile:
+ """Read a document from multiple extensions"""
+
+ @classmethod
+ def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]:
+ """Read a PDF file
+
+ >>> from doctr.io import DocumentFile
+ >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
+
+ Args:
+ ----
+ file: the path to the PDF file or a binary stream
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
+
+ Returns:
+ -------
+ the list of pages decoded as numpy ndarray of shape H x W x 3
+ """
+ return read_pdf(file, **kwargs)
+
+ @classmethod
+ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]:
+ """Interpret a web page as a PDF document
+
+ >>> from doctr.io import DocumentFile
+ >>> doc = DocumentFile.from_url("https://www.yoursite.com")
+
+ Args:
+ ----
+ url: the URL of the target web page
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
+
+ Returns:
+ -------
+ the list of pages decoded as numpy ndarray of shape H x W x 3
+ """
+ pdf_stream = read_html(url)
+ return cls.from_pdf(pdf_stream, **kwargs)
+
+ @classmethod
+ def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]:
+ """Read an image file (or a collection of image files) and convert it into an image in numpy format
+
+ >>> from doctr.io import DocumentFile
+ >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"])
+
+ Args:
+ ----
+ files: the path to the image file or a binary stream, or a collection of those
+ **kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy`
+
+ Returns:
+ -------
+ the list of pages decoded as numpy ndarray of shape H x W x 3
+ """
+ if isinstance(files, (str, Path, bytes)):
+ files = [files]
+
+ return [read_img_as_numpy(file, **kwargs) for file in files]
diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec6f8f86cdfe233391ab1c2fd7800b16c784ceb
--- /dev/null
+++ b/doctr/models/__init__.py
@@ -0,0 +1,6 @@
+from . import artefacts
+from .classification import *
+from .detection import *
+from .recognition import *
+from .zoo import *
+from .factory import *
diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4394363cfd0666d7827883d4ee4ab10634db7d0
--- /dev/null
+++ b/doctr/models/_utils.py
@@ -0,0 +1,163 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from math import floor
+from statistics import median_low
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from langdetect import LangDetectException, detect_langs
+
+__all__ = ["estimate_orientation", "get_language", "invert_data_structure"]
+
+
+def get_max_width_length_ratio(contour: np.ndarray) -> float:
+ """Get the maximum shape ratio of a contour.
+
+ Args:
+ ----
+ contour: the contour from cv2.findContour
+
+ Returns:
+ -------
+ the maximum shape ratio
+ """
+ _, (w, h), _ = cv2.minAreaRect(contour)
+ return max(w / h, h / w)
+
+
+def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
+ """Estimate the angle of the general document orientation based on the
+ lines of the document and the assumption that they should be horizontal.
+
+ Args:
+ ----
+ img: the img or bitmap to analyze (H, W, C)
+ n_ct: the number of contours used for the orientation estimation
+ ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
+
+ Returns:
+ -------
+ the angle of the general document orientation
+ """
+ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
+ max_value = np.max(img)
+ min_value = np.min(img)
+ if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
+ thresh = img.astype(np.uint8)
+ if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
+ gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ gray_img = cv2.medianBlur(gray_img, 5)
+ thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] # type: ignore[assignment]
+
+ # try to merge words in lines
+ (h, w) = img.shape[:2]
+ k_x = max(1, (floor(w / 100)))
+ k_y = max(1, (floor(h / 100)))
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
+ thresh = cv2.dilate(thresh, kernel, iterations=1) # type: ignore[assignment]
+
+ # extract contours
+ contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Sort contours
+ contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
+
+ angles = []
+ for contour in contours[:n_ct]:
+ _, (w, h), angle = cv2.minAreaRect(contour)
+ if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines
+ angles.append(angle)
+ elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree
+ angles.append(angle - 90)
+
+ if len(angles) == 0:
+ return 0 # in case no angles is found
+ else:
+ median = -median_low(angles)
+ return round(median) if abs(median) != 0 else 0
+
+
+def rectify_crops(
+ crops: List[np.ndarray],
+ orientations: List[int],
+) -> List[np.ndarray]:
+ """Rotate each crop of the list according to the predicted orientation:
+ 0: already straight, no rotation
+ 1: 90 ccw, rotate 3 times ccw
+ 2: 180, rotate 2 times ccw
+ 3: 270 ccw, rotate 1 time ccw
+ """
+ # Inverse predictions (if angle of +90 is detected, rotate by -90)
+ orientations = [4 - pred if pred != 0 else 0 for pred in orientations]
+ return (
+ [crop if orientation == 0 else np.rot90(crop, orientation) for orientation, crop in zip(orientations, crops)]
+ if len(orientations) > 0
+ else []
+ )
+
+
+def rectify_loc_preds(
+ page_loc_preds: np.ndarray,
+ orientations: List[int],
+) -> Optional[np.ndarray]:
+ """Orient the quadrangle (Polygon4P) according to the predicted orientation,
+ so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
+ """
+ return (
+ np.stack(
+ [
+ np.roll(page_loc_pred, orientation, axis=0)
+ for orientation, page_loc_pred in zip(orientations, page_loc_preds)
+ ],
+ axis=0,
+ )
+ if len(orientations) > 0
+ else None
+ )
+
+
+def get_language(text: str) -> Tuple[str, float]:
+ """Get languages of a text using langdetect model.
+ Get the language with the highest probability or no language if only a few words or a low probability
+
+ Args:
+ ----
+ text (str): text
+
+ Returns:
+ -------
+ The detected language in ISO 639 code and confidence score
+ """
+ try:
+ lang = detect_langs(text.lower())[0]
+ except LangDetectException:
+ return "unknown", 0.0
+ if len(text) <= 1 or (len(text) <= 5 and lang.prob <= 0.2):
+ return "unknown", 0.0
+ return lang.lang, lang.prob
+
+
+def invert_data_structure(
+ x: Union[List[Dict[str, Any]], Dict[str, List[Any]]],
+) -> Union[List[Dict[str, Any]], Dict[str, List[Any]]]:
+ """Invert a List of Dict of elements to a Dict of list of elements and the other way around
+
+ Args:
+ ----
+ x: a list of dictionaries with the same keys or a dictionary of lists of the same length
+
+ Returns:
+ -------
+ dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists
+ """
+ if isinstance(x, dict):
+ assert len({len(v) for v in x.values()}) == 1, "All the lists in the dictionnary should have the same length."
+ return [dict(zip(x, t)) for t in zip(*x.values())]
+ elif isinstance(x, list):
+ return {k: [dic[k] for dic in x] for k in x[0]}
+ else:
+ raise TypeError(f"Expected input to be either a dict or a list, got {type(input)} instead.")
diff --git a/doctr/models/artefacts/__init__.py b/doctr/models/artefacts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..875f48a8750a8948ca911f41ca81eee06ebbf4f4
--- /dev/null
+++ b/doctr/models/artefacts/__init__.py
@@ -0,0 +1,2 @@
+from .barcode import *
+from .face import *
diff --git a/doctr/models/artefacts/barcode.py b/doctr/models/artefacts/barcode.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f911f47e683f39e333aa7bb75ebed887299131
--- /dev/null
+++ b/doctr/models/artefacts/barcode.py
@@ -0,0 +1,74 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+
+__all__ = ["BarCodeDetector"]
+
+
+class BarCodeDetector:
+ """Implements a Bar-code detector.
+ For now, only horizontal (or with a small angle) bar-codes are supported
+
+ Args:
+ ----
+ min_size: minimum relative size of a barcode on the page
+ canny_minval: lower bound for canny hysteresis
+ canny_maxval: upper-bound for canny hysteresis
+ """
+
+ def __init__(self, min_size: float = 1 / 6, canny_minval: int = 50, canny_maxval: int = 150) -> None:
+ self.min_size = min_size
+ self.canny_minval = canny_minval
+ self.canny_maxval = canny_maxval
+
+ def __call__(
+ self,
+ img: np.ndarray,
+ ) -> List[Tuple[float, float, float, float]]:
+ """Detect Barcodes on the image
+ Args:
+ img: np image
+
+ Returns
+ -------
+ A list of tuples: [(xmin, ymin, xmax, ymax), ...] containing barcodes rel. coordinates
+ """
+ # get image size and define parameters
+ height, width = img.shape[:2]
+ k = (1 + int(width / 512)) * 10 # spatial extension of kernels, 512 -> 20, 1024 -> 30, ...
+ min_w = int(width * self.min_size) # minimal size of a possible barcode
+
+ # Detect edges
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ edges = cv2.Canny(gray, self.canny_minval, self.canny_maxval, apertureSize=3)
+
+ # Horizontal dilation to aggregate bars of the potential barcode
+ # without aggregating text lines of the page vertically
+ edges = cv2.dilate(edges, np.ones((1, k), np.uint8))
+
+ # Instantiate a barcode-shaped kernel and erode to keep only vertical-bar structures
+ bar_code_kernel: np.ndarray = np.zeros((k, 3), np.uint8)
+ bar_code_kernel[..., [0, 2]] = 1
+ edges = cv2.erode(edges, bar_code_kernel, iterations=1)
+
+ # Opening to remove noise
+ edges = cv2.morphologyEx(edges, cv2.MORPH_OPEN, np.ones((k, k), np.uint8))
+
+ # Dilation to retrieve vertical length (lost at the first dilation)
+ edges = cv2.dilate(edges, np.ones((k, 1), np.uint8))
+
+ # Find contours, and keep the widest as barcodes
+ contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+ barcodes = []
+ for contour in contours:
+ x, y, w, h = cv2.boundingRect(contour)
+ if w >= min_w:
+ barcodes.append((x / width, y / height, (x + w) / width, (y + h) / height))
+
+ return barcodes
diff --git a/doctr/models/artefacts/face.py b/doctr/models/artefacts/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6167522ed253d705b9cd9bf5b1e931f874d85c
--- /dev/null
+++ b/doctr/models/artefacts/face.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple
+
+import cv2
+import numpy as np
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["FaceDetector"]
+
+
+class FaceDetector(NestedObject):
+ """Implements a face detector to detect profile pictures on resumes, IDS, driving licenses, passports...
+ Based on open CV CascadeClassifier (haarcascades)
+
+ Args:
+ ----
+ n_faces: maximal number of faces to detect on a single image, default = 1
+ """
+
+ def __init__(
+ self,
+ n_faces: int = 1,
+ ) -> None:
+ self.n_faces = n_faces
+ # Instantiate classifier
+ self.detector = cv2.CascadeClassifier(
+ cv2.data.haarcascades + "haarcascade_frontalface_default.xml" # type: ignore[attr-defined]
+ )
+
+ def extra_repr(self) -> str:
+ return f"n_faces={self.n_faces}"
+
+ def __call__(
+ self,
+ img: np.ndarray,
+ ) -> List[Tuple[float, float, float, float]]:
+ """Detect n_faces on the img
+
+ Args:
+ ----
+ img: image to detect faces on
+
+ Returns:
+ -------
+ A list of size n_faces, each face is a tuple of relative xmin, ymin, xmax, ymax
+ """
+ height, width = img.shape[:2]
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+ faces = self.detector.detectMultiScale(gray, 1.5, 3)
+ # If faces are detected, keep only the biggest ones
+ rel_faces = []
+ if len(faces) > 0:
+ x, y, w, h = sorted(faces, key=lambda x: x[2] + x[3])[-min(self.n_faces, len(faces))]
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
+ rel_faces.append((xmin, ymin, xmax, ymax))
+
+ return rel_faces
diff --git a/doctr/models/builder.py b/doctr/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..397b3b4244ff5b9cfe5efee589677c0b5385a291
--- /dev/null
+++ b/doctr/models/builder.py
@@ -0,0 +1,487 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+from scipy.cluster.hierarchy import fclusterdata
+
+from doctr.io.elements import Block, Document, KIEDocument, KIEPage, Line, Page, Prediction, Word
+from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes
+from doctr.utils.repr import NestedObject
+
+__all__ = ["DocumentBuilder"]
+
+
+class DocumentBuilder(NestedObject):
+ """Implements a document builder
+
+ Args:
+ ----
+ resolve_lines: whether words should be automatically grouped into lines
+ resolve_blocks: whether lines should be automatically grouped into blocks
+ paragraph_break: relative length of the minimum space separating paragraphs
+ export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle
+ box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is.
+ """
+
+ def __init__(
+ self,
+ resolve_lines: bool = True,
+ resolve_blocks: bool = True,
+ paragraph_break: float = 0.035,
+ export_as_straight_boxes: bool = False,
+ ) -> None:
+ self.resolve_lines = resolve_lines
+ self.resolve_blocks = resolve_blocks
+ self.paragraph_break = paragraph_break
+ self.export_as_straight_boxes = export_as_straight_boxes
+
+ @staticmethod
+ def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Sort bounding boxes from top to bottom, left to right
+
+ Args:
+ ----
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox)
+
+ Returns:
+ -------
+ tuple: indices of ordered boxes of shape (N,), boxes
+ If straight boxes are passed tpo the function, boxes are unchanged
+ else: boxes returned are straight boxes fitted to the straightened rotated boxes
+ so that we fit the lines afterwards to the straigthened page
+ """
+ if boxes.ndim == 3:
+ boxes = rotate_boxes(
+ loc_preds=boxes,
+ angle=-estimate_page_angle(boxes),
+ orig_shape=(1024, 1024),
+ min_angle=5.0,
+ )
+ boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1)
+ return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes
+
+ def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]:
+ """Split a line in sub_lines
+
+ Args:
+ ----
+ boxes: bounding boxes of shape (N, 4)
+ word_idcs: list of indexes for the words of the line
+
+ Returns:
+ -------
+ A list of (sub-)lines computed from the original line (words)
+ """
+ lines = []
+ # Sort words horizontally
+ word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()]
+
+ # Eventually split line horizontally
+ if len(word_idcs) < 2:
+ lines.append(word_idcs)
+ else:
+ sub_line = [word_idcs[0]]
+ for i in word_idcs[1:]:
+ horiz_break = True
+
+ prev_box = boxes[sub_line[-1]]
+ # Compute distance between boxes
+ dist = boxes[i, 0] - prev_box[2]
+ # If distance between boxes is lower than paragraph break, same sub-line
+ if dist < self.paragraph_break:
+ horiz_break = False
+
+ if horiz_break:
+ lines.append(sub_line)
+ sub_line = []
+
+ sub_line.append(i)
+ lines.append(sub_line)
+
+ return lines
+
+ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]:
+ """Order boxes to group them in lines
+
+ Args:
+ ----
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox
+
+ Returns:
+ -------
+ nested list of box indices
+ """
+ # Sort boxes, and straighten the boxes if they are rotated
+ idxs, boxes = self._sort_boxes(boxes)
+
+ # Compute median for boxes heights
+ y_med = np.median(boxes[:, 3] - boxes[:, 1])
+
+ lines = []
+ words = [idxs[0]] # Assign the top-left word to the first line
+ # Define a mean y-center for the line
+ y_center_sum = boxes[idxs[0]][[1, 3]].mean()
+
+ for idx in idxs[1:]:
+ vert_break = True
+
+ # Compute y_dist
+ y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words))
+ # If y-center of the box is close enough to mean y-center of the line, same line
+ if y_dist < y_med / 2:
+ vert_break = False
+
+ if vert_break:
+ # Compute sub-lines (horizontal split)
+ lines.extend(self._resolve_sub_lines(boxes, words))
+ words = []
+ y_center_sum = 0
+
+ words.append(idx)
+ y_center_sum += boxes[idx][[1, 3]].mean()
+
+ # Use the remaining words to form the last(s) line(s)
+ if len(words) > 0:
+ # Compute sub-lines (horizontal split)
+ lines.extend(self._resolve_sub_lines(boxes, words))
+
+ return lines
+
+ @staticmethod
+ def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]:
+ """Order lines to group them in blocks
+
+ Args:
+ ----
+ boxes: bounding boxes of shape (N, 4) or (N, 4, 2)
+ lines: list of lines, each line is a list of idx
+
+ Returns:
+ -------
+ nested list of box indices
+ """
+ # Resolve enclosing boxes of lines
+ if boxes.ndim == 3:
+ box_lines: np.ndarray = np.asarray([
+ resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) # type: ignore[misc]
+ for line in lines
+ ])
+ else:
+ _box_lines = [
+ resolve_enclosing_bbox([(tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line])
+ for line in lines
+ ]
+ box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines])
+
+ # Compute geometrical features of lines to clusterize
+ # Clusterizing only with box centers yield to poor results for complex documents
+ if boxes.ndim == 3:
+ box_features: np.ndarray = np.stack(
+ (
+ (box_lines[:, 0, 0] + box_lines[:, 0, 1]) / 2,
+ (box_lines[:, 0, 0] + box_lines[:, 2, 0]) / 2,
+ (box_lines[:, 0, 0] + box_lines[:, 2, 1]) / 2,
+ (box_lines[:, 0, 1] + box_lines[:, 2, 1]) / 2,
+ (box_lines[:, 0, 1] + box_lines[:, 2, 0]) / 2,
+ (box_lines[:, 2, 0] + box_lines[:, 2, 1]) / 2,
+ ),
+ axis=-1,
+ )
+ else:
+ box_features = np.stack(
+ (
+ (box_lines[:, 0] + box_lines[:, 3]) / 2,
+ (box_lines[:, 1] + box_lines[:, 2]) / 2,
+ (box_lines[:, 0] + box_lines[:, 2]) / 2,
+ (box_lines[:, 1] + box_lines[:, 3]) / 2,
+ box_lines[:, 0],
+ box_lines[:, 1],
+ ),
+ axis=-1,
+ )
+ # Compute clusters
+ clusters = fclusterdata(box_features, t=0.1, depth=4, criterion="distance", metric="euclidean")
+
+ _blocks: Dict[int, List[int]] = {}
+ # Form clusters
+ for line_idx, cluster_idx in enumerate(clusters):
+ if cluster_idx in _blocks.keys():
+ _blocks[cluster_idx].append(line_idx)
+ else:
+ _blocks[cluster_idx] = [line_idx]
+
+ # Retrieve word-box level to return a fully nested structure
+ blocks = [[lines[idx] for idx in block] for block in _blocks.values()]
+
+ return blocks
+
+ def _build_blocks(
+ self,
+ boxes: np.ndarray,
+ word_preds: List[Tuple[str, float]],
+ crop_orientations: List[Dict[str, Any]],
+ ) -> List[Block]:
+ """Gather independent words in structured blocks
+
+ Args:
+ ----
+ boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
+ word_preds: list of all detected words of the page, of shape N
+ crop_orientations: list of dictoinaries containing
+ the general orientation (orientations + confidences) of the crops
+
+ Returns:
+ -------
+ list of block elements
+ """
+ if boxes.shape[0] != len(word_preds):
+ raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
+
+ if boxes.shape[0] == 0:
+ return []
+
+ # Decide whether we try to form lines
+ _boxes = boxes
+ if self.resolve_lines:
+ lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4])
+ # Decide whether we try to form blocks
+ if self.resolve_blocks and len(lines) > 1:
+ _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines)
+ else:
+ _blocks = [lines]
+ else:
+ # Sort bounding boxes, one line for all boxes, one block for the line
+ lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] # type: ignore[list-item]
+ _blocks = [lines]
+
+ blocks = [
+ Block([
+ Line([
+ Word(
+ *word_preds[idx],
+ tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
+ crop_orientations[idx],
+ )
+ if boxes.ndim == 3
+ else Word(
+ *word_preds[idx],
+ ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
+ crop_orientations[idx],
+ )
+ for idx in line
+ ])
+ for line in lines
+ ])
+ for lines in _blocks
+ ]
+
+ return blocks
+
+ def extra_repr(self) -> str:
+ return (
+ f"resolve_lines={self.resolve_lines}, resolve_blocks={self.resolve_blocks}, "
+ f"paragraph_break={self.paragraph_break}, "
+ f"export_as_straight_boxes={self.export_as_straight_boxes}"
+ )
+
+ def __call__(
+ self,
+ pages: List[np.ndarray],
+ boxes: List[np.ndarray],
+ text_preds: List[List[Tuple[str, float]]],
+ page_shapes: List[Tuple[int, int]],
+ crop_orientations: List[Dict[str, Any]],
+ orientations: Optional[List[Dict[str, Any]]] = None,
+ languages: Optional[List[Dict[str, Any]]] = None,
+ ) -> Document:
+ """Re-arrange detected words into structured blocks
+
+ Args:
+ ----
+ pages: list of N elements, where each element represents the page image
+ boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5)
+ or (*, 6) for all words for a given page
+ text_preds: list of N elements, where each element is the list of all word prediction (text + confidence)
+ page_shapes: shape of each page, of size N
+ crop_orientations: list of N elements, where each element is
+ a dictionary containing the general orientation (orientations + confidences) of the crops
+ orientations: optional, list of N elements,
+ where each element is a dictionary containing the orientation (orientation + confidence)
+ languages: optional, list of N elements,
+ where each element is a dictionary containing the language (language + confidence)
+
+ Returns:
+ -------
+ document object
+ """
+ if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len(
+ crop_orientations
+ ):
+ raise ValueError("All arguments are expected to be lists of the same size")
+
+ _orientations = (
+ orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
+ )
+ _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item]
+ if self.export_as_straight_boxes and len(boxes) > 0:
+ # If boxes are already straight OK, else fit a bounding rect
+ if boxes[0].ndim == 3:
+ # Iterate over pages and boxes
+ boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes]
+
+ _pages = [
+ Page(
+ page,
+ self._build_blocks(
+ page_boxes,
+ word_preds,
+ word_crop_orientations,
+ ),
+ _idx,
+ shape,
+ orientation,
+ language,
+ )
+ for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip(
+ pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages
+ )
+ ]
+
+ return Document(_pages)
+
+
+class KIEDocumentBuilder(DocumentBuilder):
+ """Implements a KIE document builder
+
+ Args:
+ ----
+ resolve_lines: whether words should be automatically grouped into lines
+ resolve_blocks: whether lines should be automatically grouped into blocks
+ paragraph_break: relative length of the minimum space separating paragraphs
+ export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle
+ box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is.
+ """
+
+ def __call__( # type: ignore[override]
+ self,
+ pages: List[np.ndarray],
+ boxes: List[Dict[str, np.ndarray]],
+ text_preds: List[Dict[str, List[Tuple[str, float]]]],
+ page_shapes: List[Tuple[int, int]],
+ crop_orientations: List[Dict[str, List[Dict[str, Any]]]],
+ orientations: Optional[List[Dict[str, Any]]] = None,
+ languages: Optional[List[Dict[str, Any]]] = None,
+ ) -> KIEDocument:
+ """Re-arrange detected words into structured predictions
+
+ Args:
+ ----
+ pages: list of N elements, where each element represents the page image
+ boxes: list of N dictionaries, where each element represents the localization predictions for a class,
+ of shape (*, 5) or (*, 6) for all predictions
+ text_preds: list of N dictionaries, where each element is the list of all word prediction
+ page_shapes: shape of each page, of size N
+ crop_orientations: list of N dictonaries, where each element is
+ a list containing the general crop orientations (orientations + confidences) of the crops
+ orientations: optional, list of N elements,
+ where each element is a dictionary containing the orientation (orientation + confidence)
+ languages: optional, list of N elements,
+ where each element is a dictionary containing the language (language + confidence)
+
+ Returns:
+ -------
+ document object
+ """
+ if len(boxes) != len(text_preds) != len(crop_orientations) or len(boxes) != len(page_shapes) != len(
+ crop_orientations
+ ):
+ raise ValueError("All arguments are expected to be lists of the same size")
+ _orientations = (
+ orientations if isinstance(orientations, list) else [None] * len(boxes) # type: ignore[list-item]
+ )
+ _languages = languages if isinstance(languages, list) else [None] * len(boxes) # type: ignore[list-item]
+ if self.export_as_straight_boxes and len(boxes) > 0:
+ # If boxes are already straight OK, else fit a bounding rect
+ if next(iter(boxes[0].values())).ndim == 3:
+ straight_boxes: List[Dict[str, np.ndarray]] = []
+ # Iterate over pages
+ for p_boxes in boxes:
+ # Iterate over boxes of the pages
+ straight_boxes_dict = {}
+ for k, box in p_boxes.items():
+ straight_boxes_dict[k] = np.concatenate((box.min(1), box.max(1)), 1)
+ straight_boxes.append(straight_boxes_dict)
+ boxes = straight_boxes
+
+ _pages = [
+ KIEPage(
+ page,
+ {
+ k: self._build_blocks(
+ page_boxes[k],
+ word_preds[k],
+ word_crop_orientations[k],
+ )
+ for k in page_boxes.keys()
+ },
+ _idx,
+ shape,
+ orientation,
+ language,
+ )
+ for page, _idx, shape, page_boxes, word_preds, word_crop_orientations, orientation, language in zip(
+ pages, range(len(boxes)), page_shapes, boxes, text_preds, crop_orientations, _orientations, _languages
+ )
+ ]
+
+ return KIEDocument(_pages)
+
+ def _build_blocks( # type: ignore[override]
+ self,
+ boxes: np.ndarray,
+ word_preds: List[Tuple[str, float]],
+ crop_orientations: List[Dict[str, Any]],
+ ) -> List[Prediction]:
+ """Gather independent words in structured blocks
+
+ Args:
+ ----
+ boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2)
+ word_preds: list of all detected words of the page, of shape N
+ crop_orientations: list of orientations for each word crop
+
+ Returns:
+ -------
+ list of block elements
+ """
+ if boxes.shape[0] != len(word_preds):
+ raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}")
+
+ if boxes.shape[0] == 0:
+ return []
+
+ # Decide whether we try to form lines
+ _boxes = boxes
+ idxs, _ = self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])
+ predictions = [
+ Prediction(
+ value=word_preds[idx][0],
+ confidence=word_preds[idx][1],
+ geometry=tuple([tuple(pt) for pt in boxes[idx].tolist()]), # type: ignore[arg-type]
+ crop_orientation=crop_orientations[idx],
+ )
+ if boxes.ndim == 3
+ else Prediction(
+ value=word_preds[idx][0],
+ confidence=word_preds[idx][1],
+ geometry=((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])),
+ crop_orientation=crop_orientations[idx],
+ )
+ for idx in idxs
+ ]
+ return predictions
diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..496079740010bce5f1da5161808ed66ccfb9a898
--- /dev/null
+++ b/doctr/models/classification/__init__.py
@@ -0,0 +1,7 @@
+from .mobilenet import *
+from .resnet import *
+from .vgg import *
+from .magc_resnet import *
+from .vit import *
+from .textnet import *
+from .zoo import *
diff --git a/doctr/models/classification/magc_resnet/__init__.py b/doctr/models/classification/magc_resnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/classification/magc_resnet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/classification/magc_resnet/pytorch.py b/doctr/models/classification/magc_resnet/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f503d7c7fadbe0026b2c93032172258eb94145ad
--- /dev/null
+++ b/doctr/models/classification/magc_resnet/pytorch.py
@@ -0,0 +1,177 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+import math
+from copy import deepcopy
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from doctr.datasets import VOCABS
+
+from ...utils.pytorch import load_pretrained_params
+from ..resnet.pytorch import ResNet
+
+__all__ = ["magc_resnet31"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "magc_resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/magc_resnet31-857391d8.pt&src=0",
+ },
+}
+
+
+class MAGC(nn.Module):
+ """Implements the Multi-Aspect Global Context Attention, as described in
+ `_.
+
+ Args:
+ ----
+ inplanes: input channels
+ headers: number of headers to split channels
+ attn_scale: if True, re-scale attention to counteract the variance distibutions
+ ratio: bottleneck ratio
+ **kwargs
+ """
+
+ def __init__(
+ self,
+ inplanes: int,
+ headers: int = 8,
+ attn_scale: bool = False,
+ ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.headers = headers
+ self.inplanes = inplanes
+ self.attn_scale = attn_scale
+ self.planes = int(inplanes * ratio)
+
+ self.single_header_inplanes = int(inplanes / headers)
+
+ self.conv_mask = nn.Conv2d(self.single_header_inplanes, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=1)
+
+ self.transform = nn.Sequential(
+ nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(self.planes, self.inplanes, kernel_size=1),
+ )
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ batch, _, height, width = inputs.size()
+ # (N * headers, C / headers, H , W)
+ x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width)
+ shortcut = x
+ # (N * headers, C / headers, H * W)
+ shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width)
+
+ # (N * headers, 1, H, W)
+ context_mask = self.conv_mask(x)
+ # (N * headers, H * W)
+ context_mask = context_mask.view(batch * self.headers, -1)
+
+ # scale variance
+ if self.attn_scale and self.headers > 1:
+ context_mask = context_mask / math.sqrt(self.single_header_inplanes)
+
+ # (N * headers, H * W)
+ context_mask = self.softmax(context_mask)
+
+ # (N * headers, C / headers)
+ context = (shortcut * context_mask.unsqueeze(1)).sum(-1)
+
+ # (N, C, 1, 1)
+ context = context.view(batch, self.headers * self.single_header_inplanes, 1, 1)
+
+ # Transform: B, C, 1, 1 -> B, C, 1, 1
+ transformed = self.transform(context)
+ return inputs + transformed
+
+
+def _magc_resnet(
+ arch: str,
+ pretrained: bool,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_stride: List[int],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> ResNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = ResNet(
+ num_blocks,
+ output_channels,
+ stage_stride,
+ stage_conv,
+ stage_pooling,
+ attn_module=partial(MAGC, headers=8, attn_scale=True),
+ cfg=_cfg,
+ **kwargs,
+ )
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet31 architecture with Multi-Aspect Global Context Attention as described in
+ `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import magc_resnet31
+ >>> model = magc_resnet31(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _magc_resnet(
+ "magc_resnet31",
+ pretrained,
+ [1, 2, 5, 3],
+ [256, 256, 512, 512],
+ [1, 1, 1, 1],
+ [True] * 4,
+ [(2, 2), (2, 1), None, None],
+ origin_stem=False,
+ stem_channels=128,
+ ignore_keys=["13.weight", "13.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..e791e661bfc9653e69302ce9d5f11315cd19ff6e
--- /dev/null
+++ b/doctr/models/classification/magc_resnet/tensorflow.py
@@ -0,0 +1,192 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from copy import deepcopy
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+from tensorflow.keras.models import Sequential
+
+from doctr.datasets import VOCABS
+
+from ...utils import load_pretrained_params
+from ..resnet.tensorflow import ResNet
+
+__all__ = ["magc_resnet31"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "magc_resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0",
+ },
+}
+
+
+class MAGC(layers.Layer):
+ """Implements the Multi-Aspect Global Context Attention, as described in
+ `_.
+
+ Args:
+ ----
+ inplanes: input channels
+ headers: number of headers to split channels
+ attn_scale: if True, re-scale attention to counteract the variance distibutions
+ ratio: bottleneck ratio
+ **kwargs
+ """
+
+ def __init__(
+ self,
+ inplanes: int,
+ headers: int = 8,
+ attn_scale: bool = False,
+ ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.headers = headers # h
+ self.inplanes = inplanes # C
+ self.attn_scale = attn_scale
+ self.planes = int(inplanes * ratio)
+
+ self.single_header_inplanes = int(inplanes / headers) # C / h
+
+ self.conv_mask = layers.Conv2D(filters=1, kernel_size=1, kernel_initializer=tf.initializers.he_normal())
+
+ self.transform = Sequential(
+ [
+ layers.Conv2D(filters=self.planes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
+ layers.LayerNormalization([1, 2, 3]),
+ layers.ReLU(),
+ layers.Conv2D(filters=self.inplanes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()),
+ ],
+ name="transform",
+ )
+
+ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
+ b, h, w, c = (tf.shape(inputs)[i] for i in range(4))
+
+ # B, H, W, C -->> B*h, H, W, C/h
+ x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes))
+ x = tf.transpose(x, perm=(0, 3, 1, 2, 4))
+ x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes))
+
+ # Compute shorcut
+ shortcut = x
+ # B*h, 1, H*W, C/h
+ shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes))
+ # B*h, 1, C/h, H*W
+ shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2])
+
+ # Compute context mask
+ # B*h, H, W, 1
+ context_mask = self.conv_mask(x)
+ # B*h, 1, H*W, 1
+ context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1))
+ # scale variance
+ if self.attn_scale and self.headers > 1:
+ context_mask = context_mask / math.sqrt(self.single_header_inplanes)
+ # B*h, 1, H*W, 1
+ context_mask = tf.keras.activations.softmax(context_mask, axis=2)
+
+ # Compute context
+ # B*h, 1, C/h, 1
+ context = tf.matmul(shortcut, context_mask)
+ context = tf.reshape(context, shape=(b, 1, c, 1))
+ # B, 1, 1, C
+ context = tf.transpose(context, perm=(0, 1, 3, 2))
+ # Set shape to resolve shape when calling this module in the Sequential MAGCResnet
+ batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1]
+ context.set_shape([batch, 1, 1, chan])
+ return context
+
+ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
+ # Context modeling: B, H, W, C -> B, 1, 1, C
+ context = self.context_modeling(inputs)
+ # Transform: B, 1, 1, C -> B, 1, 1, C
+ transformed = self.transform(context)
+ return inputs + transformed
+
+
+def _magc_resnet(
+ arch: str,
+ pretrained: bool,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_downsample: List[bool],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ origin_stem: bool = True,
+ **kwargs: Any,
+) -> ResNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = ResNet(
+ num_blocks,
+ output_channels,
+ stage_downsample,
+ stage_conv,
+ stage_pooling,
+ origin_stem,
+ attn_module=partial(MAGC, headers=8, attn_scale=True),
+ cfg=_cfg,
+ **kwargs,
+ )
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet31 architecture with Multi-Aspect Global Context Attention as described in
+ `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import magc_resnet31
+ >>> model = magc_resnet31(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _magc_resnet(
+ "magc_resnet31",
+ pretrained,
+ [1, 2, 5, 3],
+ [256, 256, 512, 512],
+ [False] * 4,
+ [True] * 4,
+ [(2, 2), (2, 1), None, None],
+ False,
+ stem_channels=128,
+ **kwargs,
+ )
diff --git a/doctr/models/classification/mobilenet/__init__.py b/doctr/models/classification/mobilenet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..64556e403a5697432f805a5af28dab812fa8b932
--- /dev/null
+++ b/doctr/models/classification/mobilenet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import *
diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc3d18f59e7fba95e0017b8312ed17d98a3ba1c5
--- /dev/null
+++ b/doctr/models/classification/mobilenet/pytorch.py
@@ -0,0 +1,240 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+from torchvision.models import mobilenetv3
+
+from doctr.datasets import VOCABS
+
+from ...utils import load_pretrained_params
+
+__all__ = [
+ "mobilenet_v3_small",
+ "mobilenet_v3_small_r",
+ "mobilenet_v3_large",
+ "mobilenet_v3_large_r",
+ "mobilenet_v3_small_orientation",
+]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "mobilenet_v3_large": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-11fc8cb9.pt&src=0",
+ },
+ "mobilenet_v3_large_r": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-74a22066.pt&src=0",
+ },
+ "mobilenet_v3_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-6a4bfa6b.pt&src=0",
+ },
+ "mobilenet_v3_small_r": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
+ },
+ "mobilenet_v3_small_orientation": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 128, 128),
+ "classes": [0, 90, 180, 270],
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0",
+ },
+}
+
+
+def _mobilenet_v3(
+ arch: str,
+ pretrained: bool,
+ rect_strides: Optional[List[str]] = None,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> mobilenetv3.MobileNetV3:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ if arch.startswith("mobilenet_v3_small"):
+ model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None)
+ else:
+ model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None)
+
+ # Rectangular strides
+ if isinstance(rect_strides, list):
+ for layer_name in rect_strides:
+ m = model
+ for child in layer_name.split("."):
+ m = getattr(m, child)
+ m.stride = (2, 1)
+
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ model.cfg = _cfg
+
+ return model
+
+
+def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import mobilenet_v3_small
+ >>> model = mobilenetv3_small(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a torch.nn.Module
+ """
+ return _mobilenet_v3(
+ "mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs
+ )
+
+
+def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_, with rectangular pooling.
+
+ >>> import torch
+ >>> from doctr.models import mobilenet_v3_small_r
+ >>> model = mobilenet_v3_small_r(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a torch.nn.Module
+ """
+ return _mobilenet_v3(
+ "mobilenet_v3_small_r",
+ pretrained,
+ ["features.2.block.1.0", "features.4.block.1.0", "features.9.block.1.0"],
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
+ **kwargs,
+ )
+
+
+def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
+ """MobileNetV3-Large architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import mobilenet_v3_large
+ >>> model = mobilenet_v3_large(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a torch.nn.Module
+ """
+ return _mobilenet_v3(
+ "mobilenet_v3_large",
+ pretrained,
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
+ **kwargs,
+ )
+
+
+def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
+ """MobileNetV3-Large architecture as described in
+ `"Searching for MobileNetV3",
+ `_, with rectangular pooling.
+
+ >>> import torch
+ >>> from doctr.models import mobilenet_v3_large_r
+ >>> model = mobilenet_v3_large_r(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a torch.nn.Module
+ """
+ return _mobilenet_v3(
+ "mobilenet_v3_large_r",
+ pretrained,
+ ["features.4.block.1.0", "features.7.block.1.0", "features.13.block.1.0"],
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
+ **kwargs,
+ )
+
+
+def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import mobilenet_v3_small_orientation
+ >>> model = mobilenet_v3_small_orientation(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a torch.nn.Module
+ """
+ return _mobilenet_v3(
+ "mobilenet_v3_small_orientation",
+ pretrained,
+ ignore_keys=["classifier.3.weight", "classifier.3.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12521865c5d1667ae998a447516586d29fedffa
--- /dev/null
+++ b/doctr/models/classification/mobilenet/tensorflow.py
@@ -0,0 +1,409 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import tensorflow as tf
+from tensorflow.keras import layers
+from tensorflow.keras.models import Sequential
+
+from ....datasets import VOCABS
+from ...utils import conv_sequence, load_pretrained_params
+
+__all__ = [
+ "MobileNetV3",
+ "mobilenet_v3_small",
+ "mobilenet_v3_small_r",
+ "mobilenet_v3_large",
+ "mobilenet_v3_large_r",
+ "mobilenet_v3_small_orientation",
+]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "mobilenet_v3_large": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
+ },
+ "mobilenet_v3_large_r": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
+ },
+ "mobilenet_v3_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
+ },
+ "mobilenet_v3_small_r": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
+ },
+ "mobilenet_v3_small_orientation": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (128, 128, 3),
+ "classes": [0, 90, 180, 270],
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
+ },
+}
+
+
+def hard_swish(x: tf.Tensor) -> tf.Tensor:
+ return x * tf.nn.relu6(x + 3.0) / 6.0
+
+
+def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class SqueezeExcitation(Sequential):
+ """Squeeze and Excitation."""
+
+ def __init__(self, chan: int, squeeze_factor: int = 4) -> None:
+ super().__init__([
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(chan // squeeze_factor, activation="relu"),
+ layers.Dense(chan, activation="hard_sigmoid"),
+ layers.Reshape((1, 1, chan)),
+ ])
+
+ def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ x = super().call(inputs, **kwargs)
+ x = tf.math.multiply(inputs, x)
+ return x
+
+
+class InvertedResidualConfig:
+ def __init__(
+ self,
+ input_channels: int,
+ kernel: int,
+ expanded_channels: int,
+ out_channels: int,
+ use_se: bool,
+ activation: str,
+ stride: Union[int, Tuple[int, int]],
+ width_mult: float = 1,
+ ) -> None:
+ self.input_channels = self.adjust_channels(input_channels, width_mult)
+ self.kernel = kernel
+ self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
+ self.out_channels = self.adjust_channels(out_channels, width_mult)
+ self.use_se = use_se
+ self.use_hs = activation == "HS"
+ self.stride = stride
+
+ @staticmethod
+ def adjust_channels(channels: int, width_mult: float):
+ return _make_divisible(channels * width_mult, 8)
+
+
+class InvertedResidual(layers.Layer):
+ """InvertedResidual for mobilenet
+
+ Args:
+ ----
+ conf: configuration object for inverted residual
+ """
+
+ def __init__(
+ self,
+ conf: InvertedResidualConfig,
+ **kwargs: Any,
+ ) -> None:
+ _kwargs = {"input_shape": kwargs.pop("input_shape")} if isinstance(kwargs.get("input_shape"), tuple) else {}
+ super().__init__(**kwargs)
+
+ act_fn = hard_swish if conf.use_hs else tf.nn.relu
+
+ _is_s1 = (isinstance(conf.stride, tuple) and conf.stride == (1, 1)) or conf.stride == 1
+ self.use_res_connect = _is_s1 and conf.input_channels == conf.out_channels
+
+ _layers = []
+ # expand
+ if conf.expanded_channels != conf.input_channels:
+ _layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs))
+
+ # depth-wise
+ _layers.extend(
+ conv_sequence(
+ conf.expanded_channels,
+ act_fn,
+ kernel_size=conf.kernel,
+ strides=conf.stride,
+ bn=True,
+ groups=conf.expanded_channels,
+ )
+ )
+
+ if conf.use_se:
+ _layers.append(SqueezeExcitation(conf.expanded_channels))
+
+ # project
+ _layers.extend(
+ conv_sequence(
+ conf.out_channels,
+ None,
+ kernel_size=1,
+ bn=True,
+ )
+ )
+
+ self.block = Sequential(_layers)
+
+ def call(
+ self,
+ inputs: tf.Tensor,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ out = self.block(inputs, **kwargs)
+ if self.use_res_connect:
+ out = tf.add(out, inputs)
+
+ return out
+
+
+class MobileNetV3(Sequential):
+ """Implements MobileNetV3, inspired from both:
+ `_.
+ and `_.
+ """
+
+ def __init__(
+ self,
+ layout: List[InvertedResidualConfig],
+ include_top: bool = True,
+ head_chans: int = 1024,
+ num_classes: int = 1000,
+ cfg: Optional[Dict[str, Any]] = None,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ ) -> None:
+ _layers = [
+ Sequential(
+ conv_sequence(
+ layout[0].input_channels, hard_swish, True, kernel_size=3, strides=2, input_shape=input_shape
+ ),
+ name="stem",
+ )
+ ]
+
+ for idx, conf in enumerate(layout):
+ _layers.append(
+ InvertedResidual(conf, name=f"inverted_{idx}"),
+ )
+
+ _layers.append(
+ Sequential(conv_sequence(6 * layout[-1].out_channels, hard_swish, True, kernel_size=1), name="final_block")
+ )
+
+ if include_top:
+ _layers.extend([
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(head_chans, activation=hard_swish),
+ layers.Dropout(0.2),
+ layers.Dense(num_classes),
+ ])
+
+ super().__init__(_layers)
+ self.cfg = cfg
+
+
+def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwargs: Any) -> MobileNetV3:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ kwargs.pop("classes")
+
+ # cf. Table 1 & 2 of the paper
+ if arch.startswith("mobilenet_v3_small"):
+ inverted_residual_setting = [
+ InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1
+ InvertedResidualConfig(16, 3, 72, 24, False, "RE", (2, 1) if rect_strides else 2), # C2
+ InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1),
+ InvertedResidualConfig(24, 5, 96, 40, True, "HS", (2, 1) if rect_strides else 2), # C3
+ InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
+ InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1),
+ InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1),
+ InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1),
+ InvertedResidualConfig(48, 5, 288, 96, True, "HS", (2, 1) if rect_strides else 2), # C4
+ InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
+ InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1),
+ ]
+ head_chans = 1024
+ else:
+ inverted_residual_setting = [
+ InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1),
+ InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1
+ InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1),
+ InvertedResidualConfig(24, 5, 72, 40, True, "RE", (2, 1) if rect_strides else 2), # C2
+ InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
+ InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1),
+ InvertedResidualConfig(40, 3, 240, 80, False, "HS", (2, 1) if rect_strides else 2), # C3
+ InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1),
+ InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
+ InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1),
+ InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1),
+ InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1),
+ InvertedResidualConfig(112, 5, 672, 160, True, "HS", (2, 1) if rect_strides else 2), # C4
+ InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
+ InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1),
+ ]
+ head_chans = 1280
+
+ kwargs["num_classes"] = _cfg["num_classes"]
+ kwargs["input_shape"] = _cfg["input_shape"]
+
+ # Build the model
+ model = MobileNetV3(
+ inverted_residual_setting,
+ head_chans=head_chans,
+ cfg=_cfg,
+ **kwargs,
+ )
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import mobilenet_v3_small
+ >>> model = mobilenet_v3_small(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a keras.Model
+ """
+ return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs)
+
+
+def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_, with rectangular pooling.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import mobilenet_v3_small_r
+ >>> model = mobilenet_v3_small_r(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a keras.Model
+ """
+ return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs)
+
+
+def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
+ """MobileNetV3-Large architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import mobilenet_v3_large
+ >>> model = mobilenet_v3_large(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a keras.Model
+ """
+ return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs)
+
+
+def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
+ """MobileNetV3-Large architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import mobilenet_v3_large_r
+ >>> model = mobilenet_v3_large_r(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a keras.Model
+ """
+ return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)
+
+
+def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
+ """MobileNetV3-Small architecture as described in
+ `"Searching for MobileNetV3",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import mobilenet_v3_small_orientation
+ >>> model = mobilenet_v3_small_orientation(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the MobileNetV3 architecture
+
+ Returns:
+ -------
+ a keras.Model
+ """
+ return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
diff --git a/doctr/models/classification/predictor/__init__.py b/doctr/models/classification/predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/classification/predictor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8bed39eb8c6f0ca2461648232519cc56f172b10
--- /dev/null
+++ b/doctr/models/classification/predictor/pytorch.py
@@ -0,0 +1,67 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.models.utils import set_device_and_dtype
+
+__all__ = ["CropOrientationPredictor"]
+
+
+class CropOrientationPredictor(nn.Module):
+ """Implements an object able to detect the reading direction of a text box.
+ 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core classification architecture (backbone + classification head)
+ """
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: nn.Module,
+ ) -> None:
+ super().__init__()
+ self.pre_processor = pre_processor
+ self.model = model.eval()
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ crops: List[Union[np.ndarray, torch.Tensor]],
+ ) -> List[Union[List[int], List[float]]]:
+ # Dimension check
+ if any(crop.ndim != 3 for crop in crops):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ processed_batches = self.pre_processor(crops)
+ _params = next(self.model.parameters())
+ self.model, processed_batches = set_device_and_dtype(
+ self.model, processed_batches, _params.device, _params.dtype
+ )
+ predicted_batches = [self.model(batch) for batch in processed_batches]
+ # confidence
+ probs = [
+ torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
+ ]
+ # Postprocess predictions
+ predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]
+
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
+ # Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
+ classes = [
+ int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
+ for idx in class_idxs
+ ]
+ confs = [round(float(p), 2) for prob in probs for p in prob]
+
+ return [class_idxs, classes, confs]
diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..2299bacbd70a4fa7cd5f802020a4ed49ce73f160
--- /dev/null
+++ b/doctr/models/classification/predictor/tensorflow.py
@@ -0,0 +1,62 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.utils.repr import NestedObject
+
+__all__ = ["CropOrientationPredictor"]
+
+
+class CropOrientationPredictor(NestedObject):
+ """Implements an object able to detect the reading direction of a text box.
+ 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core classification architecture (backbone + classification head)
+ """
+
+ _children_names: List[str] = ["pre_processor", "model"]
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: keras.Model,
+ ) -> None:
+ self.pre_processor = pre_processor
+ self.model = model
+
+ def __call__(
+ self,
+ crops: List[Union[np.ndarray, tf.Tensor]],
+ ) -> List[Union[List[int], List[float]]]:
+ # Dimension check
+ if any(crop.ndim != 3 for crop in crops):
+ raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
+
+ processed_batches = self.pre_processor(crops)
+ predicted_batches = [self.model(batch, training=False) for batch in processed_batches]
+
+ # confidence
+ probs = [tf.math.reduce_max(tf.nn.softmax(batch, axis=1), axis=1).numpy() for batch in predicted_batches]
+ # Postprocess predictions
+ predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]
+
+ class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
+ # Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
+ classes = [
+ int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
+ for idx in class_idxs
+ ]
+ confs = [round(float(p), 2) for prob in probs for p in prob]
+
+ return [class_idxs, classes, confs]
diff --git a/doctr/models/classification/resnet/__init__.py b/doctr/models/classification/resnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/classification/resnet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/classification/resnet/pytorch.py b/doctr/models/classification/resnet/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7591741c2985f48f3be3b18e766d293d754e04ed
--- /dev/null
+++ b/doctr/models/classification/resnet/pytorch.py
@@ -0,0 +1,366 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from torch import nn
+from torchvision.models.resnet import BasicBlock
+from torchvision.models.resnet import ResNet as TVResNet
+from torchvision.models.resnet import resnet18 as tv_resnet18
+from torchvision.models.resnet import resnet34 as tv_resnet34
+from torchvision.models.resnet import resnet50 as tv_resnet50
+
+from doctr.datasets import VOCABS
+
+from ...utils import conv_sequence_pt, load_pretrained_params
+
+__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "resnet_stage"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "resnet18": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-244bf390.pt&src=0",
+ },
+ "resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet31-1056cc5c.pt&src=0",
+ },
+ "resnet34": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-bd8725db.pt&src=0",
+ },
+ "resnet50": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-1a6c155e.pt&src=0",
+ },
+ "resnet34_wide": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/resnet34_wide-b4b3e39e.pt&src=0",
+ },
+}
+
+
+def resnet_stage(in_channels: int, out_channels: int, num_blocks: int, stride: int) -> List[nn.Module]:
+ """Build a ResNet stage"""
+ _layers: List[nn.Module] = []
+
+ in_chan = in_channels
+ s = stride
+ for _ in range(num_blocks):
+ downsample = None
+ if in_chan != out_channels:
+ downsample = nn.Sequential(*conv_sequence_pt(in_chan, out_channels, False, True, kernel_size=1, stride=s))
+
+ _layers.append(BasicBlock(in_chan, out_channels, stride=s, downsample=downsample))
+ in_chan = out_channels
+ # Only the first block can have stride != 1
+ s = 1
+
+ return _layers
+
+
+class ResNet(nn.Sequential):
+ """Implements a ResNet-31 architecture from `"Show, Attend and Read:A Simple and Strong Baseline for Irregular
+ Text Recognition" `_.
+
+ Args:
+ ----
+ num_blocks: number of resnet block in each stage
+ output_channels: number of channels in each stage
+ stage_conv: whether to add a conv_sequence after each stage
+ stage_pooling: pooling to add after each stage (if None, no pooling)
+ origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+ stem_channels: number of output channels of the stem convolutions
+ attn_module: attention module to use in each stage
+ include_top: whether the classifier head should be instantiated
+ num_classes: number of output classes
+ """
+
+ def __init__(
+ self,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_stride: List[int],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ origin_stem: bool = True,
+ stem_channels: int = 64,
+ attn_module: Optional[Callable[[int], nn.Module]] = None,
+ include_top: bool = True,
+ num_classes: int = 1000,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers: List[nn.Module]
+ if origin_stem:
+ _layers = [
+ *conv_sequence_pt(3, stem_channels, True, True, kernel_size=7, padding=3, stride=2),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ ]
+ else:
+ _layers = [
+ *conv_sequence_pt(3, stem_channels // 2, True, True, kernel_size=3, padding=1),
+ *conv_sequence_pt(stem_channels // 2, stem_channels, True, True, kernel_size=3, padding=1),
+ nn.MaxPool2d(2),
+ ]
+ in_chans = [stem_channels] + output_channels[:-1]
+ for n_blocks, in_chan, out_chan, stride, conv, pool in zip(
+ num_blocks, in_chans, output_channels, stage_stride, stage_conv, stage_pooling
+ ):
+ _stage = resnet_stage(in_chan, out_chan, n_blocks, stride)
+ if attn_module is not None:
+ _stage.append(attn_module(out_chan))
+ if conv:
+ _stage.extend(conv_sequence_pt(out_chan, out_chan, True, True, kernel_size=3, padding=1))
+ if pool is not None:
+ _stage.append(nn.MaxPool2d(pool))
+ _layers.append(nn.Sequential(*_stage))
+
+ if include_top:
+ _layers.extend([
+ nn.AdaptiveAvgPool2d(1),
+ nn.Flatten(1),
+ nn.Linear(output_channels[-1], num_classes, bias=True),
+ ])
+
+ super().__init__(*_layers)
+ self.cfg = cfg
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+def _resnet(
+ arch: str,
+ pretrained: bool,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_stride: List[int],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> ResNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = ResNet(num_blocks, output_channels, stage_stride, stage_conv, stage_pooling, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def _tv_resnet(
+ arch: str,
+ pretrained: bool,
+ arch_fn,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> TVResNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = arch_fn(**kwargs, weights=None)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ model.cfg = _cfg
+
+ return model
+
+
+def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+ """ResNet-18 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import resnet18
+ >>> model = resnet18(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A resnet18 model
+ """
+ return _tv_resnet(
+ "resnet18",
+ pretrained,
+ tv_resnet18,
+ ignore_keys=["fc.weight", "fc.bias"],
+ **kwargs,
+ )
+
+
+def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet31 architecture with rectangular pooling windows as described in
+ `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition",
+ `_. Downsizing: (H, W) --> (H/8, W/4)
+
+ >>> import torch
+ >>> from doctr.models import resnet31
+ >>> model = resnet31(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A resnet31 model
+ """
+ return _resnet(
+ "resnet31",
+ pretrained,
+ [1, 2, 5, 3],
+ [256, 256, 512, 512],
+ [1, 1, 1, 1],
+ [True] * 4,
+ [(2, 2), (2, 1), None, None],
+ origin_stem=False,
+ stem_channels=128,
+ ignore_keys=["13.weight", "13.bias"],
+ **kwargs,
+ )
+
+
+def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+ """ResNet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import resnet34
+ >>> model = resnet34(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A resnet34 model
+ """
+ return _tv_resnet(
+ "resnet34",
+ pretrained,
+ tv_resnet34,
+ ignore_keys=["fc.weight", "fc.bias"],
+ **kwargs,
+ )
+
+
+def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """ResNet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_ with twice as many output channels.
+
+ >>> import torch
+ >>> from doctr.models import resnet34_wide
+ >>> model = resnet34_wide(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A resnet34_wide model
+ """
+ return _resnet(
+ "resnet34_wide",
+ pretrained,
+ [3, 4, 6, 3],
+ [128, 256, 512, 1024],
+ [1, 2, 2, 2],
+ [False] * 4,
+ [None] * 4,
+ origin_stem=True,
+ stem_channels=128,
+ ignore_keys=["10.weight", "10.bias"],
+ **kwargs,
+ )
+
+
+def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+ """ResNet-50 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import resnet50
+ >>> model = resnet50(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A resnet50 model
+ """
+ return _tv_resnet(
+ "resnet50",
+ pretrained,
+ tv_resnet50,
+ ignore_keys=["fc.weight", "fc.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..7648e5f8d064390adae1f3b9b9386913498ef9a5
--- /dev/null
+++ b/doctr/models/classification/resnet/tensorflow.py
@@ -0,0 +1,395 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+from tensorflow.keras.applications import ResNet50
+from tensorflow.keras.models import Sequential
+
+from doctr.datasets import VOCABS
+
+from ...utils import conv_sequence, load_pretrained_params
+
+__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "resnet18": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
+ },
+ "resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
+ },
+ "resnet34": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
+ },
+ "resnet50": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
+ },
+ "resnet34_wide": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
+ },
+}
+
+
+class ResnetBlock(layers.Layer):
+ """Implements a resnet31 block with shortcut
+
+ Args:
+ ----
+ conv_shortcut: Use of shortcut
+ output_channels: number of channels to use in Conv2D
+ kernel_size: size of square kernels
+ strides: strides to use in the first convolution of the block
+ """
+
+ def __init__(self, output_channels: int, conv_shortcut: bool, strides: int = 1, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if conv_shortcut:
+ self.shortcut = Sequential([
+ layers.Conv2D(
+ filters=output_channels,
+ strides=strides,
+ padding="same",
+ kernel_size=1,
+ use_bias=False,
+ kernel_initializer="he_normal",
+ ),
+ layers.BatchNormalization(),
+ ])
+ else:
+ self.shortcut = layers.Lambda(lambda x: x)
+ self.conv_block = Sequential(self.conv_resnetblock(output_channels, 3, strides))
+ self.act = layers.Activation("relu")
+
+ @staticmethod
+ def conv_resnetblock(
+ output_channels: int,
+ kernel_size: int,
+ strides: int = 1,
+ ) -> List[layers.Layer]:
+ return [
+ *conv_sequence(output_channels, "relu", bn=True, strides=strides, kernel_size=kernel_size),
+ *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size),
+ ]
+
+ def call(self, inputs: tf.Tensor) -> tf.Tensor:
+ clone = self.shortcut(inputs)
+ conv_out = self.conv_block(inputs)
+ out = self.act(clone + conv_out)
+
+ return out
+
+
+def resnet_stage(
+ num_blocks: int, out_channels: int, shortcut: bool = False, downsample: bool = False
+) -> List[layers.Layer]:
+ _layers: List[layers.Layer] = [ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1)]
+
+ for _ in range(1, num_blocks):
+ _layers.append(ResnetBlock(out_channels, conv_shortcut=False))
+
+ return _layers
+
+
+class ResNet(Sequential):
+ """Implements a ResNet architecture
+
+ Args:
+ ----
+ num_blocks: number of resnet block in each stage
+ output_channels: number of channels in each stage
+ stage_downsample: whether the first residual block of a stage should downsample
+ stage_conv: whether to add a conv_sequence after each stage
+ stage_pooling: pooling to add after each stage (if None, no pooling)
+ origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+ stem_channels: number of output channels of the stem convolutions
+ attn_module: attention module to use in each stage
+ include_top: whether the classifier head should be instantiated
+ num_classes: number of output classes
+ input_shape: shape of inputs
+ """
+
+ def __init__(
+ self,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_downsample: List[bool],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ origin_stem: bool = True,
+ stem_channels: int = 64,
+ attn_module: Optional[Callable[[int], layers.Layer]] = None,
+ include_top: bool = True,
+ num_classes: int = 1000,
+ cfg: Optional[Dict[str, Any]] = None,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ ) -> None:
+ inplanes = stem_channels
+ if origin_stem:
+ _layers = [
+ *conv_sequence(inplanes, "relu", True, kernel_size=7, strides=2, input_shape=input_shape),
+ layers.MaxPool2D(pool_size=(3, 3), strides=2, padding="same"),
+ ]
+ else:
+ _layers = [
+ *conv_sequence(inplanes // 2, "relu", True, kernel_size=3, input_shape=input_shape),
+ *conv_sequence(inplanes, "relu", True, kernel_size=3),
+ layers.MaxPool2D(pool_size=2, strides=2, padding="valid"),
+ ]
+
+ for n_blocks, out_chan, down, conv, pool in zip(
+ num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling
+ ):
+ _layers.extend(resnet_stage(n_blocks, out_chan, out_chan != inplanes, down))
+ if attn_module is not None:
+ _layers.append(attn_module(out_chan))
+ if conv:
+ _layers.extend(conv_sequence(out_chan, activation="relu", bn=True, kernel_size=3))
+ if pool:
+ _layers.append(layers.MaxPool2D(pool_size=pool, strides=pool, padding="valid"))
+ inplanes = out_chan
+
+ if include_top:
+ _layers.extend([
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(num_classes),
+ ])
+
+ super().__init__(_layers)
+ self.cfg = cfg
+
+
+def _resnet(
+ arch: str,
+ pretrained: bool,
+ num_blocks: List[int],
+ output_channels: List[int],
+ stage_downsample: List[bool],
+ stage_conv: List[bool],
+ stage_pooling: List[Optional[Tuple[int, int]]],
+ origin_stem: bool = True,
+ **kwargs: Any,
+) -> ResNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = ResNet(
+ num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
+ )
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet-18 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import resnet18
+ >>> model = resnet18(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A classification model
+ """
+ return _resnet(
+ "resnet18",
+ pretrained,
+ [2, 2, 2, 2],
+ [64, 128, 256, 512],
+ [False, True, True, True],
+ [False] * 4,
+ [None] * 4,
+ True,
+ **kwargs,
+ )
+
+
+def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet31 architecture with rectangular pooling windows as described in
+ `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition",
+ `_. Downsizing: (H, W) --> (H/8, W/4)
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import resnet31
+ >>> model = resnet31(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A classification model
+ """
+ return _resnet(
+ "resnet31",
+ pretrained,
+ [1, 2, 5, 3],
+ [256, 256, 512, 512],
+ [False] * 4,
+ [True] * 4,
+ [(2, 2), (2, 1), None, None],
+ False,
+ stem_channels=128,
+ **kwargs,
+ )
+
+
+def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import resnet34
+ >>> model = resnet34(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A classification model
+ """
+ return _resnet(
+ "resnet34",
+ pretrained,
+ [3, 4, 6, 3],
+ [64, 128, 256, 512],
+ [False, True, True, True],
+ [False] * 4,
+ [None] * 4,
+ True,
+ **kwargs,
+ )
+
+
+def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet-50 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import resnet50
+ >>> model = resnet50(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A classification model
+ """
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs["resnet50"]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs["resnet50"]["classes"])
+
+ _cfg = deepcopy(default_cfgs["resnet50"])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ kwargs.pop("classes")
+
+ model = ResNet50(
+ weights=None,
+ include_top=True,
+ pooling=True,
+ input_shape=kwargs["input_shape"],
+ classes=kwargs["num_classes"],
+ classifier_activation=None,
+ )
+
+ model.cfg = _cfg
+
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs["resnet50"]["url"])
+
+ return model
+
+
+def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet:
+ """Resnet-34 architecture as described in `"Deep Residual Learning for Image Recognition",
+ `_ with twice as many output channels for each stage.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import resnet34_wide
+ >>> model = resnet34_wide(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the ResNet architecture
+
+ Returns:
+ -------
+ A classification model
+ """
+ return _resnet(
+ "resnet34_wide",
+ pretrained,
+ [3, 4, 6, 3],
+ [128, 256, 512, 1024],
+ [False, True, True, True],
+ [False] * 4,
+ [None] * 4,
+ True,
+ stem_channels=128,
+ **kwargs,
+ )
diff --git a/doctr/models/classification/textnet/__init__.py b/doctr/models/classification/textnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/classification/textnet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/classification/textnet/pytorch.py b/doctr/models/classification/textnet/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbb719f8bad20653ae3337d000db12f1d73168e
--- /dev/null
+++ b/doctr/models/classification/textnet/pytorch.py
@@ -0,0 +1,275 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+from torch import nn
+
+from doctr.datasets import VOCABS
+
+from ...modules.layers.pytorch import FASTConvLayer
+from ...utils import conv_sequence_pt, load_pretrained_params
+
+__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "textnet_tiny": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-27288d12.pt&src=0",
+ },
+ "textnet_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-43166ee6.pt&src=0",
+ },
+ "textnet_base": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-7f68d7e0.pt&src=0",
+ },
+}
+
+
+class TextNet(nn.Sequential):
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ Args:
+ ----
+ stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
+ include_top (bool, optional): Whether to include the classifier head. Defaults to True.
+ num_classes (int, optional): Number of output classes. Defaults to 1000.
+ cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ stages: List[Dict[str, List[int]]],
+ input_shape: Tuple[int, int, int] = (3, 32, 32),
+ num_classes: int = 1000,
+ include_top: bool = True,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers: List[nn.Module] = [
+ *conv_sequence_pt(
+ in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
+ ),
+ *[
+ nn.Sequential(*[
+ FASTConvLayer(**params) # type: ignore[arg-type]
+ for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
+ ])
+ for stage in stages
+ ],
+ ]
+
+ if include_top:
+ _layers.append(
+ nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Flatten(1),
+ nn.Linear(stages[-1]["out_channels"][-1], num_classes),
+ )
+ )
+
+ super().__init__(*_layers)
+ self.cfg = cfg
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+def _textnet(
+ arch: str,
+ pretrained: bool,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> TextNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = TextNet(**kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ model.cfg = _cfg
+
+ return model
+
+
+def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import torch
+ >>> from doctr.models import textnet_tiny
+ >>> model = textnet_tiny(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A textnet tiny model
+ """
+ return _textnet(
+ "textnet_tiny",
+ pretrained,
+ stages=[
+ {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
+ {
+ "in_channels": [64, 128, 128, 128],
+ "out_channels": [128] * 4,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
+ "stride": [2, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256],
+ "out_channels": [256] * 4,
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
+ "stride": [2, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512],
+ "out_channels": [512] * 4,
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
+ "stride": [2, 1, 1, 1],
+ },
+ ],
+ ignore_keys=["7.2.weight", "7.2.bias"],
+ **kwargs,
+ )
+
+
+def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import torch
+ >>> from doctr.models import textnet_small
+ >>> model = textnet_small(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A TextNet small model
+ """
+ return _textnet(
+ "textnet_small",
+ pretrained,
+ stages=[
+ {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
+ {
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
+ "out_channels": [128] * 8,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
+ "out_channels": [256] * 8,
+ "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512, 512],
+ "out_channels": [512] * 5,
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
+ "stride": [2, 1, 1, 1, 1],
+ },
+ ],
+ ignore_keys=["7.2.weight", "7.2.bias"],
+ **kwargs,
+ )
+
+
+def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import torch
+ >>> from doctr.models import textnet_base
+ >>> model = textnet_base(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A TextNet base model
+ """
+ return _textnet(
+ "textnet_base",
+ pretrained,
+ stages=[
+ {
+ "in_channels": [64] * 10,
+ "out_channels": [64] * 10,
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
+ "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
+ "out_channels": [128] * 10,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
+ "out_channels": [256] * 8,
+ "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512, 512],
+ "out_channels": [512] * 5,
+ "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
+ "stride": [2, 1, 1, 1, 1],
+ },
+ ],
+ ignore_keys=["7.2.weight", "7.2.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f30d5d823ce84e358e6d3d497d7091bea9db1b04
--- /dev/null
+++ b/doctr/models/classification/textnet/tensorflow.py
@@ -0,0 +1,267 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+from tensorflow.keras import Sequential, layers
+
+from doctr.datasets import VOCABS
+
+from ...modules.layers.tensorflow import FASTConvLayer
+from ...utils import conv_sequence, load_pretrained_params
+
+__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "textnet_tiny": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
+ },
+ "textnet_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
+ },
+ "textnet_base": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
+ },
+}
+
+
+class TextNet(Sequential):
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ Args:
+ ----
+ stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
+ include_top (bool, optional): Whether to include the classifier head. Defaults to True.
+ num_classes (int, optional): Number of output classes. Defaults to 1000.
+ cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ stages: List[Dict[str, List[int]]],
+ input_shape: Tuple[int, int, int] = (32, 32, 3),
+ num_classes: int = 1000,
+ include_top: bool = True,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers = [
+ *conv_sequence(
+ out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
+ ),
+ *[
+ Sequential(
+ [
+ FASTConvLayer(**params) # type: ignore[arg-type]
+ for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
+ ],
+ name=f"stage_{i}",
+ )
+ for i, stage in enumerate(stages)
+ ],
+ ]
+
+ if include_top:
+ _layers.append(
+ Sequential(
+ [
+ layers.AveragePooling2D(1),
+ layers.Flatten(),
+ layers.Dense(num_classes),
+ ],
+ name="classifier",
+ )
+ )
+
+ super().__init__(_layers)
+ self.cfg = cfg
+
+
+def _textnet(
+ arch: str,
+ pretrained: bool,
+ **kwargs: Any,
+) -> TextNet:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = TextNet(cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import textnet_tiny
+ >>> model = textnet_tiny(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A textnet tiny model
+ """
+ return _textnet(
+ "textnet_tiny",
+ pretrained,
+ stages=[
+ {"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
+ {
+ "in_channels": [64, 128, 128, 128],
+ "out_channels": [128] * 4,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
+ "stride": [2, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256],
+ "out_channels": [256] * 4,
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
+ "stride": [2, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512],
+ "out_channels": [512] * 4,
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
+ "stride": [2, 1, 1, 1],
+ },
+ ],
+ **kwargs,
+ )
+
+
+def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import textnet_small
+ >>> model = textnet_small(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A TextNet small model
+ """
+ return _textnet(
+ "textnet_small",
+ pretrained,
+ stages=[
+ {"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
+ {
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
+ "out_channels": [128] * 8,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
+ "out_channels": [256] * 8,
+ "kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512, 512],
+ "out_channels": [512] * 5,
+ "kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
+ "stride": [2, 1, 1, 1, 1],
+ },
+ ],
+ **kwargs,
+ )
+
+
+def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
+ """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+ Minimalist Kernel Representation" `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import textnet_base
+ >>> model = textnet_base(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the TextNet architecture
+
+ Returns:
+ -------
+ A TextNet base model
+ """
+ return _textnet(
+ "textnet_base",
+ pretrained,
+ stages=[
+ {
+ "in_channels": [64] * 10,
+ "out_channels": [64] * 10,
+ "kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
+ "stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
+ "out_channels": [128] * 10,
+ "kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
+ "out_channels": [256] * 8,
+ "kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
+ "stride": [2, 1, 1, 1, 1, 1, 1, 1],
+ },
+ {
+ "in_channels": [256, 512, 512, 512, 512],
+ "out_channels": [512] * 5,
+ "kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
+ "stride": [2, 1, 1, 1, 1],
+ },
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/vgg/__init__.py b/doctr/models/classification/vgg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..64556e403a5697432f805a5af28dab812fa8b932
--- /dev/null
+++ b/doctr/models/classification/vgg/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import *
diff --git a/doctr/models/classification/vgg/pytorch.py b/doctr/models/classification/vgg/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e16b1178841f3096daa9acd7c920fea4a178c7f
--- /dev/null
+++ b/doctr/models/classification/vgg/pytorch.py
@@ -0,0 +1,95 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+from torch import nn
+from torchvision.models import vgg as tv_vgg
+
+from doctr.datasets import VOCABS
+
+from ...utils import load_pretrained_params
+
+__all__ = ["vgg16_bn_r"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "vgg16_bn_r": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-d108c19c.pt&src=0",
+ },
+}
+
+
+def _vgg(
+ arch: str,
+ pretrained: bool,
+ tv_arch: str,
+ num_rect_pools: int = 3,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> tv_vgg.VGG:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None)
+ # List the MaxPool2d
+ pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
+ # Replace their kernel with rectangular ones
+ for idx in pool_idcs[-num_rect_pools:]:
+ model.features[idx] = nn.MaxPool2d((2, 1))
+ # Patch average pool & classification head
+ model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ model.classifier = nn.Linear(512, kwargs["num_classes"])
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ model.cfg = _cfg
+
+ return model
+
+
+def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG:
+ """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
+ `_, modified by adding batch normalization, rectangular pooling and a simpler
+ classification head.
+
+ >>> import torch
+ >>> from doctr.models import vgg16_bn_r
+ >>> model = vgg16_bn_r(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ **kwargs: keyword arguments of the VGG architecture
+
+ Returns:
+ -------
+ VGG feature extractor
+ """
+ return _vgg(
+ "vgg16_bn_r",
+ pretrained,
+ "vgg16_bn",
+ 3,
+ ignore_keys=["classifier.weight", "classifier.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..259ed9f88875fb00c6c006a314f0210e77857bd8
--- /dev/null
+++ b/doctr/models/classification/vgg/tensorflow.py
@@ -0,0 +1,113 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+from tensorflow.keras import layers
+from tensorflow.keras.models import Sequential
+
+from doctr.datasets import VOCABS
+
+from ...utils import conv_sequence, load_pretrained_params
+
+__all__ = ["VGG", "vgg16_bn_r"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "vgg16_bn_r": {
+ "mean": (0.5, 0.5, 0.5),
+ "std": (1.0, 1.0, 1.0),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
+ },
+}
+
+
+class VGG(Sequential):
+ """Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
+ `_.
+
+ Args:
+ ----
+ num_blocks: number of convolutional block in each stage
+ planes: number of output channels in each stage
+ rect_pools: whether pooling square kernels should be replace with rectangular ones
+ include_top: whether the classifier head should be instantiated
+ num_classes: number of output classes
+ input_shape: shapes of the input tensor
+ """
+
+ def __init__(
+ self,
+ num_blocks: List[int],
+ planes: List[int],
+ rect_pools: List[bool],
+ include_top: bool = False,
+ num_classes: int = 1000,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers = []
+ # Specify input_shape only for the first layer
+ kwargs = {"input_shape": input_shape}
+ for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools):
+ for _ in range(nb_blocks):
+ _layers.extend(conv_sequence(out_chan, "relu", True, kernel_size=3, **kwargs)) # type: ignore[arg-type]
+ kwargs = {}
+ _layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2)))
+
+ if include_top:
+ _layers.extend([layers.GlobalAveragePooling2D(), layers.Dense(num_classes)])
+ super().__init__(_layers)
+ self.cfg = cfg
+
+
+def _vgg(
+ arch: str, pretrained: bool, num_blocks: List[int], planes: List[int], rect_pools: List[bool], **kwargs: Any
+) -> VGG:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["classes"] = kwargs["classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG:
+ """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
+ `_, modified by adding batch normalization, rectangular pooling and a simpler
+ classification head.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import vgg16_bn_r
+ >>> model = vgg16_bn_r(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ **kwargs: keyword arguments of the VGG architecture
+
+ Returns:
+ -------
+ VGG feature extractor
+ """
+ return _vgg(
+ "vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs
+ )
diff --git a/doctr/models/classification/vit/__init__.py b/doctr/models/classification/vit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/classification/vit/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/classification/vit/pytorch.py b/doctr/models/classification/vit/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..335e92559fd9fa44de136323a6b50456eb605c7a
--- /dev/null
+++ b/doctr/models/classification/vit/pytorch.py
@@ -0,0 +1,195 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.transformer import EncoderBlock
+from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding
+
+from ...utils.pytorch import load_pretrained_params
+
+__all__ = ["vit_s", "vit_b"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "vit_s": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-5d05442d.pt&src=0",
+ },
+ "vit_b": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-0fbef167.pt&src=0",
+ },
+}
+
+
+class ClassifierHead(nn.Module):
+ """Classifier head for Vision Transformer
+
+ Args:
+ ----
+ in_channels: number of input channels
+ num_classes: number of output classes
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_classes: int,
+ ) -> None:
+ super().__init__()
+
+ self.head = nn.Linear(in_channels, num_classes)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # (batch_size, num_classes) cls token
+ return self.head(x[:, 0])
+
+
+class VisionTransformer(nn.Sequential):
+ """VisionTransformer architecture as described in
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_.
+
+ Args:
+ ----
+ d_model: dimension of the transformer layers
+ num_layers: number of transformer layers
+ num_heads: number of attention heads
+ ffd_ratio: multiplier for the hidden dimension of the feedforward layer
+ patch_size: size of the patches
+ input_shape: size of the input image
+ dropout: dropout rate
+ num_classes: number of output classes
+ include_top: whether the classifier head should be instantiated
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_layers: int,
+ num_heads: int,
+ ffd_ratio: int,
+ patch_size: Tuple[int, int] = (4, 4),
+ input_shape: Tuple[int, int, int] = (3, 32, 32),
+ dropout: float = 0.0,
+ num_classes: int = 1000,
+ include_top: bool = True,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers: List[nn.Module] = [
+ PatchEmbedding(input_shape, d_model, patch_size),
+ EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
+ ]
+ if include_top:
+ _layers.append(ClassifierHead(d_model, num_classes))
+
+ super().__init__(*_layers)
+ self.cfg = cfg
+
+
+def _vit(
+ arch: str,
+ pretrained: bool,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> VisionTransformer:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = VisionTransformer(cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
+ """VisionTransformer-S architecture
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_. Patches: (H, W) -> (H/8, W/8)
+
+ NOTE: unofficial config used in ViTSTR and ParSeq
+
+ >>> import torch
+ >>> from doctr.models import vit_s
+ >>> model = vit_s(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the VisionTransformer architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _vit(
+ "vit_s",
+ pretrained,
+ d_model=384,
+ num_layers=12,
+ num_heads=6,
+ ffd_ratio=4,
+ ignore_keys=["2.head.weight", "2.head.bias"],
+ **kwargs,
+ )
+
+
+def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
+ """VisionTransformer-B architecture as described in
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_. Patches: (H, W) -> (H/8, W/8)
+
+ >>> import torch
+ >>> from doctr.models import vit_b
+ >>> model = vit_b(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the VisionTransformer architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _vit(
+ "vit_b",
+ pretrained,
+ d_model=768,
+ num_layers=12,
+ num_heads=12,
+ ffd_ratio=4,
+ ignore_keys=["2.head.weight", "2.head.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b73b49ac9c11dd430fe7b8ca1f612b745122f14
--- /dev/null
+++ b/doctr/models/classification/vit/tensorflow.py
@@ -0,0 +1,192 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import Sequential, layers
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.transformer import EncoderBlock
+from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
+from doctr.utils.repr import NestedObject
+
+from ...utils import load_pretrained_params
+
+__all__ = ["vit_s", "vit_b"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "vit_s": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 32),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0",
+ },
+ "vit_b": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 32, 3),
+ "classes": list(VOCABS["french"]),
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0",
+ },
+}
+
+
+class ClassifierHead(layers.Layer, NestedObject):
+ """Classifier head for Vision Transformer
+
+ Args:
+ ----
+ num_classes: number of output classes
+ """
+
+ def __init__(self, num_classes: int) -> None:
+ super().__init__()
+
+ self.head = layers.Dense(num_classes, kernel_initializer="he_normal", name="dense")
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ # (batch_size, num_classes) cls token
+ return self.head(x[:, 0])
+
+
+class VisionTransformer(Sequential):
+ """VisionTransformer architecture as described in
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_.
+
+ Args:
+ ----
+ d_model: dimension of the transformer layers
+ num_layers: number of transformer layers
+ num_heads: number of attention heads
+ ffd_ratio: multiplier for the hidden dimension of the feedforward layer
+ patch_size: size of the patches
+ input_shape: size of the input image
+ dropout: dropout rate
+ num_classes: number of output classes
+ include_top: whether the classifier head should be instantiated
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_layers: int,
+ num_heads: int,
+ ffd_ratio: int,
+ patch_size: Tuple[int, int] = (4, 4),
+ input_shape: Tuple[int, int, int] = (32, 32, 3),
+ dropout: float = 0.0,
+ num_classes: int = 1000,
+ include_top: bool = True,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ _layers = [
+ PatchEmbedding(input_shape, d_model, patch_size),
+ EncoderBlock(
+ num_layers,
+ num_heads,
+ d_model,
+ d_model * ffd_ratio,
+ dropout,
+ activation_fct=layers.Activation("gelu"),
+ ),
+ ]
+ if include_top:
+ _layers.append(ClassifierHead(num_classes))
+
+ super().__init__(_layers)
+ self.cfg = cfg
+
+
+def _vit(
+ arch: str,
+ pretrained: bool,
+ **kwargs: Any,
+) -> VisionTransformer:
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["num_classes"] = kwargs["num_classes"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+ _cfg["classes"] = kwargs["classes"]
+ kwargs.pop("classes")
+
+ # Build the model
+ model = VisionTransformer(cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
+ """VisionTransformer-S architecture
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_. Patches: (H, W) -> (H/8, W/8)
+
+ NOTE: unofficial config used in ViTSTR and ParSeq
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import vit_s
+ >>> model = vit_s(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the VisionTransformer architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _vit(
+ "vit_s",
+ pretrained,
+ d_model=384,
+ num_layers=12,
+ num_heads=6,
+ ffd_ratio=4,
+ **kwargs,
+ )
+
+
+def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
+ """VisionTransformer-B architecture as described in
+ `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
+ `_. Patches: (H, W) -> (H/8, W/8)
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import vit_b
+ >>> model = vit_b(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained: boolean, True if model is pretrained
+ **kwargs: keyword arguments of the VisionTransformer architecture
+
+ Returns:
+ -------
+ A feature extractor model
+ """
+ return _vit(
+ "vit_b",
+ pretrained,
+ d_model=768,
+ num_layers=12,
+ num_heads=12,
+ ffd_ratio=4,
+ **kwargs,
+ )
diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6179ff976aee941342dc9fbd3756764644e47de5
--- /dev/null
+++ b/doctr/models/classification/zoo.py
@@ -0,0 +1,74 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List
+
+from doctr.file_utils import is_tf_available
+
+from .. import classification
+from ..preprocessor import PreProcessor
+from .predictor import CropOrientationPredictor
+
+__all__ = ["crop_orientation_predictor"]
+
+ARCHS: List[str] = [
+ "magc_resnet31",
+ "mobilenet_v3_small",
+ "mobilenet_v3_small_r",
+ "mobilenet_v3_large",
+ "mobilenet_v3_large_r",
+ "resnet18",
+ "resnet31",
+ "resnet34",
+ "resnet50",
+ "resnet34_wide",
+ "textnet_tiny",
+ "textnet_small",
+ "textnet_base",
+ "vgg16_bn_r",
+ "vit_s",
+ "vit_b",
+]
+ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
+
+
+def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
+ if arch not in ORIENTATION_ARCHS:
+ raise ValueError(f"unknown architecture '{arch}'")
+
+ # Load directly classifier from backbone
+ _model = classification.__dict__[arch](pretrained=pretrained)
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
+ kwargs["batch_size"] = kwargs.get("batch_size", 128)
+ input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
+ predictor = CropOrientationPredictor(
+ PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
+ )
+ return predictor
+
+
+def crop_orientation_predictor(
+ arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
+) -> CropOrientationPredictor:
+ """Orientation classification architecture.
+
+ >>> import numpy as np
+ >>> from doctr.models import crop_orientation_predictor
+ >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True)
+ >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
+ >>> out = model([input_crop])
+
+ Args:
+ ----
+ arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
+ pretrained: If True, returns a model pre-trained on our recognition crops dataset
+ **kwargs: keyword arguments to be passed to the CropOrientationPredictor
+
+ Returns:
+ -------
+ CropOrientationPredictor
+ """
+ return _crop_orientation_predictor(arch, pretrained, **kwargs)
diff --git a/doctr/models/core.py b/doctr/models/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..a05aee7aa9f22dbe2a56511699e9977129f1bd99
--- /dev/null
+++ b/doctr/models/core.py
@@ -0,0 +1,19 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+
+from typing import Any, Dict, Optional
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["BaseModel"]
+
+
+class BaseModel(NestedObject):
+ """Implements abstract DetectionModel class"""
+
+ def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None:
+ super().__init__()
+ self.cfg = cfg
diff --git a/doctr/models/detection/__init__.py b/doctr/models/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b09e4395eb6d15a992960894fddfb582dbbd64db
--- /dev/null
+++ b/doctr/models/detection/__init__.py
@@ -0,0 +1,4 @@
+from .differentiable_binarization import *
+from .linknet import *
+from .fast import *
+from .zoo import *
diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2bfef5e652335569d8da45965b4c64fe56c141
--- /dev/null
+++ b/doctr/models/detection/_utils/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available
+
+if is_tf_available():
+ from .tensorflow import *
+else:
+ from .pytorch import *
diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac99f4690993c1d4ec6d3563cecfccb07064c85
--- /dev/null
+++ b/doctr/models/detection/_utils/pytorch.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from torch import Tensor
+from torch.nn.functional import max_pool2d
+
+__all__ = ["erode", "dilate"]
+
+
+def erode(x: Tensor, kernel_size: int) -> Tensor:
+ """Performs erosion on a given tensor
+
+ Args:
+ ----
+ x: boolean tensor of shape (N, C, H, W)
+ kernel_size: the size of the kernel to use for erosion
+
+ Returns:
+ -------
+ the eroded tensor
+ """
+ _pad = (kernel_size - 1) // 2
+
+ return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad)
+
+
+def dilate(x: Tensor, kernel_size: int) -> Tensor:
+ """Performs dilation on a given tensor
+
+ Args:
+ ----
+ x: boolean tensor of shape (N, C, H, W)
+ kernel_size: the size of the kernel to use for dilation
+
+ Returns:
+ -------
+ the dilated tensor
+ """
+ _pad = (kernel_size - 1) // 2
+
+ return max_pool2d(x, kernel_size, stride=1, padding=_pad)
diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f5ec217493d0c1c6d96b834a83e45af65d8481e
--- /dev/null
+++ b/doctr/models/detection/_utils/tensorflow.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import tensorflow as tf
+
+__all__ = ["erode", "dilate"]
+
+
+def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
+ """Performs erosion on a given tensor
+
+ Args:
+ ----
+ x: boolean tensor of shape (N, H, W, C)
+ kernel_size: the size of the kernel to use for erosion
+
+ Returns:
+ -------
+ the eroded tensor
+ """
+ return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME")
+
+
+def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor:
+ """Performs dilation on a given tensor
+
+ Args:
+ ----
+ x: boolean tensor of shape (N, H, W, C)
+ kernel_size: the size of the kernel to use for dilation
+
+ Returns:
+ -------
+ the dilated tensor
+ """
+ return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME")
diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..63fa78615162ceca0dd11d71a64dc2c8edee4af5
--- /dev/null
+++ b/doctr/models/detection/core.py
@@ -0,0 +1,101 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List
+
+import cv2
+import numpy as np
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["DetectionPostProcessor"]
+
+
+class DetectionPostProcessor(NestedObject):
+ """Abstract class to postprocess the raw output of the model
+
+ Args:
+ ----
+ box_thresh (float): minimal objectness score to consider a box
+ bin_thresh (float): threshold to apply to segmentation raw heatmap
+ assume straight_pages (bool): if True, fit straight boxes only
+ """
+
+ def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None:
+ self.box_thresh = box_thresh
+ self.bin_thresh = bin_thresh
+ self.assume_straight_pages = assume_straight_pages
+ self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8)
+
+ def extra_repr(self) -> str:
+ return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}"
+
+ @staticmethod
+ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float:
+ """Compute the confidence score for a polygon : mean of the p values on the polygon
+
+ Args:
+ ----
+ pred (np.ndarray): p map returned by the model
+ points: coordinates of the polygon
+ assume_straight_pages: if True, fit straight boxes only
+
+ Returns:
+ -------
+ polygon objectness
+ """
+ h, w = pred.shape[:2]
+
+ if assume_straight_pages:
+ xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1)
+ xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1)
+ ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1)
+ ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1)
+ return pred[ymin : ymax + 1, xmin : xmax + 1].mean()
+
+ else:
+ mask: np.ndarray = np.zeros((h, w), np.int32)
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
+ product = pred * mask
+ return np.sum(product) / np.count_nonzero(product)
+
+ def bitmap_to_boxes(
+ self,
+ pred: np.ndarray,
+ bitmap: np.ndarray,
+ ) -> np.ndarray:
+ raise NotImplementedError
+
+ def __call__(
+ self,
+ proba_map,
+ ) -> List[List[np.ndarray]]:
+ """Performs postprocessing for a list of model outputs
+
+ Args:
+ ----
+ proba_map: probability map of shape (N, H, W, C)
+
+ Returns:
+ -------
+ list of N class predictions (for each input sample), where each class predictions is a list of C tensors
+ of shape (*, 5) or (*, 6)
+ """
+ if proba_map.ndim != 4:
+ raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.")
+
+ # Erosion + dilation on the binary map
+ bin_map = [
+ [
+ cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel)
+ for idx in range(proba_map.shape[-1])
+ ]
+ for bmap in (proba_map >= self.bin_thresh).astype(np.uint8)
+ ]
+
+ return [
+ [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])]
+ for pmaps, bmaps in zip(proba_map, bin_map)
+ ]
diff --git a/doctr/models/detection/differentiable_binarization/__init__.py b/doctr/models/detection/differentiable_binarization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/detection/differentiable_binarization/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f03a2e1bf52861f8e237e9ac63b007f93dd379e
--- /dev/null
+++ b/doctr/models/detection/differentiable_binarization/base.py
@@ -0,0 +1,375 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from ..core import DetectionPostProcessor
+
+__all__ = ["DBPostProcessor"]
+
+
+class DBPostProcessor(DetectionPostProcessor):
+ """Implements a post processor for DBNet adapted from the implementation of `xuannianz
+ `_.
+
+ Args:
+ ----
+ unclip ratio: ratio used to unshrink polygons
+ min_size_box: minimal length (pix) to keep a box
+ max_candidates: maximum boxes to consider in a single page
+ box_thresh: minimal objectness score to consider a box
+ bin_thresh: threshold used to binzarized p_map at inference time
+
+ """
+
+ def __init__(
+ self,
+ box_thresh: float = 0.1,
+ bin_thresh: float = 0.3,
+ assume_straight_pages: bool = True,
+ ) -> None:
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
+ self.unclip_ratio = 1.5
+
+ def polygon_to_box(
+ self,
+ points: np.ndarray,
+ ) -> np.ndarray:
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
+
+ Args:
+ ----
+ points: The first parameter.
+
+ Returns:
+ -------
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
+ """
+ if not self.assume_straight_pages:
+ # Compute the rectangle polygon enclosing the raw polygon
+ rect = cv2.minAreaRect(points)
+ points = cv2.boxPoints(rect)
+ # Add 1 pixel to correct cv2 approx
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
+ else:
+ poly = Polygon(points)
+ area = poly.area
+ length = poly.length
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ _points = offset.Execute(distance)
+ # Take biggest stack of points
+ idx = 0
+ if len(_points) > 1:
+ max_size = 0
+ for _idx, p in enumerate(_points):
+ if len(p) > max_size:
+ idx = _idx
+ max_size = len(p)
+ # We ensure that _points can be correctly casted to a ndarray
+ _points = [_points[idx]]
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
+ if len(expanded_points) < 1:
+ return None # type: ignore[return-value]
+ return (
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
+ if self.assume_straight_pages
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
+ )
+
+ def bitmap_to_boxes(
+ self,
+ pred: np.ndarray,
+ bitmap: np.ndarray,
+ ) -> np.ndarray:
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
+
+ Args:
+ ----
+ pred: Pred map from differentiable binarization output
+ bitmap: Bitmap map computed from pred (binarized)
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
+
+ Returns:
+ -------
+ np tensor boxes for the bitmap, each box is a 5-element list
+ containing x, y, w, h, score for the box
+ """
+ height, width = bitmap.shape[:2]
+ min_size_box = 2
+ boxes: List[Union[np.ndarray, List[float]]] = []
+ # get contours from connected components on the bitmap
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ for contour in contours:
+ # Check whether smallest enclosing bounding box is not too small
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box):
+ continue
+ # Compute objectness
+ if self.assume_straight_pages:
+ x, y, w, h = cv2.boundingRect(contour)
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
+ score = self.box_score(pred, points, assume_straight_pages=True)
+ else:
+ score = self.box_score(pred, contour, assume_straight_pages=False)
+
+ if score < self.box_thresh: # remove polygons with a weak objectness
+ continue
+
+ if self.assume_straight_pages:
+ _box = self.polygon_to_box(points)
+ else:
+ _box = self.polygon_to_box(np.squeeze(contour))
+
+ # Remove too small boxes
+ if self.assume_straight_pages:
+ if _box is None or _box[2] < min_size_box or _box[3] < min_size_box:
+ continue
+ elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box:
+ continue
+
+ if self.assume_straight_pages:
+ x, y, w, h = _box
+ # compute relative polygon to get rid of img shape
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
+ boxes.append([xmin, ymin, xmax, ymax, score])
+ else:
+ # compute relative box to get rid of img shape, in that case _box is a 4pt polygon
+ if not isinstance(_box, np.ndarray) and _box.shape == (4, 2):
+ raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)")
+ _box[:, 0] /= width
+ _box[:, 1] /= height
+ boxes.append(_box)
+
+ if not self.assume_straight_pages:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
+ else:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
+
+
+class _DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ fpn_channels: number of channels each extracted feature maps is mapped to
+ """
+
+ shrink_ratio = 0.4
+ thresh_min = 0.3
+ thresh_max = 0.7
+ min_size_box = 3
+ assume_straight_pages: bool = True
+
+ @staticmethod
+ def compute_distance(
+ xs: np.ndarray,
+ ys: np.ndarray,
+ a: np.ndarray,
+ b: np.ndarray,
+ eps: float = 1e-6,
+ ) -> float:
+ """Compute the distance for each point of the map (xs, ys) to the (a, b) segment
+
+ Args:
+ ----
+ xs : map of x coordinates (height, width)
+ ys : map of y coordinates (height, width)
+ a: first point defining the [ab] segment
+ b: second point defining the [ab] segment
+ eps: epsilon to avoid division by zero
+
+ Returns:
+ -------
+ The computed distance
+
+ """
+ square_dist_1 = np.square(xs - a[0]) + np.square(ys - a[1])
+ square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1])
+ square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1])
+ cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps)
+ cosin = np.clip(cosin, -1.0, 1.0)
+ square_sin = 1 - np.square(cosin)
+ square_sin = np.nan_to_num(square_sin)
+ result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist + eps)
+ result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0]
+ return result
+
+ def draw_thresh_map(
+ self,
+ polygon: np.ndarray,
+ canvas: np.ndarray,
+ mask: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Draw a polygon treshold map on a canvas, as described in the DB paper
+
+ Args:
+ ----
+ polygon : array of coord., to draw the boundary of the polygon
+ canvas : threshold map to fill with polygons
+ mask : mask for training on threshold polygons
+ """
+ if polygon.ndim != 2 or polygon.shape[1] != 2:
+ raise AttributeError("polygon should be a 2 dimensional array of coords")
+
+ # Augment polygon by shrink_ratio
+ polygon_shape = Polygon(polygon)
+ distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ subject = [tuple(coor) for coor in polygon] # Get coord as list of tuples
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
+
+ # Fill the mask with 1 on the new padded polygon
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
+
+ # Get min/max to recover polygon after distance computation
+ xmin = padded_polygon[:, 0].min()
+ xmax = padded_polygon[:, 0].max()
+ ymin = padded_polygon[:, 1].min()
+ ymax = padded_polygon[:, 1].max()
+ width = xmax - xmin + 1
+ height = ymax - ymin + 1
+ # Get absolute polygon for distance computation
+ polygon[:, 0] = polygon[:, 0] - xmin
+ polygon[:, 1] = polygon[:, 1] - ymin
+ # Get absolute padded polygon
+ xs: np.ndarray = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
+ ys: np.ndarray = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
+
+ # Compute distance map to fill the padded polygon
+ distance_map = np.zeros((polygon.shape[0], height, width), dtype=polygon.dtype)
+ for i in range(polygon.shape[0]):
+ j = (i + 1) % polygon.shape[0]
+ absolute_distance = self.compute_distance(xs, ys, polygon[i], polygon[j])
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+ distance_map = np.min(distance_map, axis=0)
+
+ # Clip the padded polygon inside the canvas
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
+
+ # Fill the canvas with the distances computed inside the valid padded polygon
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
+ 1
+ - distance_map[
+ ymin_valid - ymin : ymax_valid - ymax + height, xmin_valid - xmin : xmax_valid - xmax + width
+ ],
+ canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
+ )
+
+ return polygon, canvas, mask
+
+ def build_target(
+ self,
+ target: List[Dict[str, np.ndarray]],
+ output_shape: Tuple[int, int, int],
+ channels_last: bool = True,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
+ raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
+ if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
+ raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
+
+ input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32
+
+ h: int
+ w: int
+ if channels_last:
+ h, w, num_classes = output_shape
+ else:
+ num_classes, h, w = output_shape
+ target_shape = (len(target), num_classes, h, w)
+
+ seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+ seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
+ thresh_target: np.ndarray = np.zeros(target_shape, dtype=np.float32)
+ thresh_mask: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+
+ for idx, tgt in enumerate(target):
+ for class_idx, _tgt in enumerate(tgt.values()):
+ # Draw each polygon on gt
+ if _tgt.shape[0] == 0:
+ # Empty image, full masked
+ seg_mask[idx, class_idx] = False
+
+ # Absolute bounding boxes
+ abs_boxes = _tgt.copy()
+ if abs_boxes.ndim == 3:
+ abs_boxes[:, :, 0] *= w
+ abs_boxes[:, :, 1] *= h
+ polys = abs_boxes
+ boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
+ abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
+ else:
+ abs_boxes[:, [0, 2]] *= w
+ abs_boxes[:, [1, 3]] *= h
+ abs_boxes = abs_boxes.round().astype(np.int32)
+ polys = np.stack(
+ [
+ abs_boxes[:, [0, 1]],
+ abs_boxes[:, [0, 3]],
+ abs_boxes[:, [2, 3]],
+ abs_boxes[:, [2, 1]],
+ ],
+ axis=1,
+ )
+ boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
+
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
+ # Mask boxes that are too small
+ if box_size < self.min_size_box:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+
+ # Negative shrink for gt, as described in paper
+ polygon = Polygon(poly)
+ distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
+ subject = [tuple(coor) for coor in poly]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ shrunken = padding.Execute(-distance)
+
+ # Draw polygon on gt if it is valid
+ if len(shrunken) == 0:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
+
+ # Draw on both thresh map and thresh mask
+ poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
+ poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
+ )
+ if channels_last:
+ seg_target = seg_target.transpose((0, 2, 3, 1))
+ seg_mask = seg_mask.transpose((0, 2, 3, 1))
+ thresh_target = thresh_target.transpose((0, 2, 3, 1))
+ thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
+
+ thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
+
+ seg_target = seg_target.astype(input_dtype)
+ seg_mask = seg_mask.astype(bool)
+ thresh_target = thresh_target.astype(input_dtype)
+ thresh_mask = thresh_mask.astype(bool)
+
+ return seg_target, seg_mask, thresh_target, thresh_mask
diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7bdd8efe204a22476eb4048108a2d6c6c459ea
--- /dev/null
+++ b/doctr/models/detection/differentiable_binarization/pytorch.py
@@ -0,0 +1,435 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Callable, Dict, List, Optional
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models import resnet34, resnet50
+from torchvision.models._utils import IntermediateLayerGetter
+from torchvision.ops.deform_conv import DeformConv2d
+
+from doctr.file_utils import CLASS_NAME
+
+from ...classification import mobilenet_v3_large
+from ...utils import _bf16_to_float32, load_pretrained_params
+from .base import DBPostProcessor, _DBNet
+
+__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "db_resnet50": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0",
+ },
+ "db_resnet34": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet34-cb6aed9e.pt&src=0",
+ },
+ "db_mobilenet_v3_large": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-81e9b152.pt&src=0",
+ },
+}
+
+
+class FeaturePyramidNetwork(nn.Module):
+ def __init__(
+ self,
+ in_channels: List[int],
+ out_channels: int,
+ deform_conv: bool = False,
+ ) -> None:
+ super().__init__()
+
+ out_chans = out_channels // len(in_channels)
+
+ conv_layer = DeformConv2d if deform_conv else nn.Conv2d
+
+ self.in_branches = nn.ModuleList([
+ nn.Sequential(
+ conv_layer(chans, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+ for idx, chans in enumerate(in_channels)
+ ])
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+ self.out_branches = nn.ModuleList([
+ nn.Sequential(
+ conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(inplace=True),
+ nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
+ )
+ for idx, chans in enumerate(in_channels)
+ ])
+
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ if len(x) != len(self.out_branches):
+ raise AssertionError
+ # Conv1x1 to get the same number of channels
+ _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)]
+ out: List[torch.Tensor] = [_x[-1]]
+ for t in _x[:-1][::-1]:
+ out.append(self.upsample(out[-1]) + t)
+
+ # Conv and final upsampling
+ out = [branch(t) for branch, t in zip(self.out_branches, out[::-1])]
+
+ return torch.cat(out, dim=1)
+
+
+class DBNet(_DBNet, nn.Module):
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ head_chans: the number of channels in the head
+ deform_conv: whether to use deformable convolution
+ bin_thresh: threshold for binarization
+ box_thresh: minimal objectness score to consider a box
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ def __init__(
+ self,
+ feat_extractor: IntermediateLayerGetter,
+ head_chans: int = 256,
+ deform_conv: bool = False,
+ bin_thresh: float = 0.3,
+ box_thresh: float = 0.1,
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__()
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+ self.cfg = cfg
+
+ conv_layer = DeformConv2d if deform_conv else nn.Conv2d
+
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.feat_extractor = feat_extractor
+ # Identify the number of channels for the head initialization
+ _is_training = self.feat_extractor.training
+ self.feat_extractor = self.feat_extractor.eval()
+ with torch.no_grad():
+ out = self.feat_extractor(torch.zeros((1, 3, 224, 224)))
+ fpn_channels = [v.shape[1] for _, v in out.items()]
+
+ if _is_training:
+ self.feat_extractor = self.feat_extractor.train()
+
+ self.fpn = FeaturePyramidNetwork(fpn_channels, head_chans, deform_conv)
+ # Conv1 map to channels
+
+ self.prob_head = nn.Sequential(
+ conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False),
+ nn.BatchNorm2d(head_chans // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False),
+ nn.BatchNorm2d(head_chans // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
+ )
+ self.thresh_head = nn.Sequential(
+ conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False),
+ nn.BatchNorm2d(head_chans // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False),
+ nn.BatchNorm2d(head_chans // 4),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2),
+ )
+
+ self.postprocessor = DBPostProcessor(
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, (nn.Conv2d, DeformConv2d)):
+ nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[np.ndarray]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, torch.Tensor]:
+ # Extract feature maps at different stages
+ feats = self.feat_extractor(x)
+ feats = [feats[str(idx)] for idx in range(len(feats))]
+ # Pass through the FPN
+ feat_concat = self.fpn(feats)
+ logits = self.prob_head(feat_concat)
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(torch.sigmoid(logits))
+
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes (keep only text predictions)
+ out["preds"] = [
+ dict(zip(self.class_names, preds))
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
+ ]
+
+ if target is not None:
+ thresh_map = self.thresh_head(feat_concat)
+ loss = self.compute_loss(logits, thresh_map, target)
+ out["loss"] = loss
+
+ return out
+
+ def compute_loss(
+ self,
+ out_map: torch.Tensor,
+ thresh_map: torch.Tensor,
+ target: List[np.ndarray],
+ gamma: float = 2.0,
+ alpha: float = 0.5,
+ eps: float = 1e-8,
+ ) -> torch.Tensor:
+ """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
+ and a list of masks for each image. From there it computes the loss with the model output
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape (N, C, H, W)
+ thresh_map: threshold map of shape (N, C, H, W)
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ gamma: modulating factor in the focal loss formula
+ alpha: balancing factor in the focal loss formula
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ if gamma < 0:
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
+
+ prob_map = torch.sigmoid(out_map)
+ thresh_map = torch.sigmoid(thresh_map)
+
+ targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
+
+ seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
+ seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
+ thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
+ thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)
+
+ if torch.any(seg_mask):
+ # Focal loss
+ focal_scale = 10.0
+ bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")
+
+ p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
+ alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
+ # Unreduced version
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
+ # Class reduced
+ focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
+
+ # Compute dice loss for each class or for approx binary_map
+ if len(self.class_names) > 1:
+ dice_map = torch.softmax(out_map, dim=1)
+ else:
+ # compute binary map instead
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
+ # Class reduced
+ inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
+ cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
+ dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
+
+ # Compute l1 loss for thresh_map
+ if torch.any(thresh_mask):
+ l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)
+
+ return l1_loss + focal_scale * focal_loss + dice_loss
+
+
+def _dbnet(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ fpn_layers: List[str],
+ backbone_submodule: Optional[str] = None,
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> DBNet:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Feature extractor
+ backbone = (
+ backbone_fn(pretrained_backbone)
+ if not arch.split("_")[1].startswith("resnet")
+ # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50
+ else backbone_fn(weights=None) # type: ignore[call-arg]
+ )
+ if isinstance(backbone_submodule, str):
+ backbone = getattr(backbone, backbone_submodule)
+ feat_extractor = IntermediateLayerGetter(
+ backbone,
+ {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)},
+ )
+
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+ # Build the model
+ model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of class_names is not the same as the number of classes in the pretrained model =>
+ # remove the layer weights
+ _ignore_keys = (
+ ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
+ )
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_, using a ResNet-34 backbone.
+
+ >>> import torch
+ >>> from doctr.models import db_resnet34
+ >>> model = db_resnet34(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _dbnet(
+ "db_resnet34",
+ pretrained,
+ resnet34,
+ ["layer1", "layer2", "layer3", "layer4"],
+ None,
+ ignore_keys=[
+ "prob_head.6.weight",
+ "prob_head.6.bias",
+ "thresh_head.6.weight",
+ "thresh_head.6.bias",
+ ],
+ **kwargs,
+ )
+
+
+def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_, using a ResNet-50 backbone.
+
+ >>> import torch
+ >>> from doctr.models import db_resnet50
+ >>> model = db_resnet50(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _dbnet(
+ "db_resnet50",
+ pretrained,
+ resnet50,
+ ["layer1", "layer2", "layer3", "layer4"],
+ None,
+ ignore_keys=[
+ "prob_head.6.weight",
+ "prob_head.6.bias",
+ "thresh_head.6.weight",
+ "thresh_head.6.bias",
+ ],
+ **kwargs,
+ )
+
+
+def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_, using a MobileNet V3 Large backbone.
+
+ >>> import torch
+ >>> from doctr.models import db_mobilenet_v3_large
+ >>> model = db_mobilenet_v3_large(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _dbnet(
+ "db_mobilenet_v3_large",
+ pretrained,
+ mobilenet_v3_large,
+ ["3", "6", "12", "16"],
+ "features",
+ ignore_keys=[
+ "prob_head.6.weight",
+ "prob_head.6.bias",
+ "thresh_head.6.weight",
+ "thresh_head.6.bias",
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..df9935b04259ce7912459cda4e3e046589dbeb45
--- /dev/null
+++ b/doctr/models/detection/differentiable_binarization/tensorflow.py
@@ -0,0 +1,402 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras.applications import ResNet50
+
+from doctr.file_utils import CLASS_NAME
+from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
+from doctr.utils.repr import NestedObject
+
+from ...classification import mobilenet_v3_large
+from .base import DBPostProcessor, _DBNet
+
+__all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "db_resnet50": {
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "input_shape": (1024, 1024, 3),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0",
+ },
+ "db_mobilenet_v3_large": {
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "input_shape": (1024, 1024, 3),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0",
+ },
+}
+
+
+class FeaturePyramidNetwork(layers.Layer, NestedObject):
+ """Feature Pyramid Network as described in `"Feature Pyramid Networks for Object Detection"
+ `_.
+
+ Args:
+ ----
+ channels: number of channel to output
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ ) -> None:
+ super().__init__()
+ self.channels = channels
+ self.upsample = layers.UpSampling2D(size=(2, 2), interpolation="nearest")
+ self.inner_blocks = [layers.Conv2D(channels, 1, strides=1, kernel_initializer="he_normal") for _ in range(4)]
+ self.layer_blocks = [self.build_upsampling(channels, dilation_factor=2**idx) for idx in range(4)]
+
+ @staticmethod
+ def build_upsampling(
+ channels: int,
+ dilation_factor: int = 1,
+ ) -> layers.Layer:
+ """Module which performs a 3x3 convolution followed by up-sampling
+
+ Args:
+ ----
+ channels: number of output channels
+ dilation_factor (int): dilation factor to scale the convolution output before concatenation
+
+ Returns:
+ -------
+ a keras.layers.Layer object, wrapping these operations in a sequential module
+
+ """
+ _layers = conv_sequence(channels, "relu", True, kernel_size=3)
+
+ if dilation_factor > 1:
+ _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest"))
+
+ module = keras.Sequential(_layers)
+
+ return module
+
+ def extra_repr(self) -> str:
+ return f"channels={self.channels}"
+
+ def call(
+ self,
+ x: List[tf.Tensor],
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ # Channel mapping
+ results = [block(fmap, **kwargs) for block, fmap in zip(self.inner_blocks, x)]
+ # Upsample & sum
+ for idx in range(len(results) - 1, -1):
+ results[idx] += self.upsample(results[idx + 1])
+ # Conv & upsample
+ results = [block(fmap, **kwargs) for block, fmap in zip(self.layer_blocks, results)]
+
+ return layers.concatenate(results)
+
+
+class DBNet(_DBNet, keras.Model, NestedObject):
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ fpn_channels: number of channels each extracted feature maps is mapped to
+ bin_thresh: threshold for binarization
+ box_thresh: minimal objectness score to consider a box
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
+
+ def __init__(
+ self,
+ feature_extractor: IntermediateLayerGetter,
+ fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
+ bin_thresh: float = 0.3,
+ box_thresh: float = 0.1,
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__()
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+ self.cfg = cfg
+
+ self.feat_extractor = feature_extractor
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.fpn = FeaturePyramidNetwork(channels=fpn_channels)
+ # Initialize kernels
+ _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
+ output_shape = tuple(self.fpn(_inputs).shape)
+
+ self.probability_head = keras.Sequential([
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
+ layers.BatchNormalization(),
+ layers.Activation("relu"),
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
+ ])
+ self.threshold_head = keras.Sequential([
+ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
+ layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
+ layers.BatchNormalization(),
+ layers.Activation("relu"),
+ layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
+ ])
+
+ self.postprocessor = DBPostProcessor(
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ def compute_loss(
+ self,
+ out_map: tf.Tensor,
+ thresh_map: tf.Tensor,
+ target: List[Dict[str, np.ndarray]],
+ gamma: float = 2.0,
+ alpha: float = 0.5,
+ eps: float = 1e-8,
+ ) -> tf.Tensor:
+ """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
+ and a list of masks for each image. From there it computes the loss with the model output
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape (N, H, W, C)
+ thresh_map: threshold map of shape (N, H, W, C)
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ gamma: modulating factor in the focal loss formula
+ alpha: balancing factor in the focal loss formula
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ if gamma < 0:
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
+
+ prob_map = tf.math.sigmoid(out_map)
+ thresh_map = tf.math.sigmoid(thresh_map)
+
+ seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
+ seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
+ seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
+ seg_mask = tf.cast(seg_mask, tf.float32)
+ thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
+ thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)
+
+ # Focal loss
+ focal_scale = 10.0
+ bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
+
+ # Convert logits to prob, compute gamma factor
+ p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
+ alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
+ # Unreduced loss
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
+ # Class reduced
+ focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
+
+ # Compute dice loss for each class or for approx binary_map
+ if len(self.class_names) > 1:
+ dice_map = tf.nn.softmax(out_map, axis=-1)
+ else:
+ # compute binary map instead
+ dice_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
+ # Class-reduced dice loss
+ inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
+ cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
+ dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
+
+ # Compute l1 loss for thresh_map
+ if tf.reduce_any(thresh_mask):
+ thresh_mask = tf.cast(thresh_mask, tf.float32)
+ l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
+ tf.reduce_sum(thresh_mask) + eps
+ )
+ else:
+ l1_loss = tf.constant(0.0)
+
+ return l1_loss + focal_scale * focal_loss + dice_loss
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[Dict[str, np.ndarray]]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ feat_maps = self.feat_extractor(x, **kwargs)
+ feat_concat = self.fpn(feat_maps, **kwargs)
+ logits = self.probability_head(feat_concat, **kwargs)
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
+
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes (keep only text predictions)
+ out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
+
+ if target is not None:
+ thresh_map = self.threshold_head(feat_concat, **kwargs)
+ loss = self.compute_loss(logits, thresh_map, target)
+ out["loss"] = loss
+
+ return out
+
+
+def _db_resnet(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ fpn_layers: List[str],
+ pretrained_backbone: bool = True,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> DBNet:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = input_shape or _cfg["input_shape"]
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(
+ weights="imagenet" if pretrained_backbone else None,
+ include_top=False,
+ pooling=None,
+ input_shape=_cfg["input_shape"],
+ ),
+ fpn_layers,
+ )
+
+ # Build the model
+ model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, _cfg["url"])
+
+ return model
+
+
+def _db_mobilenet(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ fpn_layers: List[str],
+ pretrained_backbone: bool = True,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> DBNet:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = input_shape or _cfg["input_shape"]
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(
+ input_shape=_cfg["input_shape"],
+ include_top=False,
+ pretrained=pretrained_backbone,
+ ),
+ fpn_layers,
+ )
+
+ # Build the model
+ model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, _cfg["url"])
+
+ return model
+
+
+def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_, using a ResNet-50 backbone.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import db_resnet50
+ >>> model = db_resnet50(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _db_resnet(
+ "db_resnet50",
+ pretrained,
+ ResNet50,
+ ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
+ **kwargs,
+ )
+
+
+def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
+ """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
+ `_, using a mobilenet v3 large backbone.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import db_mobilenet_v3_large
+ >>> model = db_mobilenet_v3_large(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _db_mobilenet(
+ "db_mobilenet_v3_large",
+ pretrained,
+ mobilenet_v3_large,
+ ["inverted_2", "inverted_5", "inverted_11", "final_block"],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/fast/__init__.py b/doctr/models/detection/fast/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/detection/fast/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..868c3eadec4c1a284c1f348c117fca96a0476291
--- /dev/null
+++ b/doctr/models/detection/fast/base.py
@@ -0,0 +1,256 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from doctr.models.core import BaseModel
+
+from ..core import DetectionPostProcessor
+
+__all__ = ["_FAST", "FASTPostProcessor"]
+
+
+class FASTPostProcessor(DetectionPostProcessor):
+ """Implements a post processor for FAST model.
+
+ Args:
+ ----
+ bin_thresh: threshold used to binzarized p_map at inference time
+ box_thresh: minimal objectness score to consider a box
+ assume_straight_pages: whether the inputs were expected to have horizontal text elements
+ """
+
+ def __init__(
+ self,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ assume_straight_pages: bool = True,
+ ) -> None:
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
+ self.unclip_ratio = 1.0
+
+ def polygon_to_box(
+ self,
+ points: np.ndarray,
+ ) -> np.ndarray:
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
+
+ Args:
+ ----
+ points: The first parameter.
+
+ Returns:
+ -------
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
+ """
+ if not self.assume_straight_pages:
+ # Compute the rectangle polygon enclosing the raw polygon
+ rect = cv2.minAreaRect(points)
+ points = cv2.boxPoints(rect)
+ # Add 1 pixel to correct cv2 approx
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
+ else:
+ poly = Polygon(points)
+ area = poly.area
+ length = poly.length
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ _points = offset.Execute(distance)
+ # Take biggest stack of points
+ idx = 0
+ if len(_points) > 1:
+ max_size = 0
+ for _idx, p in enumerate(_points):
+ if len(p) > max_size:
+ idx = _idx
+ max_size = len(p)
+ # We ensure that _points can be correctly casted to a ndarray
+ _points = [_points[idx]]
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
+ if len(expanded_points) < 1:
+ return None # type: ignore[return-value]
+ return (
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
+ if self.assume_straight_pages
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
+ )
+
+ def bitmap_to_boxes(
+ self,
+ pred: np.ndarray,
+ bitmap: np.ndarray,
+ ) -> np.ndarray:
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
+
+ Args:
+ ----
+ pred: Pred map from differentiable linknet output
+ bitmap: Bitmap map computed from pred (binarized)
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
+
+ Returns:
+ -------
+ np tensor boxes for the bitmap, each box is a 6-element list
+ containing x, y, w, h, alpha, score for the box
+ """
+ height, width = bitmap.shape[:2]
+ boxes: List[Union[np.ndarray, List[float]]] = []
+ # get contours from connected components on the bitmap
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ for contour in contours:
+ # Check whether smallest enclosing bounding box is not too small
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
+ continue
+ # Compute objectness
+ if self.assume_straight_pages:
+ x, y, w, h = cv2.boundingRect(contour)
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
+ score = self.box_score(pred, points, assume_straight_pages=True)
+ else:
+ score = self.box_score(pred, contour, assume_straight_pages=False)
+
+ if score < self.box_thresh: # remove polygons with a weak objectness
+ continue
+
+ if self.assume_straight_pages:
+ _box = self.polygon_to_box(points)
+ else:
+ _box = self.polygon_to_box(np.squeeze(contour))
+
+ if self.assume_straight_pages:
+ # compute relative polygon to get rid of img shape
+ x, y, w, h = _box
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
+ boxes.append([xmin, ymin, xmax, ymax, score])
+ else:
+ # compute relative box to get rid of img shape
+ _box[:, 0] /= width
+ _box[:, 1] /= height
+ boxes.append(_box)
+
+ if not self.assume_straight_pages:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
+ else:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
+
+
+class _FAST(BaseModel):
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_.
+ """
+
+ min_size_box: int = 3
+ assume_straight_pages: bool = True
+ shrink_ratio = 0.4
+
+ def build_target(
+ self,
+ target: List[Dict[str, np.ndarray]],
+ output_shape: Tuple[int, int, int],
+ channels_last: bool = True,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Build the target, and it's mask to be used from loss computation.
+
+ Args:
+ ----
+ target: target coming from dataset
+ output_shape: shape of the output of the model without batch_size
+ channels_last: whether channels are last or not
+
+ Returns:
+ -------
+ the new formatted target, mask and shrunken text kernel
+ """
+ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
+ raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
+ if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
+ raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
+
+ h: int
+ w: int
+ if channels_last:
+ h, w, num_classes = output_shape
+ else:
+ num_classes, h, w = output_shape
+ target_shape = (len(target), num_classes, h, w)
+
+ seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+ seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
+ shrunken_kernel: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+
+ for idx, tgt in enumerate(target):
+ for class_idx, _tgt in enumerate(tgt.values()):
+ # Draw each polygon on gt
+ if _tgt.shape[0] == 0:
+ # Empty image, full masked
+ seg_mask[idx, class_idx] = False
+
+ # Absolute bounding boxes
+ abs_boxes = _tgt.copy()
+
+ if abs_boxes.ndim == 3:
+ abs_boxes[:, :, 0] *= w
+ abs_boxes[:, :, 1] *= h
+ polys = abs_boxes
+ boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
+ abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
+ else:
+ abs_boxes[:, [0, 2]] *= w
+ abs_boxes[:, [1, 3]] *= h
+ abs_boxes = abs_boxes.round().astype(np.int32)
+ polys = np.stack(
+ [
+ abs_boxes[:, [0, 1]],
+ abs_boxes[:, [0, 3]],
+ abs_boxes[:, [2, 3]],
+ abs_boxes[:, [2, 1]],
+ ],
+ axis=1,
+ )
+ boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
+
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
+ # Mask boxes that are too small
+ if box_size < self.min_size_box:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+
+ # Negative shrink for gt, as described in paper
+ polygon = Polygon(poly)
+ distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
+ subject = [tuple(coor) for coor in poly]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ shrunken = padding.Execute(-distance)
+
+ # Draw polygon on gt if it is valid
+ if len(shrunken) == 0:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
+ # draw the original polygon on the segmentation target
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
+
+ # Don't forget to switch back to channel last if Tensorflow is used
+ if channels_last:
+ seg_target = seg_target.transpose((0, 2, 3, 1))
+ seg_mask = seg_mask.transpose((0, 2, 3, 1))
+ shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
+
+ return seg_target, seg_mask, shrunken_kernel
diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac44b182567f890b48e090e292fbb2cb48f9944
--- /dev/null
+++ b/doctr/models/detection/fast/pytorch.py
@@ -0,0 +1,442 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models._utils import IntermediateLayerGetter
+
+from doctr.file_utils import CLASS_NAME
+
+from ...classification import textnet_base, textnet_small, textnet_tiny
+from ...modules.layers import FASTConvLayer
+from ...utils import _bf16_to_float32, load_pretrained_params
+from .base import _FAST, FASTPostProcessor
+
+__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "fast_tiny": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-1acac421.pt&src=0",
+ },
+ "fast_small": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-10952cc1.pt&src=0",
+ },
+ "fast_base": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-688a8b34.pt&src=0",
+ },
+}
+
+
+class FastNeck(nn.Module):
+ """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
+
+ Args:
+ ----
+ in_channels: number of input channels
+ out_channels: number of output channels
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = 128,
+ ) -> None:
+ super().__init__()
+ self.reduction = nn.ModuleList([
+ FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]
+ ])
+
+ def _upsample(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return F.interpolate(x, size=y.shape[-2:], mode="bilinear")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ f1, f2, f3, f4 = x
+ f1, f2, f3, f4 = [reduction(f) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
+ f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
+ f = torch.cat((f1, f2, f3, f4), 1)
+ return f
+
+
+class FastHead(nn.Sequential):
+ """Head of the FAST architecture
+
+ Args:
+ ----
+ in_channels: number of input channels
+ num_classes: number of output classes
+ out_channels: number of output channels
+ dropout: dropout probability
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_classes: int,
+ out_channels: int = 128,
+ dropout: float = 0.1,
+ ) -> None:
+ _layers: List[nn.Module] = [
+ FASTConvLayer(in_channels, out_channels, kernel_size=3),
+ nn.Dropout(dropout),
+ nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
+ ]
+ super().__init__(*_layers)
+
+
+class FAST(_FAST, nn.Module):
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_.
+
+ Args:
+ ----
+ feat extractor: the backbone serving as feature extractor
+ bin_thresh: threshold for binarization
+ box_thresh: minimal objectness score to consider a box
+ dropout_prob: dropout probability
+ pooling_size: size of the pooling layer
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ def __init__(
+ self,
+ feat_extractor: IntermediateLayerGetter,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ dropout_prob: float = 0.1,
+ pooling_size: int = 4, # different from paper performs better on close text-rich images
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = {},
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__()
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+ self.cfg = cfg
+
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.feat_extractor = feat_extractor
+ # Identify the number of channels for the neck & head initialization
+ _is_training = self.feat_extractor.training
+ self.feat_extractor = self.feat_extractor.eval()
+ with torch.no_grad():
+ out = self.feat_extractor(torch.zeros((1, 3, 32, 32)))
+ feat_out_channels = [v.shape[1] for _, v in out.items()]
+
+ if _is_training:
+ self.feat_extractor = self.feat_extractor.train()
+
+ # Initialize neck & head
+ self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
+ self.prob_head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
+
+ # NOTE: The post processing from the paper works not well for text-rich images
+ # so we use a modified version from DBNet
+ self.postprocessor = FASTPostProcessor(
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ # Pooling layer as erosion reversal as described in the paper
+ self.pooling = nn.MaxPool2d(kernel_size=pooling_size // 2 + 1, stride=1, padding=(pooling_size // 2) // 2)
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[np.ndarray]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, torch.Tensor]:
+ # Extract feature maps at different stages
+ feats = self.feat_extractor(x)
+ feats = [feats[str(idx)] for idx in range(len(feats))]
+ # Pass through the Neck & Head & Upsample
+ feat_concat = self.neck(feats)
+ logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(torch.sigmoid(self.pooling(logits)))
+
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes (keep only text predictions)
+ out["preds"] = [
+ dict(zip(self.class_names, preds))
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
+ ]
+
+ if target is not None:
+ loss = self.compute_loss(logits, target)
+ out["loss"] = loss
+
+ return out
+
+ def compute_loss(
+ self,
+ out_map: torch.Tensor,
+ target: List[np.ndarray],
+ eps: float = 1e-6,
+ ) -> torch.Tensor:
+ """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape (N, num_classes, H, W)
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
+
+ seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
+ shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
+ seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
+
+ def ohem_sample(score: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ masks = []
+ for class_idx in range(gt.shape[0]):
+ pos_num = int(torch.sum(gt[class_idx] > 0.5)) - int(
+ torch.sum((gt[class_idx] > 0.5) & (mask[class_idx] <= 0.5))
+ )
+ neg_num = int(torch.sum(gt[class_idx] <= 0.5))
+ neg_num = int(min(pos_num * 3, neg_num))
+
+ if neg_num == 0 or pos_num == 0:
+ masks.append(mask[class_idx])
+ continue
+
+ neg_score_sorted, _ = torch.sort(-score[class_idx][gt[class_idx] <= 0.5])
+ threshold = -neg_score_sorted[neg_num - 1]
+
+ selected_mask = ((score[class_idx] >= threshold) | (gt[class_idx] > 0.5)) & (mask[class_idx] > 0.5)
+ masks.append(selected_mask)
+ # combine all masks to shape (len(masks), H, W)
+ return torch.stack(masks).unsqueeze(0).float()
+
+ if len(self.class_names) > 1:
+ kernels = torch.softmax(out_map, dim=1)
+ prob_map = torch.softmax(self.pooling(out_map), dim=1)
+ else:
+ kernels = torch.sigmoid(out_map)
+ prob_map = torch.sigmoid(self.pooling(out_map))
+
+ # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
+ selected_masks = torch.cat(
+ [ohem_sample(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], 0
+ ).float()
+ inter = (selected_masks * prob_map * seg_target).sum((0, 2, 3))
+ cardinality = (selected_masks * (prob_map + seg_target)).sum((0, 2, 3))
+ text_loss = (1 - 2 * inter / (cardinality + eps)).mean() * 0.5
+
+ # As described in the paper, we use the Dice loss for the text kernel map.
+ selected_masks = seg_target * seg_mask
+ inter = (selected_masks * kernels * shrunken_kernel).sum((0, 2, 3)) # noqa
+ cardinality = (selected_masks * (kernels + shrunken_kernel)).sum((0, 2, 3)) # noqa
+ kernel_loss = (1 - 2 * inter / (cardinality + eps)).mean()
+
+ return text_loss + kernel_loss
+
+
+def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
+ """Fuse batchnorm and conv layers and reparameterize the model
+
+ args:
+ ----
+ model: the FAST model to reparameterize
+
+ Returns:
+ -------
+ the reparameterized model
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for module in model.modules():
+ if hasattr(module, "reparameterize_layer"):
+ module.reparameterize_layer()
+
+ for name, child in model.named_children():
+ if isinstance(child, nn.BatchNorm2d):
+ # fuse batchnorm only if it is followed by a conv layer
+ if last_conv is None:
+ continue
+ conv_w = last_conv.weight
+ conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
+
+ factor = child.weight / torch.sqrt(child.running_var + child.eps)
+ last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
+ last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
+ model._modules[last_conv_name] = last_conv
+ model._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ reparameterize(child)
+
+ return model # type: ignore[return-value]
+
+
+def _fast(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ feat_layers: List[str],
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> FAST:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Build the feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(pretrained_backbone),
+ {layer_name: str(idx) for idx, layer_name in enumerate(feat_layers)},
+ )
+
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+ # Build the model
+ model = FAST(feat_extractor, cfg=default_cfgs[arch], **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of class_names is not the same as the number of classes in the pretrained model =>
+ # remove the layer weights
+ _ignore_keys = (
+ ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
+ )
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a tiny TextNet backbone.
+
+ >>> import torch
+ >>> from doctr.models import fast_tiny
+ >>> model = fast_tiny(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_tiny",
+ pretrained,
+ textnet_tiny,
+ ["3", "4", "5", "6"],
+ ignore_keys=["prob_head.2.weight"],
+ **kwargs,
+ )
+
+
+def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a small TextNet backbone.
+
+ >>> import torch
+ >>> from doctr.models import fast_small
+ >>> model = fast_small(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_small",
+ pretrained,
+ textnet_small,
+ ["3", "4", "5", "6"],
+ ignore_keys=["prob_head.2.weight"],
+ **kwargs,
+ )
+
+
+def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a base TextNet backbone.
+
+ >>> import torch
+ >>> from doctr.models import fast_base
+ >>> model = fast_base(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_base",
+ pretrained,
+ textnet_base,
+ ["3", "4", "5", "6"],
+ ignore_keys=["prob_head.2.weight"],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0934e99c23b27a2241306a8c3babbde6dd25d6a
--- /dev/null
+++ b/doctr/models/detection/fast/tensorflow.py
@@ -0,0 +1,428 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import Sequential, layers
+
+from doctr.file_utils import CLASS_NAME
+from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
+from doctr.utils.repr import NestedObject
+
+from ...classification import textnet_base, textnet_small, textnet_tiny
+from ...modules.layers import FASTConvLayer
+from .base import _FAST, FASTPostProcessor
+
+__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "fast_tiny": {
+ "input_shape": (1024, 1024, 3),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": None,
+ },
+ "fast_small": {
+ "input_shape": (1024, 1024, 3),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": None,
+ },
+ "fast_base": {
+ "input_shape": (1024, 1024, 3),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": None,
+ },
+}
+
+
+class FastNeck(layers.Layer, NestedObject):
+ """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
+
+ Args:
+ ----
+ in_channels: number of input channels
+ out_channels: number of output channels
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = 128,
+ ) -> None:
+ super().__init__()
+ self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]
+
+ def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
+ return tf.image.resize(x, size=y.shape[1:3], method="bilinear")
+
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ f1, f2, f3, f4 = x
+ f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
+ f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
+ f = tf.concat((f1, f2, f3, f4), axis=-1)
+ return f
+
+
+class FastHead(Sequential):
+ """Head of the FAST architecture
+
+ Args:
+ ----
+ in_channels: number of input channels
+ num_classes: number of output classes
+ out_channels: number of output channels
+ dropout: dropout probability
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_classes: int,
+ out_channels: int = 128,
+ dropout: float = 0.1,
+ ) -> None:
+ _layers = [
+ FASTConvLayer(in_channels, out_channels, kernel_size=3),
+ layers.Dropout(dropout),
+ layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
+ ]
+ super().__init__(_layers)
+
+
+class FAST(_FAST, keras.Model, NestedObject):
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ bin_thresh: threshold for binarization
+ box_thresh: minimal objectness score to consider a box
+ dropout_prob: dropout probability
+ pooling_size: size of the pooling layer
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
+
+ def __init__(
+ self,
+ feature_extractor: IntermediateLayerGetter,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ dropout_prob: float = 0.1,
+ pooling_size: int = 4, # different from paper performs better on close text-rich images
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = {},
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__()
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+ self.cfg = cfg
+
+ self.feat_extractor = feature_extractor
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ # Identify the number of channels for the neck & head initialization
+ feat_out_channels = [
+ layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
+ ]
+ # Initialize neck & head
+ self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
+ self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)
+
+ # NOTE: The post processing from the paper works not well for text-rich images
+ # so we use a modified version from DBNet
+ self.postprocessor = FASTPostProcessor(
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ # Pooling layer as erosion reversal as described in the paper
+ self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
+
+ def compute_loss(
+ self,
+ out_map: tf.Tensor,
+ target: List[Dict[str, np.ndarray]],
+ eps: float = 1e-6,
+ ) -> tf.Tensor:
+ """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape (N, num_classes, H, W)
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ targets = self.build_target(target, out_map.shape[1:], True)
+
+ seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
+ seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
+ shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)
+
+ def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
+ pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
+ tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
+ )
+ neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
+ neg_num = tf.minimum(pos_num * 3, neg_num)
+
+ if neg_num == 0 or pos_num == 0:
+ return mask
+
+ neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
+ threshold = -neg_score_sorted[-1]
+
+ selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
+ return tf.cast(selected_mask, dtype=tf.float32)
+
+ if len(self.class_names) > 1:
+ kernels = tf.nn.softmax(out_map, axis=-1)
+ prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
+ else:
+ kernels = tf.sigmoid(out_map)
+ prob_map = tf.sigmoid(self.pooling(out_map))
+
+ # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
+ selected_masks = tf.stack(
+ [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
+ )
+ inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
+ cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
+ text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5
+
+ # As described in the paper, we use the Dice loss for the text kernel map.
+ selected_masks = seg_target * seg_mask
+ inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
+ cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
+ kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))
+
+ return text_loss + kernel_loss
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[Dict[str, np.ndarray]]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ feat_maps = self.feat_extractor(x, **kwargs)
+ # Pass through the Neck & Head & Upsample
+ feat_concat = self.neck(feat_maps, **kwargs)
+ logits: tf.Tensor = self.head(feat_concat, **kwargs)
+ logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))
+
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes (keep only text predictions)
+ out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
+
+ if target is not None:
+ loss = self.compute_loss(logits, target)
+ out["loss"] = loss
+
+ return out
+
+
+def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
+ """Fuse batchnorm and conv layers and reparameterize the model
+
+ args:
+ ----
+ model: the FAST model to reparameterize
+
+ Returns:
+ -------
+ the reparameterized model
+ """
+ last_conv = None
+ last_conv_idx = None
+
+ for idx, layer in enumerate(model.layers):
+ if hasattr(layer, "layers") or isinstance(
+ layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
+ ):
+ if isinstance(layer, layers.BatchNormalization):
+ # fuse batchnorm only if it is followed by a conv layer
+ if last_conv is None:
+ continue
+ conv_w = last_conv.kernel
+ conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)
+
+ factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
+ last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
+ if last_conv.use_bias:
+ last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
+ model.layers[last_conv_idx] = last_conv # Replace the last conv layer with the fused version
+ model.layers[idx] = layers.Lambda(lambda x: x)
+ last_conv = None
+ elif isinstance(layer, layers.Conv2D):
+ last_conv = layer
+ last_conv_idx = idx
+ elif isinstance(layer, FASTConvLayer):
+ layer.reparameterize_layer()
+ elif isinstance(layer, FastNeck):
+ for reduction in layer.reduction:
+ reduction.reparameterize_layer()
+ elif isinstance(layer, FastHead):
+ reparameterize(layer)
+ else:
+ reparameterize(layer)
+ return model
+
+
+def _fast(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ feat_layers: List[str],
+ pretrained_backbone: bool = True,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> FAST:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = input_shape or _cfg["input_shape"]
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(
+ input_shape=_cfg["input_shape"],
+ include_top=False,
+ pretrained=pretrained_backbone,
+ ),
+ feat_layers,
+ )
+
+ # Build the model
+ model = FAST(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, _cfg["url"])
+
+ # Build the model for reparameterization to access the layers
+ _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
+
+ return model
+
+
+def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a tiny TextNet backbone.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import fast_tiny
+ >>> model = fast_tiny(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_tiny",
+ pretrained,
+ textnet_tiny,
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
+ **kwargs,
+ )
+
+
+def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a small TextNet backbone.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import fast_small
+ >>> model = fast_small(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_small",
+ pretrained,
+ textnet_small,
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
+ **kwargs,
+ )
+
+
+def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
+ """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+ `_, using a base TextNet backbone.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import fast_base
+ >>> model = fast_base(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the DBNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _fast(
+ "fast_base",
+ pretrained,
+ textnet_base,
+ ["stage_0", "stage_1", "stage_2", "stage_3"],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/linknet/__init__.py b/doctr/models/detection/linknet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/detection/linknet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..986f57d6ad2fbdc6503dff21901875617e04c50d
--- /dev/null
+++ b/doctr/models/detection/linknet/base.py
@@ -0,0 +1,256 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from doctr.models.core import BaseModel
+
+from ..core import DetectionPostProcessor
+
+__all__ = ["_LinkNet", "LinkNetPostProcessor"]
+
+
+class LinkNetPostProcessor(DetectionPostProcessor):
+ """Implements a post processor for LinkNet model.
+
+ Args:
+ ----
+ bin_thresh: threshold used to binzarized p_map at inference time
+ box_thresh: minimal objectness score to consider a box
+ assume_straight_pages: whether the inputs were expected to have horizontal text elements
+ """
+
+ def __init__(
+ self,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ assume_straight_pages: bool = True,
+ ) -> None:
+ super().__init__(box_thresh, bin_thresh, assume_straight_pages)
+ self.unclip_ratio = 1.5
+
+ def polygon_to_box(
+ self,
+ points: np.ndarray,
+ ) -> np.ndarray:
+ """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
+
+ Args:
+ ----
+ points: The first parameter.
+
+ Returns:
+ -------
+ a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
+ """
+ if not self.assume_straight_pages:
+ # Compute the rectangle polygon enclosing the raw polygon
+ rect = cv2.minAreaRect(points)
+ points = cv2.boxPoints(rect)
+ # Add 1 pixel to correct cv2 approx
+ area = (rect[1][0] + 1) * (1 + rect[1][1])
+ length = 2 * (rect[1][0] + rect[1][1]) + 2
+ else:
+ poly = Polygon(points)
+ area = poly.area
+ length = poly.length
+ distance = area * self.unclip_ratio / length # compute distance to expand polygon
+ offset = pyclipper.PyclipperOffset()
+ offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ _points = offset.Execute(distance)
+ # Take biggest stack of points
+ idx = 0
+ if len(_points) > 1:
+ max_size = 0
+ for _idx, p in enumerate(_points):
+ if len(p) > max_size:
+ idx = _idx
+ max_size = len(p)
+ # We ensure that _points can be correctly casted to a ndarray
+ _points = [_points[idx]]
+ expanded_points: np.ndarray = np.asarray(_points) # expand polygon
+ if len(expanded_points) < 1:
+ return None # type: ignore[return-value]
+ return (
+ cv2.boundingRect(expanded_points) # type: ignore[return-value]
+ if self.assume_straight_pages
+ else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
+ )
+
+ def bitmap_to_boxes(
+ self,
+ pred: np.ndarray,
+ bitmap: np.ndarray,
+ ) -> np.ndarray:
+ """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
+
+ Args:
+ ----
+ pred: Pred map from differentiable linknet output
+ bitmap: Bitmap map computed from pred (binarized)
+ angle_tol: Comparison tolerance of the angle with the median angle across the page
+ ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
+
+ Returns:
+ -------
+ np tensor boxes for the bitmap, each box is a 6-element list
+ containing x, y, w, h, alpha, score for the box
+ """
+ height, width = bitmap.shape[:2]
+ boxes: List[Union[np.ndarray, List[float]]] = []
+ # get contours from connected components on the bitmap
+ contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ for contour in contours:
+ # Check whether smallest enclosing bounding box is not too small
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
+ continue
+ # Compute objectness
+ if self.assume_straight_pages:
+ x, y, w, h = cv2.boundingRect(contour)
+ points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
+ score = self.box_score(pred, points, assume_straight_pages=True)
+ else:
+ score = self.box_score(pred, contour, assume_straight_pages=False)
+
+ if score < self.box_thresh: # remove polygons with a weak objectness
+ continue
+
+ if self.assume_straight_pages:
+ _box = self.polygon_to_box(points)
+ else:
+ _box = self.polygon_to_box(np.squeeze(contour))
+
+ if self.assume_straight_pages:
+ # compute relative polygon to get rid of img shape
+ x, y, w, h = _box
+ xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
+ boxes.append([xmin, ymin, xmax, ymax, score])
+ else:
+ # compute relative box to get rid of img shape
+ _box[:, 0] /= width
+ _box[:, 1] /= height
+ boxes.append(_box)
+
+ if not self.assume_straight_pages:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
+ else:
+ return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
+
+
+class _LinkNet(BaseModel):
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ Args:
+ ----
+ out_chan: number of channels for the output
+ """
+
+ min_size_box: int = 3
+ assume_straight_pages: bool = True
+ shrink_ratio = 0.5
+
+ def build_target(
+ self,
+ target: List[Dict[str, np.ndarray]],
+ output_shape: Tuple[int, int, int],
+ channels_last: bool = True,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Build the target, and it's mask to be used from loss computation.
+
+ Args:
+ ----
+ target: target coming from dataset
+ output_shape: shape of the output of the model without batch_size
+ channels_last: whether channels are last or not
+
+ Returns:
+ -------
+ the new formatted target and the mask
+ """
+ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
+ raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
+ if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
+ raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
+
+ h: int
+ w: int
+ if channels_last:
+ h, w, num_classes = output_shape
+ else:
+ num_classes, h, w = output_shape
+ target_shape = (len(target), num_classes, h, w)
+
+ seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+ seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
+
+ for idx, tgt in enumerate(target):
+ for class_idx, _tgt in enumerate(tgt.values()):
+ # Draw each polygon on gt
+ if _tgt.shape[0] == 0:
+ # Empty image, full masked
+ seg_mask[idx, class_idx] = False
+
+ # Absolute bounding boxes
+ abs_boxes = _tgt.copy()
+
+ if abs_boxes.ndim == 3:
+ abs_boxes[:, :, 0] *= w
+ abs_boxes[:, :, 1] *= h
+ polys = abs_boxes
+ boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
+ abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
+ else:
+ abs_boxes[:, [0, 2]] *= w
+ abs_boxes[:, [1, 3]] *= h
+ abs_boxes = abs_boxes.round().astype(np.int32)
+ polys = np.stack(
+ [
+ abs_boxes[:, [0, 1]],
+ abs_boxes[:, [0, 3]],
+ abs_boxes[:, [2, 3]],
+ abs_boxes[:, [2, 1]],
+ ],
+ axis=1,
+ )
+ boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
+
+ for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
+ # Mask boxes that are too small
+ if box_size < self.min_size_box:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+
+ # Negative shrink for gt, as described in paper
+ polygon = Polygon(poly)
+ distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
+ subject = [tuple(coor) for coor in poly]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+ shrunken = padding.Execute(-distance)
+
+ # Draw polygon on gt if it is valid
+ if len(shrunken) == 0:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ shrunken = np.array(shrunken[0]).reshape(-1, 2)
+ if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
+ seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+ continue
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
+
+ # Don't forget to switch back to channel last if Tensorflow is used
+ if channels_last:
+ seg_target = seg_target.transpose((0, 2, 3, 1))
+ seg_mask = seg_mask.transpose((0, 2, 3, 1))
+
+ return seg_target, seg_mask
diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..537fd57256526a8964c4feda8c39c05805b697f4
--- /dev/null
+++ b/doctr/models/detection/linknet/pytorch.py
@@ -0,0 +1,380 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models._utils import IntermediateLayerGetter
+
+from doctr.file_utils import CLASS_NAME
+from doctr.models.classification import resnet18, resnet34, resnet50
+
+from ...utils import _bf16_to_float32, load_pretrained_params
+from .base import LinkNetPostProcessor, _LinkNet
+
+__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "linknet_resnet18": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0",
+ },
+ "linknet_resnet34": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0",
+ },
+ "linknet_resnet50": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0",
+ },
+}
+
+
+class LinkNetFPN(nn.Module):
+ def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None:
+ super().__init__()
+ strides = [
+ 1 if (in_shape[-1] == out_shape[-1]) else 2
+ for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:])
+ ]
+
+ chans = [shape[0] for shape in layer_shapes]
+
+ _decoder_layers = [
+ self.decoder_block(ochan, ichan, stride) for ichan, ochan, stride in zip(chans[:-1], chans[1:], strides)
+ ]
+
+ self.decoders = nn.ModuleList(_decoder_layers)
+
+ @staticmethod
+ def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential:
+ """Creates a LinkNet decoder block"""
+ mid_chan = in_chan // 4
+ return nn.Sequential(
+ nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False),
+ nn.BatchNorm2d(mid_chan),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(mid_chan, mid_chan, 3, padding=1, output_padding=stride - 1, stride=stride, bias=False),
+ nn.BatchNorm2d(mid_chan),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False),
+ nn.BatchNorm2d(out_chan),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
+ out = feats[-1]
+ for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
+ out = decoder(out) + fmap
+
+ out = self.decoders[0](out)
+
+ return out
+
+
+class LinkNet(nn.Module, _LinkNet):
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ bin_thresh: threshold for binarization of the output feature map
+ box_thresh: minimal objectness score to consider a box
+ head_chans: number of channels in the head layers
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ def __init__(
+ self,
+ feat_extractor: IntermediateLayerGetter,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ head_chans: int = 32,
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__()
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+ self.cfg = cfg
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.feat_extractor = feat_extractor
+ # Identify the number of channels for the FPN initialization
+ self.feat_extractor.eval()
+ with torch.no_grad():
+ in_shape = (3, 512, 512)
+ out = self.feat_extractor(torch.zeros((1, *in_shape)))
+ # Get the shapes of the extracted feature maps
+ _shapes = [v.shape[1:] for _, v in out.items()]
+ # Prepend the expected shapes of the first encoder
+ _shapes = [(_shapes[0][0], in_shape[1] // 4, in_shape[2] // 4)] + _shapes
+ self.feat_extractor.train()
+
+ self.fpn = LinkNetFPN(_shapes)
+
+ self.classifier = nn.Sequential(
+ nn.ConvTranspose2d(
+ _shapes[0][0], head_chans, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
+ ),
+ nn.BatchNorm2d(head_chans),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_chans, head_chans, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(head_chans),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2),
+ )
+
+ self.postprocessor = LinkNetPostProcessor(
+ assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[np.ndarray]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ feats = self.feat_extractor(x)
+ logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
+ logits = self.classifier(logits)
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(torch.sigmoid(logits))
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = [
+ dict(zip(self.class_names, preds))
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
+ ]
+
+ if target is not None:
+ loss = self.compute_loss(logits, target)
+ out["loss"] = loss
+
+ return out
+
+ def compute_loss(
+ self,
+ out_map: torch.Tensor,
+ target: List[np.ndarray],
+ gamma: float = 2.0,
+ alpha: float = 0.5,
+ eps: float = 1e-8,
+ ) -> torch.Tensor:
+ """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
+ `_.
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape (N, num_classes, H, W)
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ gamma: modulating factor in the focal loss formula
+ alpha: balancing factor in the focal loss formula
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
+
+ seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
+ seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
+ seg_mask = seg_mask.to(dtype=torch.float32)
+
+ bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")
+ proba_map = torch.sigmoid(out_map)
+
+ # Focal loss
+ if gamma < 0:
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
+ p_t = proba_map * seg_target + (1 - proba_map) * (1 - seg_target)
+ alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
+ # Unreduced version
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
+ # Class reduced
+ focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
+
+ # Compute dice loss for each class
+ dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map
+ # Class reduced
+ inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
+ cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
+ dice_loss = (1 - 2 * inter / (cardinality + eps)).mean()
+
+ # Return the full loss (equal sum of focal loss and dice loss)
+ return focal_loss + dice_loss
+
+
+def _linknet(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ fpn_layers: List[str],
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> LinkNet:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Build the feature extractor
+ backbone = backbone_fn(pretrained_backbone)
+ feat_extractor = IntermediateLayerGetter(
+ backbone,
+ {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)},
+ )
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+
+ # Build the model
+ model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of class_names is not the same as the number of classes in the pretrained model =>
+ # remove the layer weights
+ _ignore_keys = (
+ ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
+ )
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import linknet_resnet18
+ >>> model = linknet_resnet18(pretrained=True).eval()
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet18",
+ pretrained,
+ resnet18,
+ ["layer1", "layer2", "layer3", "layer4"],
+ ignore_keys=[
+ "classifier.6.weight",
+ "classifier.6.bias",
+ ],
+ **kwargs,
+ )
+
+
+def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import linknet_resnet34
+ >>> model = linknet_resnet34(pretrained=True).eval()
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet34",
+ pretrained,
+ resnet34,
+ ["layer1", "layer2", "layer3", "layer4"],
+ ignore_keys=[
+ "classifier.6.weight",
+ "classifier.6.bias",
+ ],
+ **kwargs,
+ )
+
+
+def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import linknet_resnet50
+ >>> model = linknet_resnet50(pretrained=True).eval()
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet50",
+ pretrained,
+ resnet50,
+ ["layer1", "layer2", "layer3", "layer4"],
+ ignore_keys=[
+ "classifier.6.weight",
+ "classifier.6.bias",
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff11dbe4778d9fda973a72b3692cc014ccef70d5
--- /dev/null
+++ b/doctr/models/detection/linknet/tensorflow.py
@@ -0,0 +1,366 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import Model, Sequential, layers
+
+from doctr.file_utils import CLASS_NAME
+from doctr.models.classification import resnet18, resnet34, resnet50
+from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
+from doctr.utils.repr import NestedObject
+
+from .base import LinkNetPostProcessor, _LinkNet
+
+__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "linknet_resnet18": {
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "input_shape": (1024, 1024, 3),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
+ },
+ "linknet_resnet34": {
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "input_shape": (1024, 1024, 3),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
+ },
+ "linknet_resnet50": {
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "input_shape": (1024, 1024, 3),
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
+ },
+}
+
+
+def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
+ """Creates a LinkNet decoder block"""
+ return Sequential([
+ *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
+ layers.Conv2DTranspose(
+ filters=in_chan // 4,
+ kernel_size=3,
+ strides=stride,
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ ),
+ layers.BatchNormalization(),
+ layers.Activation("relu"),
+ *conv_sequence(out_chan, "relu", True, kernel_size=1),
+ ])
+
+
+class LinkNetFPN(Model, NestedObject):
+ """LinkNet Decoder module"""
+
+ def __init__(
+ self,
+ out_chans: int,
+ in_shapes: List[Tuple[int, ...]],
+ ) -> None:
+ super().__init__()
+ self.out_chans = out_chans
+ strides = [2] * (len(in_shapes) - 1) + [1]
+ i_chans = [s[-1] for s in in_shapes[::-1]]
+ o_chans = i_chans[1:] + [out_chans]
+ self.decoders = [
+ decoder_block(in_chan, out_chan, s, input_shape=in_shape)
+ for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
+ ]
+
+ def call(self, x: List[tf.Tensor]) -> tf.Tensor:
+ out = 0
+ for decoder, fmap in zip(self.decoders, x[::-1]):
+ out = decoder(out + fmap)
+ return out
+
+ def extra_repr(self) -> str:
+ return f"out_chans={self.out_chans}"
+
+
+class LinkNet(_LinkNet, keras.Model):
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ Args:
+ ----
+ feature extractor: the backbone serving as feature extractor
+ fpn_channels: number of channels each extracted feature maps is mapped to
+ bin_thresh: threshold for binarization of the output feature map
+ box_thresh: minimal objectness score to consider a box
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ class_names: list of class names
+ """
+
+ _children_names: List[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
+
+ def __init__(
+ self,
+ feat_extractor: IntermediateLayerGetter,
+ fpn_channels: int = 64,
+ bin_thresh: float = 0.1,
+ box_thresh: float = 0.1,
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ class_names: List[str] = [CLASS_NAME],
+ ) -> None:
+ super().__init__(cfg=cfg)
+
+ self.class_names = class_names
+ num_classes: int = len(self.class_names)
+
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.feat_extractor = feat_extractor
+
+ self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
+ self.fpn.build(self.feat_extractor.output_shape)
+
+ self.classifier = Sequential([
+ layers.Conv2DTranspose(
+ filters=32,
+ kernel_size=3,
+ strides=2,
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ input_shape=self.fpn.decoders[-1].output_shape[1:],
+ ),
+ layers.BatchNormalization(),
+ layers.Activation("relu"),
+ *conv_sequence(32, "relu", True, kernel_size=3, strides=1),
+ layers.Conv2DTranspose(
+ filters=num_classes,
+ kernel_size=2,
+ strides=2,
+ padding="same",
+ use_bias=True,
+ kernel_initializer="he_normal",
+ ),
+ ])
+
+ self.postprocessor = LinkNetPostProcessor(
+ assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
+ )
+
+ def compute_loss(
+ self,
+ out_map: tf.Tensor,
+ target: List[Dict[str, np.ndarray]],
+ gamma: float = 2.0,
+ alpha: float = 0.5,
+ eps: float = 1e-8,
+ ) -> tf.Tensor:
+ """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on
+ `_.
+
+ Args:
+ ----
+ out_map: output feature map of the model of shape N x H x W x 1
+ target: list of dictionary where each dict has a `boxes` and a `flags` entry
+ gamma: modulating factor in the focal loss formula
+ alpha: balancing factor in the focal loss formula
+ eps: epsilon factor in dice loss
+
+ Returns:
+ -------
+ A loss tensor
+ """
+ seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
+ seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
+ seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
+ seg_mask = tf.cast(seg_mask, tf.float32)
+
+ bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
+ proba_map = tf.sigmoid(out_map)
+
+ # Focal loss
+ if gamma < 0:
+ raise ValueError("Value of gamma should be greater than or equal to zero.")
+ # Convert logits to prob, compute gamma factor
+ p_t = (seg_target * proba_map) + ((1 - seg_target) * (1 - proba_map))
+ alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
+ # Unreduced loss
+ focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
+ # Class reduced
+ focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))
+
+ # Compute dice loss for each class
+ dice_map = tf.nn.softmax(out_map, axis=-1) if len(self.class_names) > 1 else proba_map
+ # Class-reduced dice loss
+ inter = tf.reduce_sum(seg_mask * dice_map * seg_target, axis=[0, 1, 2])
+ cardinality = tf.reduce_sum(seg_mask * (dice_map + seg_target), axis=[0, 1, 2])
+ dice_loss = tf.reduce_mean(1 - 2 * inter / (cardinality + eps))
+
+ return focal_loss + dice_loss
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[Dict[str, np.ndarray]]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ feat_maps = self.feat_extractor(x, **kwargs)
+ logits = self.fpn(feat_maps, **kwargs)
+ logits = self.classifier(logits, **kwargs)
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output or target is None or return_preds:
+ prob_map = _bf16_to_float32(tf.math.sigmoid(logits))
+
+ if return_model_output:
+ out["out_map"] = prob_map
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]
+
+ if target is not None:
+ loss = self.compute_loss(logits, target)
+ out["loss"] = loss
+
+ return out
+
+
+def _linknet(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ fpn_layers: List[str],
+ pretrained_backbone: bool = True,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> LinkNet:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
+ if not kwargs.get("class_names", None):
+ kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
+ else:
+ kwargs["class_names"] = sorted(kwargs["class_names"])
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(
+ pretrained=pretrained_backbone,
+ include_top=False,
+ input_shape=_cfg["input_shape"],
+ ),
+ fpn_layers,
+ )
+
+ # Build the model
+ model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, _cfg["url"])
+
+ return model
+
+
+def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import linknet_resnet18
+ >>> model = linknet_resnet18(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet18",
+ pretrained,
+ resnet18,
+ ["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"],
+ **kwargs,
+ )
+
+
+def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import linknet_resnet34
+ >>> model = linknet_resnet34(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet34",
+ pretrained,
+ resnet34,
+ ["resnet_block_2", "resnet_block_6", "resnet_block_12", "resnet_block_15"],
+ **kwargs,
+ )
+
+
+def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
+ """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+ `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import linknet_resnet50
+ >>> model = linknet_resnet50(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ -------
+ text detection architecture
+ """
+ return _linknet(
+ "linknet_resnet50",
+ pretrained,
+ resnet50,
+ ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
+ **kwargs,
+ )
diff --git a/doctr/models/detection/predictor/__init__.py b/doctr/models/detection/predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd
--- /dev/null
+++ b/doctr/models/detection/predictor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available
+
+if is_tf_available():
+ from .tensorflow import *
+else:
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b83d61906cd5631ac3638dc18c4bb418bee6556
--- /dev/null
+++ b/doctr/models/detection/predictor/pytorch.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.models.utils import set_device_and_dtype
+
+__all__ = ["DetectionPredictor"]
+
+
+class DetectionPredictor(nn.Module):
+ """Implements an object able to localize text elements in a document
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core detection architecture
+ """
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: nn.Module,
+ ) -> None:
+ super().__init__()
+ self.pre_processor = pre_processor
+ self.model = model.eval()
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ pages: List[Union[np.ndarray, torch.Tensor]],
+ return_maps: bool = False,
+ **kwargs: Any,
+ ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ processed_batches = self.pre_processor(pages)
+ _params = next(self.model.parameters())
+ self.model, processed_batches = set_device_and_dtype(
+ self.model, processed_batches, _params.device, _params.dtype
+ )
+ predicted_batches = [
+ self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
+ ]
+ preds = [pred for batch in predicted_batches for pred in batch["preds"]]
+ if return_maps:
+ seg_maps = [
+ pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
+ ]
+ return preds, seg_maps
+ return preds
diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff3f388bd6bcc60824d31b0e8af806f5aaffaf5
--- /dev/null
+++ b/doctr/models/detection/predictor/tensorflow.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict, List, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.utils.repr import NestedObject
+
+__all__ = ["DetectionPredictor"]
+
+
+class DetectionPredictor(NestedObject):
+ """Implements an object able to localize text elements in a document
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core detection architecture
+ """
+
+ _children_names: List[str] = ["pre_processor", "model"]
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: keras.Model,
+ ) -> None:
+ self.pre_processor = pre_processor
+ self.model = model
+
+ def __call__(
+ self,
+ pages: List[Union[np.ndarray, tf.Tensor]],
+ return_maps: bool = False,
+ **kwargs: Any,
+ ) -> Union[List[Dict[str, np.ndarray]], Tuple[List[Dict[str, np.ndarray]], List[np.ndarray]]]:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ processed_batches = self.pre_processor(pages)
+ predicted_batches = [
+ self.model(batch, return_preds=True, return_model_output=True, training=False, **kwargs)
+ for batch in processed_batches
+ ]
+
+ preds = [pred for batch in predicted_batches for pred in batch["preds"]]
+ if return_maps:
+ seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
+ return preds, seg_maps
+ return preds
diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cbc1adc5d9d40d89965420bb546f06c2e3a154
--- /dev/null
+++ b/doctr/models/detection/zoo.py
@@ -0,0 +1,102 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List
+
+from doctr.file_utils import is_tf_available, is_torch_available
+
+from .. import detection
+from ..detection.fast import reparameterize
+from ..preprocessor import PreProcessor
+from .predictor import DetectionPredictor
+
+__all__ = ["detection_predictor"]
+
+ARCHS: List[str]
+
+
+if is_tf_available():
+ ARCHS = [
+ "db_resnet50",
+ "db_mobilenet_v3_large",
+ "linknet_resnet18",
+ "linknet_resnet34",
+ "linknet_resnet50",
+ "fast_tiny",
+ "fast_small",
+ "fast_base",
+ ]
+elif is_torch_available():
+ ARCHS = [
+ "db_resnet34",
+ "db_resnet50",
+ "db_mobilenet_v3_large",
+ "linknet_resnet18",
+ "linknet_resnet34",
+ "linknet_resnet50",
+ "fast_tiny",
+ "fast_small",
+ "fast_base",
+ ]
+
+
+def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
+ if isinstance(arch, str):
+ if arch not in ARCHS:
+ raise ValueError(f"unknown architecture '{arch}'")
+
+ _model = detection.__dict__[arch](
+ pretrained=pretrained,
+ pretrained_backbone=kwargs.get("pretrained_backbone", True),
+ assume_straight_pages=assume_straight_pages,
+ )
+ # Reparameterize FAST models by default to lower inference latency and memory usage
+ if isinstance(_model, detection.FAST):
+ _model = reparameterize(_model)
+ else:
+ if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
+ raise ValueError(f"unknown architecture: {type(arch)}")
+
+ _model = arch
+ _model.assume_straight_pages = assume_straight_pages
+
+ kwargs.pop("pretrained_backbone", None)
+
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
+ kwargs["batch_size"] = kwargs.get("batch_size", 2)
+ predictor = DetectionPredictor(
+ PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
+ _model,
+ )
+ return predictor
+
+
+def detection_predictor(
+ arch: Any = "db_resnet50",
+ pretrained: bool = False,
+ assume_straight_pages: bool = True,
+ **kwargs: Any,
+) -> DetectionPredictor:
+ """Text detection architecture.
+
+ >>> import numpy as np
+ >>> from doctr.models import detection_predictor
+ >>> model = detection_predictor(arch='db_resnet50', pretrained=True)
+ >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
+ >>> out = model([input_page])
+
+ Args:
+ ----
+ arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
+ pretrained: If True, returns a model pre-trained on our text detection dataset
+ assume_straight_pages: If True, fit straight boxes to the page
+ **kwargs: optional keyword arguments passed to the architecture
+
+ Returns:
+ -------
+ Detection predictor
+ """
+ return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
diff --git a/doctr/models/factory/__init__.py b/doctr/models/factory/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5b25a325e8a5ecc11546832996a78de1407ce4
--- /dev/null
+++ b/doctr/models/factory/__init__.py
@@ -0,0 +1 @@
+from .hub import *
diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e3e28e4f1f414968f401adcb817d59a49950806
--- /dev/null
+++ b/doctr/models/factory/hub.py
@@ -0,0 +1,240 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py
+
+import json
+import logging
+import os
+import subprocess
+import textwrap
+from pathlib import Path
+from typing import Any
+
+from huggingface_hub import (
+ HfApi,
+ Repository,
+ get_token,
+ get_token_permission,
+ hf_hub_download,
+ login,
+ snapshot_download,
+)
+
+from doctr import models
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_torch_available():
+ import torch
+
+__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
+
+
+AVAILABLE_ARCHS = {
+ "classification": models.classification.zoo.ARCHS,
+ "detection": models.detection.zoo.ARCHS,
+ "recognition": models.recognition.zoo.ARCHS,
+ "obj_detection": ["fasterrcnn_mobilenet_v3_large_fpn"] if is_torch_available() else None,
+}
+
+
+def login_to_hub() -> None: # pragma: no cover
+ """Login to huggingface hub"""
+ access_token = get_token()
+ if access_token is not None and get_token_permission(access_token):
+ logging.info("Huggingface Hub token found and valid")
+ login(token=access_token, write_permission=True)
+ else:
+ login()
+ # check if git lfs is installed
+ try:
+ subprocess.call(["git", "lfs", "version"])
+ except FileNotFoundError:
+ raise OSError(
+ "Looks like you do not have git-lfs installed, please install. \
+ You can install from https://git-lfs.github.com/. \
+ Then run `git lfs install` (you only have to do this once)."
+ )
+
+
+def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None:
+ """Save model and config to disk for pushing to huggingface hub
+
+ Args:
+ ----
+ model: TF or PyTorch model to be saved
+ save_dir: directory to save model and config
+ arch: architecture name
+ task: task name
+ """
+ save_directory = Path(save_dir)
+
+ if is_torch_available():
+ weights_path = save_directory / "pytorch_model.bin"
+ torch.save(model.state_dict(), weights_path)
+ elif is_tf_available():
+ weights_path = save_directory / "tf_model" / "weights"
+ model.save_weights(str(weights_path))
+
+ config_path = save_directory / "config.json"
+
+ # add model configuration
+ model_config = model.cfg
+ model_config["arch"] = arch
+ model_config["task"] = task
+
+ with config_path.open("w") as f:
+ json.dump(model_config, f, indent=2, ensure_ascii=False)
+
+
+def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # pragma: no cover
+ """Save model and its configuration on HF hub
+
+ >>> from doctr.models import login_to_hub, push_to_hf_hub
+ >>> from doctr.models.recognition import crnn_mobilenet_v3_small
+ >>> login_to_hub()
+ >>> model = crnn_mobilenet_v3_small(pretrained=True)
+ >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
+
+ Args:
+ ----
+ model: TF or PyTorch model to be saved
+ model_name: name of the model which is also the repository name
+ task: task name
+ **kwargs: keyword arguments for push_to_hf_hub
+ """
+ run_config = kwargs.get("run_config", None)
+ arch = kwargs.get("arch", None)
+
+ if run_config is None and arch is None:
+ raise ValueError("run_config or arch must be specified")
+ if task not in ["classification", "detection", "recognition", "obj_detection"]:
+ raise ValueError("task must be one of classification, detection, recognition, obj_detection")
+
+ # default readme
+ readme = textwrap.dedent(
+ f"""
+ ---
+ language: en
+ ---
+
+
+
+
+
+ **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
+
+ ## Task: {task}
+
+ https://github.com/mindee/doctr
+
+ ### Example usage:
+
+ ```python
+ >>> from doctr.io import DocumentFile
+ >>> from doctr.models import ocr_predictor, from_hub
+
+ >>> img = DocumentFile.from_images([''])
+ >>> # Load your model from the hub
+ >>> model = from_hub('mindee/my-model')
+
+ >>> # Pass it to the predictor
+ >>> # If your model is a recognition model:
+ >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
+ >>> reco_arch=model,
+ >>> pretrained=True)
+
+ >>> # If your model is a detection model:
+ >>> predictor = ocr_predictor(det_arch=model,
+ >>> reco_arch='crnn_mobilenet_v3_small',
+ >>> pretrained=True)
+
+ >>> # Get your predictions
+ >>> res = predictor(img)
+ ```
+ """
+ )
+
+ # add run configuration to readme if available
+ if run_config is not None:
+ arch = run_config.arch
+ readme += textwrap.dedent(
+ f"""### Run Configuration
+ \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
+ )
+
+ if arch not in AVAILABLE_ARCHS[task]: # type: ignore
+ raise ValueError(
+ f"Architecture: {arch} for task: {task} not found.\
+ \nAvailable architectures: {AVAILABLE_ARCHS}"
+ )
+
+ commit_message = f"Add {model_name} model"
+
+ local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
+ repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
+ repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
+
+ with repo.commit(commit_message):
+ _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
+ readme_path = Path(repo.local_dir) / "README.md"
+ readme_path.write_text(readme)
+
+ repo.git_push()
+
+
+def from_hub(repo_id: str, **kwargs: Any):
+ """Instantiate & load a pretrained model from HF hub.
+
+ >>> from doctr.models import from_hub
+ >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
+
+ Args:
+ ----
+ repo_id: HuggingFace model hub repo
+ kwargs: kwargs of `hf_hub_download` or `snapshot_download`
+
+ Returns:
+ -------
+ Model loaded with the checkpoint
+ """
+ # Get the config
+ with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
+ cfg = json.load(f)
+
+ arch = cfg["arch"]
+ task = cfg["task"]
+ cfg.pop("arch")
+ cfg.pop("task")
+
+ if task == "classification":
+ model = models.classification.__dict__[arch](
+ pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"]
+ )
+ elif task == "detection":
+ model = models.detection.__dict__[arch](pretrained=False)
+ elif task == "recognition":
+ model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
+ elif task == "obj_detection" and is_torch_available():
+ model = models.obj_detection.__dict__[arch](
+ pretrained=False,
+ image_mean=cfg["mean"],
+ image_std=cfg["std"],
+ max_size=cfg["input_shape"][-1],
+ num_classes=len(cfg["classes"]),
+ )
+
+ # update model cfg
+ model.cfg = cfg
+
+ # Load checkpoint
+ if is_torch_available():
+ state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
+ model.load_state_dict(state_dict)
+ else: # tf
+ repo_path = snapshot_download(repo_id, **kwargs)
+ model.load_weights(os.path.join(repo_path, "tf_model", "weights"))
+
+ return model
diff --git a/doctr/models/kie_predictor/__init__.py b/doctr/models/kie_predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd
--- /dev/null
+++ b/doctr/models/kie_predictor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available
+
+if is_tf_available():
+ from .tensorflow import *
+else:
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..107009bed491f7fed49c1d6d400a2973618c187e
--- /dev/null
+++ b/doctr/models/kie_predictor/base.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Optional
+
+from doctr.models.builder import KIEDocumentBuilder
+
+from ..classification.predictor import CropOrientationPredictor
+from ..predictor.base import _OCRPredictor
+
+__all__ = ["_KIEPredictor"]
+
+
+class _KIEPredictor(_OCRPredictor):
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
+ symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
+ kwargs: keyword args of `DocumentBuilder`
+ """
+
+ crop_orientation_predictor: Optional[CropOrientationPredictor]
+
+ def __init__(
+ self,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs)
+
+ self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fa183f5b930ae8cdc217d2af0b5283aa0f69c82
--- /dev/null
+++ b/doctr/models/kie_predictor/pytorch.py
@@ -0,0 +1,176 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict, List, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.io.elements import Document
+from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
+from doctr.models.detection.predictor import DetectionPredictor
+from doctr.models.recognition.predictor import RecognitionPredictor
+from doctr.utils.geometry import rotate_image
+
+from .base import _KIEPredictor
+
+__all__ = ["KIEPredictor"]
+
+
+class KIEPredictor(nn.Module, _KIEPredictor):
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ det_predictor: detection module
+ reco_predictor: recognition module
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ detect_language: if True, the language prediction will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ **kwargs: keyword args of `DocumentBuilder`
+ """
+
+ def __init__(
+ self,
+ det_predictor: DetectionPredictor,
+ reco_predictor: RecognitionPredictor,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ detect_orientation: bool = False,
+ detect_language: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ nn.Module.__init__(self)
+ self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
+ self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
+ _KIEPredictor.__init__(
+ self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
+ )
+ self.detect_orientation = detect_orientation
+ self.detect_language = detect_language
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ pages: List[Union[np.ndarray, torch.Tensor]],
+ **kwargs: Any,
+ ) -> Document:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
+
+ # Localize text elements
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
+
+ # Detect document rotation and rotate pages
+ seg_maps = [
+ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
+ np.uint8
+ )
+ for out_map in out_maps
+ ]
+ if self.detect_orientation:
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
+ orientations = [
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
+ ]
+ else:
+ orientations = None
+ if self.straighten_pages:
+ origin_page_orientations = (
+ origin_page_orientations
+ if self.detect_orientation
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
+ )
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # type: ignore[arg-type]
+ # Forward again to get predictions on straight pages
+ loc_preds = self.det_predictor(pages, **kwargs)
+
+ dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
+ # Check whether crop mode should be switched to channels first
+ channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
+
+ # Rectify crops if aspect ratio
+ dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()} # type: ignore[arg-type]
+
+ # Apply hooks to loc_preds if any
+ for hook in self.hooks:
+ dict_loc_preds = hook(dict_loc_preds)
+
+ # Crop images
+ crops = {}
+ for class_name in dict_loc_preds.keys():
+ crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
+ pages, # type: ignore[arg-type]
+ dict_loc_preds[class_name],
+ channels_last=channels_last,
+ assume_straight_pages=self.assume_straight_pages,
+ )
+ # Rectify crop orientation
+ crop_orientations: Any = {}
+ if not self.assume_straight_pages:
+ for class_name in dict_loc_preds.keys():
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
+ crops[class_name], dict_loc_preds[class_name]
+ )
+ crop_orientations[class_name] = [
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
+ ]
+
+ # Identify character sequences
+ word_preds = {
+ k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
+ for k, crop_value in crops.items()
+ }
+ if not crop_orientations:
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
+
+ boxes: Dict = {}
+ text_preds: Dict = {}
+ word_crop_orientations: Dict = {}
+ for class_name in dict_loc_preds.keys():
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
+ )
+
+ boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
+ text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
+
+ if self.detect_language:
+ languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
+ languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
+ else:
+ languages_dict = None
+
+ out = self.doc_builder(
+ pages, # type: ignore[arg-type]
+ boxes_per_page,
+ text_preds_per_page,
+ origin_page_shapes, # type: ignore[arg-type]
+ crop_orientations_per_page,
+ orientations,
+ languages_dict,
+ )
+ return out
+
+ @staticmethod
+ def get_text(text_pred: Dict) -> str:
+ text = []
+ for value in text_pred.values():
+ text += [item[0] for item in value]
+
+ return " ".join(text)
diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b1211dd070e43dfa5f5385b0d326a38608a576
--- /dev/null
+++ b/doctr/models/kie_predictor/tensorflow.py
@@ -0,0 +1,171 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict, List, Union
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.io.elements import Document
+from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
+from doctr.models.detection.predictor import DetectionPredictor
+from doctr.models.recognition.predictor import RecognitionPredictor
+from doctr.utils.geometry import rotate_image
+from doctr.utils.repr import NestedObject
+
+from .base import _KIEPredictor
+
+__all__ = ["KIEPredictor"]
+
+
+class KIEPredictor(NestedObject, _KIEPredictor):
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ det_predictor: detection module
+ reco_predictor: recognition module
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ detect_language: if True, the language prediction will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ **kwargs: keyword args of `DocumentBuilder`
+ """
+
+ _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
+
+ def __init__(
+ self,
+ det_predictor: DetectionPredictor,
+ reco_predictor: RecognitionPredictor,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ detect_orientation: bool = False,
+ detect_language: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ self.det_predictor = det_predictor
+ self.reco_predictor = reco_predictor
+ _KIEPredictor.__init__(
+ self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
+ )
+ self.detect_orientation = detect_orientation
+ self.detect_language = detect_language
+
+ def __call__(
+ self,
+ pages: List[Union[np.ndarray, tf.Tensor]],
+ **kwargs: Any,
+ ) -> Document:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ origin_page_shapes = [page.shape[:2] for page in pages]
+
+ # Localize text elements
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
+
+ # Detect document rotation and rotate pages
+ seg_maps = [
+ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
+ np.uint8
+ )
+ for out_map in out_maps
+ ]
+ if self.detect_orientation:
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
+ orientations = [
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
+ ]
+ else:
+ orientations = None
+ if self.straighten_pages:
+ origin_page_orientations = (
+ origin_page_orientations
+ if self.detect_orientation
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
+ )
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
+ # Forward again to get predictions on straight pages
+ loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
+
+ dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
+ # Rectify crops if aspect ratio
+ dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
+
+ # Apply hooks to loc_preds if any
+ for hook in self.hooks:
+ dict_loc_preds = hook(dict_loc_preds)
+
+ # Crop images
+ crops = {}
+ for class_name in dict_loc_preds.keys():
+ crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
+ pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
+ )
+
+ # Rectify crop orientation
+ crop_orientations: Any = {}
+ if not self.assume_straight_pages:
+ for class_name in dict_loc_preds.keys():
+ crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
+ crops[class_name], dict_loc_preds[class_name]
+ )
+ crop_orientations[class_name] = [
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
+ ]
+
+ # Identify character sequences
+ word_preds = {
+ k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
+ for k, crop_value in crops.items()
+ }
+ if not crop_orientations:
+ crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
+
+ boxes: Dict = {}
+ text_preds: Dict = {}
+ word_crop_orientations: Dict = {}
+ for class_name in dict_loc_preds.keys():
+ boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
+ dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
+ )
+
+ boxes_per_page: List[Dict] = invert_data_structure(boxes) # type: ignore[assignment]
+ text_preds_per_page: List[Dict] = invert_data_structure(text_preds) # type: ignore[assignment]
+ crop_orientations_per_page: List[Dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
+
+ if self.detect_language:
+ languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
+ languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
+ else:
+ languages_dict = None
+
+ out = self.doc_builder(
+ pages,
+ boxes_per_page,
+ text_preds_per_page,
+ origin_page_shapes, # type: ignore[arg-type]
+ crop_orientations_per_page,
+ orientations,
+ languages_dict,
+ )
+ return out
+
+ @staticmethod
+ def get_text(text_pred: Dict) -> str:
+ text = []
+ for value in text_pred.values():
+ text += [item[0] for item in value]
+
+ return " ".join(text)
diff --git a/doctr/models/modules/__init__.py b/doctr/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d659f1cdb7d675f32c1b504965bd453cc7a9a4d8
--- /dev/null
+++ b/doctr/models/modules/__init__.py
@@ -0,0 +1,3 @@
+from .layers import *
+from .transformer import *
+from .vision_transformer import *
diff --git a/doctr/models/modules/layers/__init__.py b/doctr/models/modules/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/modules/layers/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ad119ec9ba00caaaabaaf673b8d442014da258
--- /dev/null
+++ b/doctr/models/modules/layers/pytorch.py
@@ -0,0 +1,165 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+__all__ = ["FASTConvLayer"]
+
+
+class FASTConvLayer(nn.Module):
+ """Convolutional layer used in the TextNet and FAST architectures"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.groups = groups
+ self.in_channels = in_channels
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.hor_conv, self.hor_bn = None, None
+ self.ver_conv, self.ver_bn = None, None
+
+ padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
+
+ self.activation = nn.ReLU(inplace=True)
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=self.converted_ks,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ if self.converted_ks[1] != 1:
+ self.ver_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=(self.converted_ks[0], 1),
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ self.ver_bn = nn.BatchNorm2d(out_channels)
+
+ if self.converted_ks[0] != 1:
+ self.hor_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=(1, self.converted_ks[1]),
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ self.hor_bn = nn.BatchNorm2d(out_channels)
+
+ self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if hasattr(self, "fused_conv"):
+ return self.activation(self.fused_conv(x))
+
+ main_outputs = self.bn(self.conv(x))
+ vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
+ horizontal_outputs = (
+ self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
+ )
+ id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
+
+ return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
+
+ # The following logic is used to reparametrize the layer
+ # Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
+ def _identity_to_conv(
+ self, identity: Union[nn.BatchNorm2d, None]
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
+ if identity is None or identity.running_var is None:
+ return 0, 0
+ if not hasattr(self, "id_tensor"):
+ input_dim = self.in_channels // self.groups
+ kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
+ for i in range(self.in_channels):
+ kernel_value[i, i % input_dim, 0, 0] = 1
+ id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
+ kernel = self.id_tensor
+ std = (identity.running_var + identity.eps).sqrt()
+ t = (identity.weight / std).reshape(-1, 1, 1, 1)
+ return kernel * t, identity.bias - identity.running_mean * identity.weight / std
+
+ def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
+ kernel = conv.weight
+ kernel = self._pad_to_mxn_tensor(kernel)
+ std = (bn.running_var + bn.eps).sqrt() # type: ignore
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
+ return kernel * t, bn.bias - bn.running_mean * bn.weight / std
+
+ def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
+ if self.ver_conv is not None:
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
+ else:
+ kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
+ if self.hor_conv is not None:
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
+ else:
+ kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
+ return kernel_mxn, bias_mxn
+
+ def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
+ kernel_height, kernel_width = self.converted_ks
+ height, width = kernel.shape[2:]
+ pad_left_right = (kernel_width - width) // 2
+ pad_top_down = (kernel_height - height) // 2
+ return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
+
+ def reparameterize_layer(self):
+ if hasattr(self, "fused_conv"):
+ return
+ kernel, bias = self._get_equivalent_kernel_bias()
+ self.fused_conv = nn.Conv2d(
+ in_channels=self.conv.in_channels,
+ out_channels=self.conv.out_channels,
+ kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
+ stride=self.conv.stride, # type: ignore[arg-type]
+ padding=self.conv.padding, # type: ignore[arg-type]
+ dilation=self.conv.dilation, # type: ignore[arg-type]
+ groups=self.conv.groups,
+ bias=True,
+ )
+ self.fused_conv.weight.data = kernel
+ self.fused_conv.bias.data = bias # type: ignore[union-attr]
+ for para in self.parameters():
+ para.detach_()
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
+ if hasattr(self, attr):
+ self.__delattr__(attr)
+
+ if hasattr(self, "rbr_identity"):
+ self.__delattr__("rbr_identity")
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..68849fbf6e5ad5f08a34ff90403dc44612b53834
--- /dev/null
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -0,0 +1,173 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras import layers
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["FASTConvLayer"]
+
+
+class FASTConvLayer(layers.Layer, NestedObject):
+ """Convolutional layer used in the TextNet and FAST architectures"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.groups = groups
+ self.in_channels = in_channels
+ self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.hor_conv, self.hor_bn = None, None
+ self.ver_conv, self.ver_bn = None, None
+
+ padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2)
+
+ self.activation = layers.ReLU()
+ self.conv_pad = layers.ZeroPadding2D(padding=padding)
+
+ self.conv = layers.Conv2D(
+ filters=out_channels,
+ kernel_size=self.converted_ks,
+ strides=stride,
+ dilation_rate=dilation,
+ groups=groups,
+ use_bias=bias,
+ )
+
+ self.bn = layers.BatchNormalization()
+
+ if self.converted_ks[1] != 1:
+ self.ver_pad = layers.ZeroPadding2D(
+ padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
+ )
+ self.ver_conv = layers.Conv2D(
+ filters=out_channels,
+ kernel_size=(self.converted_ks[0], 1),
+ strides=stride,
+ dilation_rate=dilation,
+ groups=groups,
+ use_bias=bias,
+ )
+ self.ver_bn = layers.BatchNormalization()
+
+ if self.converted_ks[0] != 1:
+ self.hor_pad = layers.ZeroPadding2D(
+ padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
+ )
+ self.hor_conv = layers.Conv2D(
+ filters=out_channels,
+ kernel_size=(1, self.converted_ks[1]),
+ strides=stride,
+ dilation_rate=dilation,
+ groups=groups,
+ use_bias=bias,
+ )
+ self.hor_bn = layers.BatchNormalization()
+
+ self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
+
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ if hasattr(self, "fused_conv"):
+ return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs))
+
+ main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs)
+ vertical_outputs = (
+ self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs)
+ if self.ver_conv is not None and self.ver_bn is not None
+ else 0
+ )
+ horizontal_outputs = (
+ self.hor_bn(self.hor_conv(self.hor_pad(x, **kwargs), **kwargs), **kwargs)
+ if self.hor_bn is not None and self.hor_conv is not None
+ else 0
+ )
+ id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
+
+ return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
+
+ # The following logic is used to reparametrize the layer
+ # Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
+ def _identity_to_conv(
+ self, identity: layers.BatchNormalization
+ ) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
+ if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
+ return 0, 0
+ if not hasattr(self, "id_tensor"):
+ input_dim = self.in_channels // self.groups
+ kernel_value = np.zeros((1, 1, input_dim, self.in_channels), dtype=np.float32)
+ for i in range(self.in_channels):
+ kernel_value[0, 0, i % input_dim, i] = 1
+ id_tensor = tf.constant(kernel_value, dtype=tf.float32)
+ self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
+ kernel = self.id_tensor
+ std = tf.sqrt(identity.moving_variance + identity.epsilon)
+ t = tf.reshape(identity.gamma / std, (1, 1, 1, -1))
+ return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
+
+ def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
+ kernel = conv.kernel
+ kernel = self._pad_to_mxn_tensor(kernel)
+ std = tf.sqrt(bn.moving_variance + bn.epsilon)
+ t = tf.reshape(bn.gamma / std, (1, 1, 1, -1))
+ return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std
+
+ def _get_equivalent_kernel_bias(self):
+ kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
+ if self.ver_conv is not None:
+ kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
+ else:
+ kernel_mx1, bias_mx1 = 0, 0
+ if self.hor_conv is not None:
+ kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
+ else:
+ kernel_1xn, bias_1xn = 0, 0
+ kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
+ kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
+ bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
+ return kernel_mxn, bias_mxn
+
+ def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
+ kernel_height, kernel_width = self.converted_ks
+ height, width = kernel.shape[:2]
+ pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
+ pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
+ return tf.pad(kernel, [[pad_top_down, pad_top_down], [pad_left_right, pad_left_right], [0, 0], [0, 0]])
+
+ def reparameterize_layer(self):
+ kernel, bias = self._get_equivalent_kernel_bias()
+ self.fused_conv = layers.Conv2D(
+ filters=self.conv.filters,
+ kernel_size=self.conv.kernel_size,
+ strides=self.conv.strides,
+ padding=self.conv.padding,
+ dilation_rate=self.conv.dilation_rate,
+ groups=self.conv.groups,
+ use_bias=True,
+ )
+ # build layer to initialize weights and biases
+ self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2]))
+ self.fused_conv.set_weights([kernel.numpy(), bias.numpy()])
+ for para in self.trainable_variables:
+ para._trainable = False
+ for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
+ if hasattr(self, attr):
+ delattr(self, attr)
+
+ if hasattr(self, "rbr_identity"):
+ delattr(self, "rbr_identity")
diff --git a/doctr/models/modules/transformer/__init__.py b/doctr/models/modules/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/modules/transformer/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1c61297813b1a6b3d7124fd61dd3e37601dcfb
--- /dev/null
+++ b/doctr/models/modules/transformer/pytorch.py
@@ -0,0 +1,202 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+# This module 'transformer.py' is inspired by https://github.com/wenwenyu/MASTER-pytorch and Decoder is borrowed
+
+import math
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+from torch import nn
+
+__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "MultiHeadAttention", "PositionwiseFeedForward"]
+
+
+class PositionalEncoding(nn.Module):
+ """Compute positional encoding"""
+
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ # Compute the positional encodings once in log space.
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe.unsqueeze(0))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass
+
+ Args:
+ ----
+ x: embeddings (batch, max_len, d_model)
+
+ Returns
+ -------
+ positional embeddings (batch, max_len, d_model)
+ """
+ x = x + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+def scaled_dot_product_attention(
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Scaled Dot-Product Attention"""
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
+ if mask is not None:
+ # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
+ scores = scores.masked_fill(mask == 0, float("-inf"))
+ p_attn = torch.softmax(scores, dim=-1)
+ return torch.matmul(p_attn, value), p_attn
+
+
+class PositionwiseFeedForward(nn.Sequential):
+ """Position-wise Feed-Forward Network"""
+
+ def __init__(
+ self, d_model: int, ffd: int, dropout: float = 0.1, activation_fct: Callable[[Any], Any] = nn.ReLU()
+ ) -> None:
+ super().__init__( # type: ignore[call-overload]
+ nn.Linear(d_model, ffd),
+ activation_fct,
+ nn.Dropout(p=dropout),
+ nn.Linear(ffd, d_model),
+ nn.Dropout(p=dropout),
+ )
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-Head Attention"""
+
+ def __init__(self, num_heads: int, d_model: int, dropout: float = 0.1) -> None:
+ super().__init__()
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
+
+ self.d_k = d_model // num_heads
+ self.num_heads = num_heads
+
+ self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
+ self.output_linear = nn.Linear(d_model, d_model)
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None) -> torch.Tensor:
+ batch_size = query.size(0)
+
+ # linear projections of Q, K, V
+ query, key, value = [
+ linear(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
+ for linear, x in zip(self.linear_layers, (query, key, value))
+ ]
+
+ # apply attention on all the projected vectors in batch
+ x, attn = scaled_dot_product_attention(query, key, value, mask=mask)
+
+ # Concat attention heads
+ x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
+
+ return self.output_linear(x)
+
+
+class EncoderBlock(nn.Module):
+ """Transformer Encoder Block"""
+
+ def __init__(
+ self,
+ num_layers: int,
+ num_heads: int,
+ d_model: int,
+ dff: int, # hidden dimension of the feedforward network
+ dropout: float,
+ activation_fct: Callable[[Any], Any] = nn.ReLU(),
+ ) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+
+ self.layer_norm_input = nn.LayerNorm(d_model, eps=1e-5)
+ self.layer_norm_attention = nn.LayerNorm(d_model, eps=1e-5)
+ self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5)
+ self.dropout = nn.Dropout(dropout)
+
+ self.attention = nn.ModuleList([
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
+ ])
+ self.position_feed_forward = nn.ModuleList([
+ PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
+ ])
+
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ output = x
+
+ for i in range(self.num_layers):
+ normed_output = self.layer_norm_input(output)
+ output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output, mask))
+ normed_output = self.layer_norm_attention(output)
+ output = output + self.dropout(self.position_feed_forward[i](normed_output))
+
+ # (batch_size, seq_len, d_model)
+ return self.layer_norm_output(output)
+
+
+class Decoder(nn.Module):
+ """Transformer Decoder"""
+
+ def __init__(
+ self,
+ num_layers: int,
+ num_heads: int,
+ d_model: int,
+ vocab_size: int,
+ dropout: float = 0.2,
+ dff: int = 2048, # hidden dimension of the feedforward network
+ maximum_position_encoding: int = 50,
+ ) -> None:
+ super(Decoder, self).__init__()
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.layer_norm_input = nn.LayerNorm(d_model, eps=1e-5)
+ self.layer_norm_masked_attention = nn.LayerNorm(d_model, eps=1e-5)
+ self.layer_norm_attention = nn.LayerNorm(d_model, eps=1e-5)
+ self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5)
+
+ self.dropout = nn.Dropout(dropout)
+ self.embed = nn.Embedding(vocab_size, d_model)
+ self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding)
+
+ self.attention = nn.ModuleList([
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
+ ])
+ self.source_attention = nn.ModuleList([
+ MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
+ ])
+ self.position_feed_forward = nn.ModuleList([
+ PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)
+ ])
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ source_mask: Optional[torch.Tensor] = None,
+ target_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ tgt = self.embed(tgt) * math.sqrt(self.d_model)
+ pos_enc_tgt = self.positional_encoding(tgt)
+ output = pos_enc_tgt
+
+ for i in range(self.num_layers):
+ normed_output = self.layer_norm_input(output)
+ output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output, target_mask))
+ normed_output = self.layer_norm_masked_attention(output)
+ output = output + self.dropout(self.source_attention[i](normed_output, memory, memory, source_mask))
+ normed_output = self.layer_norm_attention(output)
+ output = output + self.dropout(self.position_feed_forward[i](normed_output))
+
+ # (batch_size, seq_len, d_model)
+ return self.layer_norm_output(output)
diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..403f99117dccc90c051f9cddbc3a97baf0c94d7d
--- /dev/null
+++ b/doctr/models/modules/transformer/tensorflow.py
@@ -0,0 +1,238 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Any, Callable, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
+
+tf.config.run_functions_eagerly(True)
+
+
+class PositionalEncoding(layers.Layer, NestedObject):
+ """Compute positional encoding"""
+
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
+ super(PositionalEncoding, self).__init__()
+ self.dropout = layers.Dropout(rate=dropout)
+
+ # Compute the positional encodings once in log space.
+ pe = tf.Variable(tf.zeros((max_len, d_model)))
+ position = tf.cast(
+ tf.expand_dims(tf.experimental.numpy.arange(start=0, stop=max_len), axis=1), dtype=tf.float32
+ )
+ div_term = tf.math.exp(
+ tf.cast(tf.experimental.numpy.arange(start=0, stop=d_model, step=2), dtype=tf.float32)
+ * -(math.log(10000.0) / d_model)
+ )
+ pe = pe.numpy()
+ pe[:, 0::2] = tf.math.sin(position * div_term)
+ pe[:, 1::2] = tf.math.cos(position * div_term)
+ self.pe = tf.expand_dims(tf.convert_to_tensor(pe), axis=0)
+
+ def call(
+ self,
+ x: tf.Tensor,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ """Forward pass
+
+ Args:
+ ----
+ x: embeddings (batch, max_len, d_model)
+ **kwargs: additional arguments
+
+ Returns
+ -------
+ positional embeddings (batch, max_len, d_model)
+ """
+ if x.dtype == tf.float16: # amp fix: cast to half
+ x = x + tf.cast(self.pe[:, : x.shape[1]], dtype=tf.half)
+ else:
+ x = x + self.pe[:, : x.shape[1]]
+ return self.dropout(x, **kwargs)
+
+
+@tf.function
+def scaled_dot_product_attention(
+ query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, mask: Optional[tf.Tensor] = None
+) -> Tuple[tf.Tensor, tf.Tensor]:
+ """Scaled Dot-Product Attention"""
+ scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2])) / math.sqrt(query.shape[-1])
+ if mask is not None:
+ # NOTE: to ensure the ONNX compatibility, tf.where works only with bool type condition
+ scores = tf.where(mask == False, float("-inf"), scores) # noqa: E712
+ p_attn = tf.nn.softmax(scores, axis=-1)
+ return tf.matmul(p_attn, value), p_attn
+
+
+class PositionwiseFeedForward(layers.Layer, NestedObject):
+ """Position-wise Feed-Forward Network"""
+
+ def __init__(
+ self, d_model: int, ffd: int, dropout=0.1, activation_fct: Callable[[Any], Any] = layers.ReLU()
+ ) -> None:
+ super(PositionwiseFeedForward, self).__init__()
+ self.activation_fct = activation_fct
+
+ self.first_linear = layers.Dense(ffd, kernel_initializer=tf.initializers.he_uniform())
+ self.sec_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform())
+ self.dropout = layers.Dropout(rate=dropout)
+
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ x = self.first_linear(x, **kwargs)
+ x = self.activation_fct(x)
+ x = self.dropout(x, **kwargs)
+ x = self.sec_linear(x, **kwargs)
+ x = self.dropout(x, **kwargs)
+ return x
+
+
+class MultiHeadAttention(layers.Layer, NestedObject):
+ """Multi-Head Attention"""
+
+ def __init__(self, num_heads: int, d_model: int, dropout: float = 0.1) -> None:
+ super().__init__()
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
+
+ self.d_k = d_model // num_heads
+ self.num_heads = num_heads
+
+ self.linear_layers = [layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) for _ in range(3)]
+ self.output_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform())
+
+ def call(
+ self,
+ query: tf.Tensor,
+ key: tf.Tensor,
+ value: tf.Tensor,
+ mask: tf.Tensor = None,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ batch_size = query.shape[0]
+
+ # linear projections of Q, K, V
+ query, key, value = [
+ tf.transpose(
+ tf.reshape(linear(x, **kwargs), shape=[batch_size, -1, self.num_heads, self.d_k]), perm=[0, 2, 1, 3]
+ )
+ for linear, x in zip(self.linear_layers, (query, key, value))
+ ]
+
+ # apply attention on all the projected vectors in batch
+ x, attn = scaled_dot_product_attention(query, key, value, mask=mask)
+
+ # Concat attention heads
+ x = tf.transpose(x, perm=[0, 2, 1, 3])
+ x = tf.reshape(x, shape=[batch_size, -1, self.num_heads * self.d_k])
+
+ return self.output_linear(x, **kwargs)
+
+
+class EncoderBlock(layers.Layer, NestedObject):
+ """Transformer Encoder Block"""
+
+ def __init__(
+ self,
+ num_layers: int,
+ num_heads: int,
+ d_model: int,
+ dff: int, # hidden dimension of the feedforward network
+ dropout: float,
+ activation_fct: Callable[[Any], Any] = layers.ReLU(),
+ ) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+
+ self.layer_norm_input = layers.LayerNormalization(epsilon=1e-5)
+ self.layer_norm_attention = layers.LayerNormalization(epsilon=1e-5)
+ self.layer_norm_output = layers.LayerNormalization(epsilon=1e-5)
+ self.dropout = layers.Dropout(rate=dropout)
+
+ self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
+ self.position_feed_forward = [
+ PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
+ ]
+
+ def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
+ output = x
+
+ for i in range(self.num_layers):
+ normed_output = self.layer_norm_input(output, **kwargs)
+ output = output + self.dropout(
+ self.attention[i](normed_output, normed_output, normed_output, mask, **kwargs),
+ **kwargs,
+ )
+ normed_output = self.layer_norm_attention(output, **kwargs)
+ output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs)
+
+ # (batch_size, seq_len, d_model)
+ return self.layer_norm_output(output, **kwargs)
+
+
+class Decoder(layers.Layer, NestedObject):
+ """Transformer Decoder"""
+
+ def __init__(
+ self,
+ num_layers: int,
+ num_heads: int,
+ d_model: int,
+ vocab_size: int,
+ dropout: float = 0.2,
+ dff: int = 2048, # hidden dimension of the feedforward network
+ maximum_position_encoding: int = 50,
+ ) -> None:
+ super(Decoder, self).__init__()
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.layer_norm_input = layers.LayerNormalization(epsilon=1e-5)
+ self.layer_norm_masked_attention = layers.LayerNormalization(epsilon=1e-5)
+ self.layer_norm_attention = layers.LayerNormalization(epsilon=1e-5)
+ self.layer_norm_output = layers.LayerNormalization(epsilon=1e-5)
+
+ self.dropout = layers.Dropout(rate=dropout)
+ self.embed = layers.Embedding(vocab_size, d_model)
+ self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding)
+
+ self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
+ self.source_attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
+ self.position_feed_forward = [PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)]
+
+ def call(
+ self,
+ tgt: tf.Tensor,
+ memory: tf.Tensor,
+ source_mask: Optional[tf.Tensor] = None,
+ target_mask: Optional[tf.Tensor] = None,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ tgt = self.embed(tgt, **kwargs) * math.sqrt(self.d_model)
+ pos_enc_tgt = self.positional_encoding(tgt, **kwargs)
+ output = pos_enc_tgt
+
+ for i in range(self.num_layers):
+ normed_output = self.layer_norm_input(output, **kwargs)
+ output = output + self.dropout(
+ self.attention[i](normed_output, normed_output, normed_output, target_mask, **kwargs),
+ **kwargs,
+ )
+ normed_output = self.layer_norm_masked_attention(output, **kwargs)
+ output = output + self.dropout(
+ self.source_attention[i](normed_output, memory, memory, source_mask, **kwargs),
+ **kwargs,
+ )
+ normed_output = self.layer_norm_attention(output, **kwargs)
+ output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs)
+
+ # (batch_size, seq_len, d_model)
+ return self.layer_norm_output(output, **kwargs)
diff --git a/doctr/models/modules/vision_transformer/__init__.py b/doctr/models/modules/vision_transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/modules/vision_transformer/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/modules/vision_transformer/pytorch.py b/doctr/models/modules/vision_transformer/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff07ed4ffe4ea6fd769d87fae35c085d927e971
--- /dev/null
+++ b/doctr/models/modules/vision_transformer/pytorch.py
@@ -0,0 +1,84 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Tuple
+
+import torch
+from torch import nn
+
+__all__ = ["PatchEmbedding"]
+
+
+class PatchEmbedding(nn.Module):
+ """Compute 2D patch embeddings with cls token and positional encoding"""
+
+ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
+ super().__init__()
+ channels, height, width = input_shape
+ self.patch_size = patch_size
+ self.interpolate = True if patch_size[0] == patch_size[1] else False
+ self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
+ self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
+ self.projection = nn.Conv2d(channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """100 % borrowed from:
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
+
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.positions.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.positions
+ class_pos_embed = self.positions[:, 0]
+ patch_pos_embed = self.positions[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = float(height // self.patch_size[0])
+ w0 = float(width // self.patch_size[1])
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+ mode="bilinear",
+ align_corners=False,
+ recompute_scale_factor=True,
+ )
+ assert int(h0) == patch_pos_embed.shape[-2], "height of interpolated patch embedding doesn't match"
+ assert int(w0) == patch_pos_embed.shape[-1], "width of interpolated patch embedding doesn't match"
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, H, W = x.shape
+ assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
+ assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"
+
+ # patchify image
+ patches = self.projection(x).flatten(2).transpose(1, 2)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # (batch_size, 1, d_model)
+ # concate cls_tokens to patches
+ embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model)
+ # add positions to embeddings
+ if self.interpolate:
+ embeddings += self.interpolate_pos_encoding(embeddings, H, W)
+ else:
+ embeddings += self.positions
+
+ return embeddings # (batch_size, num_patches + 1, d_model)
diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78f0da3fbdb035859bc4db4cb27b055e1330e4a
--- /dev/null
+++ b/doctr/models/modules/vision_transformer/tensorflow.py
@@ -0,0 +1,100 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Any, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+
+from doctr.utils.repr import NestedObject
+
+__all__ = ["PatchEmbedding"]
+
+
+class PatchEmbedding(layers.Layer, NestedObject):
+ """Compute 2D patch embeddings with cls token and positional encoding"""
+
+ def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None:
+ super().__init__()
+ height, width, _ = input_shape
+ self.patch_size = patch_size
+ self.interpolate = True if patch_size[0] == patch_size[1] else False
+ self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+
+ self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
+ self.positions = self.add_weight(
+ shape=(1, self.num_patches + 1, embed_dim),
+ initializer="zeros",
+ trainable=True,
+ name="positions",
+ )
+ self.projection = layers.Conv2D(
+ filters=embed_dim,
+ kernel_size=self.patch_size,
+ strides=self.patch_size,
+ padding="valid",
+ data_format="channels_last",
+ use_bias=True,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ name="projection",
+ )
+
+ def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
+ """100 % borrowed from:
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py
+
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
+ """
+ seq_len, dim = embeddings.shape[1:]
+ num_patches = seq_len - 1
+
+ num_positions = self.positions.shape[1] - 1
+
+ if num_patches == num_positions and height == width:
+ return self.positions
+ class_pos_embed = self.positions[:, :1]
+ patch_pos_embed = self.positions[:, 1:]
+ h0 = height // self.patch_size[0]
+ w0 = width // self.patch_size[1]
+ patch_pos_embed = tf.image.resize(
+ images=tf.reshape(
+ patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ ),
+ size=(h0, w0),
+ method="bilinear",
+ )
+
+ shape = patch_pos_embed.shape
+ assert h0 == shape[-3], "height of interpolated patch embedding doesn't match"
+ assert w0 == shape[-2], "width of interpolated patch embedding doesn't match"
+
+ patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
+ return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
+
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ B, H, W, C = x.shape
+ assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
+ assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"
+ # patchify image
+ patches = self.projection(x, **kwargs) # (batch_size, num_patches, d_model)
+ patches = tf.reshape(patches, (B, self.num_patches, -1)) # (batch_size, num_patches, d_model)
+
+ cls_tokens = tf.repeat(self.cls_token, B, axis=0) # (batch_size, 1, d_model)
+ # concate cls_tokens to patches
+ embeddings = tf.concat([cls_tokens, patches], axis=1) # (batch_size, num_patches + 1, d_model)
+ # add positions to embeddings
+ if self.interpolate:
+ embeddings += self.interpolate_pos_encoding(embeddings, H, W)
+ else:
+ embeddings += self.positions
+
+ return embeddings # (batch_size, num_patches + 1, d_model)
diff --git a/doctr/models/obj_detection/__init__.py b/doctr/models/obj_detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7e08d7ac8e7b7114437a6f82347e0dd15430b0d
--- /dev/null
+++ b/doctr/models/obj_detection/__init__.py
@@ -0,0 +1 @@
+from .faster_rcnn import *
diff --git a/doctr/models/obj_detection/faster_rcnn/__init__.py b/doctr/models/obj_detection/faster_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..96dc37cd75a58954bae7c45e78f0f64c1fba6919
--- /dev/null
+++ b/doctr/models/obj_detection/faster_rcnn/__init__.py
@@ -0,0 +1,4 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if not is_tf_available() and is_torch_available():
+ from .pytorch import *
diff --git a/doctr/models/obj_detection/faster_rcnn/pytorch.py b/doctr/models/obj_detection/faster_rcnn/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e670fa7cb628064cd80ec092b356626884bfc9f
--- /dev/null
+++ b/doctr/models/obj_detection/faster_rcnn/pytorch.py
@@ -0,0 +1,81 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Dict
+
+from torchvision.models.detection import FasterRCNN, FasterRCNN_MobileNet_V3_Large_FPN_Weights, faster_rcnn
+
+from ...utils import load_pretrained_params
+
+__all__ = ["fasterrcnn_mobilenet_v3_large_fpn"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "fasterrcnn_mobilenet_v3_large_fpn": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.485, 0.456, 0.406),
+ "std": (0.229, 0.224, 0.225),
+ "classes": ["background", "qr_code", "bar_code", "logo", "photo"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt&src=0",
+ },
+}
+
+
+def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:
+ _kwargs = {
+ "image_mean": default_cfgs[arch]["mean"],
+ "image_std": default_cfgs[arch]["std"],
+ "box_detections_per_img": 150,
+ "box_score_thresh": 0.5,
+ "box_positive_fraction": 0.35,
+ "box_nms_thresh": 0.2,
+ "rpn_nms_thresh": 0.2,
+ "num_classes": len(default_cfgs[arch]["classes"]),
+ }
+
+ # Build the model
+ _kwargs.update(kwargs)
+ model = faster_rcnn.__dict__[arch](weights=None, weights_backbone=None, **_kwargs)
+ model.cfg = default_cfgs[arch]
+
+ if pretrained:
+ # Load pretrained parameters
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+ else:
+ # Filter keys
+ state_dict = {
+ k: v
+ for k, v in faster_rcnn.__dict__[arch](weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+ .state_dict()
+ .items()
+ if not k.startswith("roi_heads.")
+ }
+
+ # Load state dict
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN:
+ """Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time
+ Object Detection with Region Proposal Networks" `_.
+
+ >>> import torch
+ >>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn
+ >>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our object detection dataset
+ **kwargs: keyword arguments of the FasterRCNN architecture
+
+ Returns:
+ -------
+ object detection architecture
+ """
+ return _fasterrcnn("fasterrcnn_mobilenet_v3_large_fpn", pretrained, **kwargs)
diff --git a/doctr/models/predictor/__init__.py b/doctr/models/predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd
--- /dev/null
+++ b/doctr/models/predictor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available
+
+if is_tf_available():
+ from .tensorflow import *
+else:
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc5bfcb5db1fee11b3fceb98bd6508118116818d
--- /dev/null
+++ b/doctr/models/predictor/base.py
@@ -0,0 +1,170 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+
+from doctr.models.builder import DocumentBuilder
+from doctr.utils.geometry import extract_crops, extract_rcrops
+
+from .._utils import rectify_crops, rectify_loc_preds
+from ..classification import crop_orientation_predictor
+from ..classification.predictor import CropOrientationPredictor
+
+__all__ = ["_OCRPredictor"]
+
+
+class _OCRPredictor:
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
+ symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
+ **kwargs: keyword args of `DocumentBuilder`
+ """
+
+ crop_orientation_predictor: Optional[CropOrientationPredictor]
+
+ def __init__(
+ self,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ **kwargs: Any,
+ ) -> None:
+ self.assume_straight_pages = assume_straight_pages
+ self.straighten_pages = straighten_pages
+ self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
+ self.doc_builder = DocumentBuilder(**kwargs)
+ self.preserve_aspect_ratio = preserve_aspect_ratio
+ self.symmetric_pad = symmetric_pad
+ self.hooks: List[Callable] = []
+
+ @staticmethod
+ def _generate_crops(
+ pages: List[np.ndarray],
+ loc_preds: List[np.ndarray],
+ channels_last: bool,
+ assume_straight_pages: bool = False,
+ ) -> List[List[np.ndarray]]:
+ extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
+
+ crops = [
+ extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
+ for page, _boxes in zip(pages, loc_preds)
+ ]
+ return crops
+
+ @staticmethod
+ def _prepare_crops(
+ pages: List[np.ndarray],
+ loc_preds: List[np.ndarray],
+ channels_last: bool,
+ assume_straight_pages: bool = False,
+ ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
+
+ # Avoid sending zero-sized crops
+ is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
+ crops = [
+ [crop for crop, _kept in zip(page_crops, page_kept) if _kept]
+ for page_crops, page_kept in zip(crops, is_kept)
+ ]
+ loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)]
+
+ return crops, loc_preds
+
+ def _rectify_crops(
+ self,
+ crops: List[List[np.ndarray]],
+ loc_preds: List[np.ndarray],
+ ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]:
+ # Work at a page level
+ orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc]
+ rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
+ rect_loc_preds = [
+ rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
+ for page_loc_preds, orientation in zip(loc_preds, orientations)
+ ]
+ # Flatten to list of tuples with (value, confidence)
+ crop_orientations = [
+ (orientation, prob)
+ for page_classes, page_probs in zip(classes, probs)
+ for orientation, prob in zip(page_classes, page_probs)
+ ]
+ return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value]
+
+ def _remove_padding(
+ self,
+ pages: List[np.ndarray],
+ loc_preds: List[np.ndarray],
+ ) -> List[np.ndarray]:
+ if self.preserve_aspect_ratio:
+ # Rectify loc_preds to remove padding
+ rectified_preds = []
+ for page, loc_pred in zip(pages, loc_preds):
+ h, w = page.shape[0], page.shape[1]
+ if h > w:
+ # y unchanged, dilate x coord
+ if self.symmetric_pad:
+ if self.assume_straight_pages:
+ loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1)
+ else:
+ loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1)
+ else:
+ if self.assume_straight_pages:
+ loc_pred[:, [0, 2]] *= h / w
+ else:
+ loc_pred[:, :, 0] *= h / w
+ elif w > h:
+ # x unchanged, dilate y coord
+ if self.symmetric_pad:
+ if self.assume_straight_pages:
+ loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1)
+ else:
+ loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1)
+ else:
+ if self.assume_straight_pages:
+ loc_pred[:, [1, 3]] *= w / h
+ else:
+ loc_pred[:, :, 1] *= w / h
+ rectified_preds.append(loc_pred)
+ return rectified_preds
+ return loc_preds
+
+ @staticmethod
+ def _process_predictions(
+ loc_preds: List[np.ndarray],
+ word_preds: List[Tuple[str, float]],
+ crop_orientations: List[Dict[str, Any]],
+ ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]:
+ text_preds = []
+ crop_orientation_preds = []
+ if len(loc_preds) > 0:
+ # Text & crop orientation predictions at page level
+ _idx = 0
+ for page_boxes in loc_preds:
+ text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]])
+ crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]])
+ _idx += page_boxes.shape[0]
+
+ return loc_preds, text_preds, crop_orientation_preds
+
+ def add_hook(self, hook: Callable) -> None:
+ """Add a hook to the predictor
+
+ Args:
+ ----
+ hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds`
+ """
+ self.hooks.append(hook)
diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..91640371e49874eb0ab88c87675709e27b3641be
--- /dev/null
+++ b/doctr/models/predictor/pytorch.py
@@ -0,0 +1,152 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.io.elements import Document
+from doctr.models._utils import estimate_orientation, get_language
+from doctr.models.detection.predictor import DetectionPredictor
+from doctr.models.recognition.predictor import RecognitionPredictor
+from doctr.utils.geometry import rotate_image
+
+from .base import _OCRPredictor
+
+__all__ = ["OCRPredictor"]
+
+
+class OCRPredictor(nn.Module, _OCRPredictor):
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ det_predictor: detection module
+ reco_predictor: recognition module
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ detect_language: if True, the language prediction will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ **kwargs: keyword args of `DocumentBuilder`
+ """
+
+ def __init__(
+ self,
+ det_predictor: DetectionPredictor,
+ reco_predictor: RecognitionPredictor,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ detect_orientation: bool = False,
+ detect_language: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ nn.Module.__init__(self)
+ self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
+ self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
+ _OCRPredictor.__init__(
+ self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
+ )
+ self.detect_orientation = detect_orientation
+ self.detect_language = detect_language
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ pages: List[Union[np.ndarray, torch.Tensor]],
+ **kwargs: Any,
+ ) -> Document:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
+
+ # Localize text elements
+ loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
+
+ # Detect document rotation and rotate pages
+ seg_maps = [
+ np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
+ for out_map in out_maps
+ ]
+ if self.detect_orientation:
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
+ orientations = [
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
+ ]
+ else:
+ orientations = None
+ if self.straighten_pages:
+ origin_page_orientations = (
+ origin_page_orientations
+ if self.detect_orientation
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
+ )
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # type: ignore[arg-type]
+ # Forward again to get predictions on straight pages
+ loc_preds = self.det_predictor(pages, **kwargs)
+
+ assert all(
+ len(loc_pred) == 1 for loc_pred in loc_preds
+ ), "Detection Model in ocr_predictor should output only one class"
+
+ loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
+ # Check whether crop mode should be switched to channels first
+ channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
+
+ # Rectify crops if aspect ratio
+ loc_preds = self._remove_padding(pages, loc_preds) # type: ignore[arg-type]
+
+ # Apply hooks to loc_preds if any
+ for hook in self.hooks:
+ loc_preds = hook(loc_preds)
+
+ # Crop images
+ crops, loc_preds = self._prepare_crops(
+ pages, # type: ignore[arg-type]
+ loc_preds,
+ channels_last=channels_last,
+ assume_straight_pages=self.assume_straight_pages,
+ )
+ # Rectify crop orientation and get crop orientation predictions
+ crop_orientations: Any = []
+ if not self.assume_straight_pages:
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
+ crop_orientations = [
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
+ ]
+
+ # Identify character sequences
+ word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
+ if not crop_orientations:
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
+
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
+
+ if self.detect_language:
+ languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
+ languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
+ else:
+ languages_dict = None
+
+ out = self.doc_builder(
+ pages, # type: ignore[arg-type]
+ boxes,
+ text_preds,
+ origin_page_shapes, # type: ignore[arg-type]
+ crop_orientations,
+ orientations,
+ languages_dict,
+ )
+ return out
diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f424b7b50b4527235d9a243a5981db4b3858b103
--- /dev/null
+++ b/doctr/models/predictor/tensorflow.py
@@ -0,0 +1,146 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List, Union
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.io.elements import Document
+from doctr.models._utils import estimate_orientation, get_language
+from doctr.models.detection.predictor import DetectionPredictor
+from doctr.models.recognition.predictor import RecognitionPredictor
+from doctr.utils.geometry import rotate_image
+from doctr.utils.repr import NestedObject
+
+from .base import _OCRPredictor
+
+__all__ = ["OCRPredictor"]
+
+
+class OCRPredictor(NestedObject, _OCRPredictor):
+ """Implements an object able to localize and identify text elements in a set of documents
+
+ Args:
+ ----
+ det_predictor: detection module
+ reco_predictor: recognition module
+ assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
+ without rotated textual elements.
+ straighten_pages: if True, estimates the page general orientation based on the median line orientation.
+ Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
+ accordingly. Doing so will improve performances for documents with page-uniform rotations.
+ detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ detect_language: if True, the language prediction will be added to the predictions for each
+ page. Doing so will slightly deteriorate the overall latency.
+ **kwargs: keyword args of `DocumentBuilder`
+ """
+
+ _children_names = ["det_predictor", "reco_predictor", "doc_builder"]
+
+ def __init__(
+ self,
+ det_predictor: DetectionPredictor,
+ reco_predictor: RecognitionPredictor,
+ assume_straight_pages: bool = True,
+ straighten_pages: bool = False,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ detect_orientation: bool = False,
+ detect_language: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ self.det_predictor = det_predictor
+ self.reco_predictor = reco_predictor
+ _OCRPredictor.__init__(
+ self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
+ )
+ self.detect_orientation = detect_orientation
+ self.detect_language = detect_language
+
+ def __call__(
+ self,
+ pages: List[Union[np.ndarray, tf.Tensor]],
+ **kwargs: Any,
+ ) -> Document:
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ origin_page_shapes = [page.shape[:2] for page in pages]
+
+ # Localize text elements
+ loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
+
+ # Detect document rotation and rotate pages
+ seg_maps = [
+ np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
+ for out_map in out_maps
+ ]
+ if self.detect_orientation:
+ origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
+ orientations = [
+ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
+ ]
+ else:
+ orientations = None
+ if self.straighten_pages:
+ origin_page_orientations = (
+ origin_page_orientations
+ if self.detect_orientation
+ else [estimate_orientation(seq_map) for seq_map in seg_maps]
+ )
+ pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
+ # forward again to get predictions on straight pages
+ loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
+
+ assert all(
+ len(loc_pred) == 1 for loc_pred in loc_preds_dict
+ ), "Detection Model in ocr_predictor should output only one class"
+ loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
+
+ # Rectify crops if aspect ratio
+ loc_preds = self._remove_padding(pages, loc_preds)
+
+ # Apply hooks to loc_preds if any
+ for hook in self.hooks:
+ loc_preds = hook(loc_preds)
+
+ # Crop images
+ crops, loc_preds = self._prepare_crops(
+ pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
+ )
+ # Rectify crop orientation and get crop orientation predictions
+ crop_orientations: Any = []
+ if not self.assume_straight_pages:
+ crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
+ crop_orientations = [
+ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
+ ]
+
+ # Identify character sequences
+ word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
+ if not crop_orientations:
+ crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
+
+ boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
+
+ if self.detect_language:
+ languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
+ languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
+ else:
+ languages_dict = None
+
+ out = self.doc_builder(
+ pages,
+ boxes,
+ text_preds,
+ origin_page_shapes, # type: ignore[arg-type]
+ crop_orientations,
+ orientations,
+ languages_dict,
+ )
+ return out
diff --git a/doctr/models/preprocessor/__init__.py b/doctr/models/preprocessor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/preprocessor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..58a236bd08a7314885b26b816209879d2457bc37
--- /dev/null
+++ b/doctr/models/preprocessor/pytorch.py
@@ -0,0 +1,128 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torchvision.transforms import functional as F
+from torchvision.transforms import transforms as T
+
+from doctr.transforms import Resize
+from doctr.utils.multithreading import multithread_exec
+
+__all__ = ["PreProcessor"]
+
+
+class PreProcessor(nn.Module):
+ """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
+
+ Args:
+ ----
+ output_size: expected size of each page in format (H, W)
+ batch_size: the size of page batches
+ mean: mean value of the training distribution by channel
+ std: standard deviation of the training distribution by channel
+ """
+
+ def __init__(
+ self,
+ output_size: Tuple[int, int],
+ batch_size: int,
+ mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+ self.batch_size = batch_size
+ self.resize: T.Resize = Resize(output_size, **kwargs)
+ # Perform the division by 255 at the same time
+ self.normalize = T.Normalize(mean, std)
+
+ def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]:
+ """Gather samples into batches for inference purposes
+
+ Args:
+ ----
+ samples: list of samples of shape (C, H, W)
+
+ Returns:
+ -------
+ list of batched samples (*, C, H, W)
+ """
+ num_batches = int(math.ceil(len(samples) / self.batch_size))
+ batches = [
+ torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0)
+ for idx in range(int(num_batches))
+ ]
+
+ return batches
+
+ def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ if x.ndim != 3:
+ raise AssertionError("expected list of 3D Tensors")
+ if isinstance(x, np.ndarray):
+ if x.dtype not in (np.uint8, np.float32):
+ raise TypeError("unsupported data type for numpy.ndarray")
+ x = torch.from_numpy(x.copy()).permute(2, 0, 1)
+ elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
+ raise TypeError("unsupported data type for torch.Tensor")
+ # Resizing
+ x = self.resize(x)
+ # Data type
+ if x.dtype == torch.uint8:
+ x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
+ else:
+ x = x.to(dtype=torch.float32) # type: ignore[union-attr]
+
+ return x
+
+ def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
+ """Prepare document data for model forwarding
+
+ Args:
+ ----
+ x: list of images (np.array) or tensors (already resized and batched)
+
+ Returns:
+ -------
+ list of page batches
+ """
+ # Input type check
+ if isinstance(x, (np.ndarray, torch.Tensor)):
+ if x.ndim != 4:
+ raise AssertionError("expected 4D Tensor")
+ if isinstance(x, np.ndarray):
+ if x.dtype not in (np.uint8, np.float32):
+ raise TypeError("unsupported data type for numpy.ndarray")
+ x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
+ elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
+ raise TypeError("unsupported data type for torch.Tensor")
+ # Resizing
+ if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
+ x = F.resize(
+ x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
+ )
+ # Data type
+ if x.dtype == torch.uint8: # type: ignore[union-attr]
+ x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
+ else:
+ x = x.to(dtype=torch.float32) # type: ignore[union-attr]
+ batches = [x]
+
+ elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x):
+ # Sample transform (to tensor, resize)
+ samples = list(multithread_exec(self.sample_transforms, x))
+ # Batching
+ batches = self.batch_inputs(samples)
+ else:
+ raise TypeError(f"invalid input type: {type(x)}")
+
+ # Batch transforms (normalize)
+ batches = list(multithread_exec(self.normalize, batches))
+
+ return batches
diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..431f95b11fff641bb99d2e8e13e2b74ce36b57e5
--- /dev/null
+++ b/doctr/models/preprocessor/tensorflow.py
@@ -0,0 +1,125 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.transforms import Normalize, Resize
+from doctr.utils.multithreading import multithread_exec
+from doctr.utils.repr import NestedObject
+
+__all__ = ["PreProcessor"]
+
+
+class PreProcessor(NestedObject):
+ """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
+
+ Args:
+ ----
+ output_size: expected size of each page in format (H, W)
+ batch_size: the size of page batches
+ mean: mean value of the training distribution by channel
+ std: standard deviation of the training distribution by channel
+ """
+
+ _children_names: List[str] = ["resize", "normalize"]
+
+ def __init__(
+ self,
+ output_size: Tuple[int, int],
+ batch_size: int,
+ mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
+ std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ **kwargs: Any,
+ ) -> None:
+ self.batch_size = batch_size
+ self.resize = Resize(output_size, **kwargs)
+ # Perform the division by 255 at the same time
+ self.normalize = Normalize(mean, std)
+
+ def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
+ """Gather samples into batches for inference purposes
+
+ Args:
+ ----
+ samples: list of samples (tf.Tensor)
+
+ Returns:
+ -------
+ list of batched samples
+ """
+ num_batches = int(math.ceil(len(samples) / self.batch_size))
+ batches = [
+ tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
+ for idx in range(int(num_batches))
+ ]
+
+ return batches
+
+ def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
+ if x.ndim != 3:
+ raise AssertionError("expected list of 3D Tensors")
+ if isinstance(x, np.ndarray):
+ if x.dtype not in (np.uint8, np.float32):
+ raise TypeError("unsupported data type for numpy.ndarray")
+ x = tf.convert_to_tensor(x)
+ elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
+ raise TypeError("unsupported data type for torch.Tensor")
+ # Data type & 255 division
+ if x.dtype == tf.uint8:
+ x = tf.image.convert_image_dtype(x, dtype=tf.float32)
+ # Resizing
+ x = self.resize(x)
+
+ return x
+
+ def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]:
+ """Prepare document data for model forwarding
+
+ Args:
+ ----
+ x: list of images (np.array) or tensors (already resized and batched)
+
+ Returns:
+ -------
+ list of page batches
+ """
+ # Input type check
+ if isinstance(x, (np.ndarray, tf.Tensor)):
+ if x.ndim != 4:
+ raise AssertionError("expected 4D Tensor")
+ if isinstance(x, np.ndarray):
+ if x.dtype not in (np.uint8, np.float32):
+ raise TypeError("unsupported data type for numpy.ndarray")
+ x = tf.convert_to_tensor(x)
+ elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
+ raise TypeError("unsupported data type for torch.Tensor")
+
+ # Data type & 255 division
+ if x.dtype == tf.uint8:
+ x = tf.image.convert_image_dtype(x, dtype=tf.float32)
+ # Resizing
+ if (x.shape[1], x.shape[2]) != self.resize.output_size:
+ x = tf.image.resize(
+ x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
+ )
+
+ batches = [x]
+
+ elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
+ # Sample transform (to tensor, resize)
+ samples = list(multithread_exec(self.sample_transforms, x))
+ # Batching
+ batches = self.batch_inputs(samples)
+ else:
+ raise TypeError(f"invalid input type: {type(x)}")
+
+ # Batch transforms (normalize)
+ batches = list(multithread_exec(self.normalize, batches))
+
+ return batches
diff --git a/doctr/models/recognition/__init__.py b/doctr/models/recognition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f2f723a722dc7c8fab2477fa7aaf2e83afdc8e6
--- /dev/null
+++ b/doctr/models/recognition/__init__.py
@@ -0,0 +1,6 @@
+from .crnn import *
+from .master import *
+from .sar import *
+from .vitstr import *
+from .parseq import *
+from .zoo import *
diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab82218cce1fa62325013606c02587e5617650c8
--- /dev/null
+++ b/doctr/models/recognition/core.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple
+
+import numpy as np
+
+from doctr.datasets import encode_sequences
+from doctr.utils.repr import NestedObject
+
+__all__ = ["RecognitionPostProcessor", "RecognitionModel"]
+
+
+class RecognitionModel(NestedObject):
+ """Implements abstract RecognitionModel class"""
+
+ vocab: str
+ max_length: int
+
+ def build_target(
+ self,
+ gts: List[str],
+ ) -> Tuple[np.ndarray, List[int]]:
+ """Encode a list of gts sequences into a np array and gives the corresponding*
+ sequence lengths.
+
+ Args:
+ ----
+ gts: list of ground-truth labels
+
+ Returns:
+ -------
+ A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
+ """
+ encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab))
+ seq_len = [len(word) for word in gts]
+ return encoded, seq_len
+
+
+class RecognitionPostProcessor(NestedObject):
+ """Abstract class to postprocess the raw output of the model
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __init__(
+ self,
+ vocab: str,
+ ) -> None:
+ self.vocab = vocab
+ self._embedding = list(self.vocab) + [""]
+
+ def extra_repr(self) -> str:
+ return f"vocab_size={len(self.vocab)}"
diff --git a/doctr/models/recognition/crnn/__init__.py b/doctr/models/recognition/crnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/recognition/crnn/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4b891f9a236b2e138dbae6ae77f5e0a6799c47
--- /dev/null
+++ b/doctr/models/recognition/crnn/pytorch.py
@@ -0,0 +1,339 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from itertools import groupby
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from doctr.datasets import VOCABS, decode_sequence
+
+from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
+from ...utils.pytorch import load_pretrained_params
+from ..core import RecognitionModel, RecognitionPostProcessor
+
+__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "crnn_vgg16_bn": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["legacy_french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0",
+ },
+ "crnn_mobilenet_v3_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small_pt-3b919a02.pt&src=0",
+ },
+ "crnn_mobilenet_v3_large": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_large_pt-f5259ec2.pt&src=0",
+ },
+}
+
+
+class CTCPostProcessor(RecognitionPostProcessor):
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ @staticmethod
+ def ctc_best_path(
+ logits: torch.Tensor,
+ vocab: str = VOCABS["french"],
+ blank: int = 0,
+ ) -> List[Tuple[str, float]]:
+ """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
+ `_.
+
+ Args:
+ ----
+ logits: model output, shape: N x T x C
+ vocab: vocabulary to use
+ blank: index of blank label
+
+ Returns:
+ -------
+ A list of tuples: (word, confidence)
+ """
+ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
+ probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values
+
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
+ words = [
+ decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab)
+ for seq in torch.argmax(logits, dim=-1)
+ ]
+
+ return list(zip(words, probs.tolist()))
+
+ def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
+ with label_to_idx mapping dictionnary
+
+ Args:
+ ----
+ logits: raw output of the model, shape (N, C + 1, seq_len)
+
+ Returns:
+ -------
+ A tuple of 2 lists: a list of str (words) and a list of float (probs)
+
+ """
+ # Decode CTC
+ return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab))
+
+
+class CRNN(RecognitionModel, nn.Module):
+ """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary used for encoding
+ rnn_units: number of units in the LSTM layers
+ exportable: onnx exportable returns only logits
+ cfg: configuration dictionary
+ """
+
+ _children_names: List[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
+
+ def __init__(
+ self,
+ feature_extractor: nn.Module,
+ vocab: str,
+ rnn_units: int = 128,
+ input_shape: Tuple[int, int, int] = (3, 32, 128),
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+ self.vocab = vocab
+ self.cfg = cfg
+ self.max_length = 32
+ self.exportable = exportable
+ self.feat_extractor = feature_extractor
+
+ # Resolve the input_size of the LSTM
+ with torch.inference_mode():
+ out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape
+ lstm_in = out_shape[1] * out_shape[2]
+
+ self.decoder = nn.LSTM(
+ input_size=lstm_in,
+ hidden_size=rnn_units,
+ batch_first=True,
+ num_layers=2,
+ bidirectional=True,
+ )
+
+ # features units = 2 * rnn_units because bidirectional layers
+ self.linear = nn.Linear(in_features=2 * rnn_units, out_features=len(vocab) + 1)
+
+ self.postprocessor = CTCPostProcessor(vocab=vocab)
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def compute_loss(
+ self,
+ model_output: torch.Tensor,
+ target: List[str],
+ ) -> torch.Tensor:
+ """Compute CTC loss for the model.
+
+ Args:
+ ----
+ model_output: predicted logits of the model
+ target: list of target strings
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ gt, seq_len = self.build_target(target)
+ batch_len = model_output.shape[0]
+ input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32)
+ # N x T x C -> T x N x C
+ logits = model_output.permute(1, 0, 2)
+ probs = F.log_softmax(logits, dim=-1)
+ ctc_loss = F.ctc_loss(
+ probs,
+ torch.from_numpy(gt),
+ input_length,
+ torch.tensor(seq_len, dtype=torch.int),
+ len(self.vocab),
+ zero_infinity=True,
+ )
+
+ return ctc_loss
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, Any]:
+ if self.training and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ features = self.feat_extractor(x)
+ # B x C x H x W --> B x C*H x W --> B x W x C*H
+ c, h, w = features.shape[1], features.shape[2], features.shape[3]
+ features_seq = torch.reshape(features, shape=(-1, h * c, w))
+ features_seq = torch.transpose(features_seq, 1, 2)
+ logits, _ = self.decoder(features_seq)
+ logits = self.linear(logits)
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(logits)
+
+ if target is not None:
+ out["loss"] = self.compute_loss(logits, target)
+
+ return out
+
+
+def _crnn(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[Any], nn.Module],
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> CRNN:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Feature extractor
+ feat_extractor = backbone_fn(pretrained=pretrained_backbone).features # type: ignore[call-arg]
+
+ kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["vocab"] = kwargs["vocab"]
+ _cfg["input_shape"] = kwargs["input_shape"]
+
+ # Build the model
+ model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
+ load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import torch
+ >>> from doctr.models import crnn_vgg16_bn
+ >>> model = crnn_vgg16_bn(pretrained=True)
+ >>> input_tensor = torch.rand(1, 3, 32, 128)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
+
+
+def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import torch
+ >>> from doctr.models import crnn_mobilenet_v3_small
+ >>> model = crnn_mobilenet_v3_small(pretrained=True)
+ >>> input_tensor = torch.rand(1, 3, 32, 128)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn(
+ "crnn_mobilenet_v3_small",
+ pretrained,
+ mobilenet_v3_small_r,
+ ignore_keys=["linear.weight", "linear.bias"],
+ **kwargs,
+ )
+
+
+def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import torch
+ >>> from doctr.models import crnn_mobilenet_v3_large
+ >>> model = crnn_mobilenet_v3_large(pretrained=True)
+ >>> input_tensor = torch.rand(1, 3, 32, 128)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn(
+ "crnn_mobilenet_v3_large",
+ pretrained,
+ mobilenet_v3_large_r,
+ ignore_keys=["linear.weight", "linear.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec48c4f0e90be4100b5dc700eb661df957a8a25
--- /dev/null
+++ b/doctr/models/recognition/crnn/tensorflow.py
@@ -0,0 +1,318 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import tensorflow as tf
+from tensorflow.keras import layers
+from tensorflow.keras.models import Model, Sequential
+
+from doctr.datasets import VOCABS
+
+from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
+from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
+from ..core import RecognitionModel, RecognitionPostProcessor
+
+__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "crnn_vgg16_bn": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["legacy_french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0",
+ },
+ "crnn_mobilenet_v3_small": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0",
+ },
+ "crnn_mobilenet_v3_large": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0",
+ },
+}
+
+
+class CTCPostProcessor(RecognitionPostProcessor):
+ """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ ignore_case: if True, ignore case of letters
+ ignore_accents: if True, ignore accents of letters
+ """
+
+ def __call__(
+ self,
+ logits: tf.Tensor,
+ beam_width: int = 1,
+ top_paths: int = 1,
+ ) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]:
+ """Performs decoding of raw output with CTC and decoding of CTC predictions
+ with label_to_idx mapping dictionnary
+
+ Args:
+ ----
+ logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
+ beam_width: An int scalar >= 0 (beam search beam width).
+ top_paths: An int scalar >= 0, <= beam_width (controls output size).
+
+ Returns:
+ -------
+ A list of decoded words of length BATCH_SIZE
+
+
+ """
+ # Decode CTC
+ _decoded, _log_prob = tf.nn.ctc_beam_search_decoder(
+ tf.transpose(logits, perm=[1, 0, 2]),
+ tf.fill(tf.shape(logits)[:1], tf.shape(logits)[1]),
+ beam_width=beam_width,
+ top_paths=top_paths,
+ )
+
+ _decoded = tf.sparse.concat(
+ 1,
+ [tf.sparse.expand_dims(dec, axis=1) for dec in _decoded],
+ expand_nonconcat_dims=True,
+ ) # dim : batchsize x beamwidth x actual_max_len_predictions
+ out_idxs = tf.sparse.to_dense(_decoded, default_value=len(self.vocab))
+
+ # Map it to characters
+ _decoded_strings_pred = tf.strings.reduce_join(
+ inputs=tf.nn.embedding_lookup(tf.constant(self._embedding, dtype=tf.string), out_idxs),
+ axis=-1,
+ )
+ _decoded_strings_pred = tf.strings.split(_decoded_strings_pred, "")
+ decoded_strings_pred = tf.sparse.to_dense(_decoded_strings_pred.to_sparse(), default_value="not valid")[
+ :, :, 0
+ ] # dim : batch_size x beam_width
+
+ if top_paths == 1:
+ probs = tf.math.exp(tf.squeeze(_log_prob, axis=1)) # dim : batchsize
+ decoded_strings_pred = tf.squeeze(decoded_strings_pred, axis=1)
+ word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
+ else:
+ probs = tf.math.exp(_log_prob) # dim : batchsize x beamwidth
+ word_values = [[word.decode() for word in words] for words in decoded_strings_pred.numpy().tolist()]
+ return list(zip(word_values, probs.numpy().tolist()))
+
+
+class CRNN(RecognitionModel, Model):
+ """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary used for encoding
+ rnn_units: number of units in the LSTM layers
+ exportable: onnx exportable returns only logits
+ beam_width: beam width for beam search decoding
+ top_paths: number of top paths for beam search decoding
+ cfg: configuration dictionary
+ """
+
+ _children_names: List[str] = ["feat_extractor", "decoder", "postprocessor"]
+
+ def __init__(
+ self,
+ feature_extractor: tf.keras.Model,
+ vocab: str,
+ rnn_units: int = 128,
+ exportable: bool = False,
+ beam_width: int = 1,
+ top_paths: int = 1,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ # Initialize kernels
+ h, w, c = feature_extractor.output_shape[1:]
+
+ super().__init__()
+ self.vocab = vocab
+ self.max_length = w
+ self.cfg = cfg
+ self.exportable = exportable
+ self.feat_extractor = feature_extractor
+
+ self.decoder = Sequential([
+ layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
+ layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
+ layers.Dense(units=len(vocab) + 1),
+ ])
+ self.decoder.build(input_shape=(None, w, h * c))
+
+ self.postprocessor = CTCPostProcessor(vocab=vocab)
+
+ self.beam_width = beam_width
+ self.top_paths = top_paths
+
+ def compute_loss(
+ self,
+ model_output: tf.Tensor,
+ target: List[str],
+ ) -> tf.Tensor:
+ """Compute CTC loss for the model.
+
+ Args:
+ ----
+ model_output: predicted logits of the model
+ target: lengths of each gt word inside the batch
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ gt, seq_len = self.build_target(target)
+ batch_len = model_output.shape[0]
+ input_length = tf.fill((batch_len,), model_output.shape[1])
+ ctc_loss = tf.nn.ctc_loss(
+ gt, model_output, seq_len, input_length, logits_time_major=False, blank_index=len(self.vocab)
+ )
+ return ctc_loss
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ beam_width: int = 1,
+ top_paths: int = 1,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ if kwargs.get("training", False) and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ features = self.feat_extractor(x, **kwargs)
+ # B x H x W x C --> B x W x H x C
+ transposed_feat = tf.transpose(features, perm=[0, 2, 1, 3])
+ w, h, c = transposed_feat.get_shape().as_list()[1:]
+ # B x W x H x C --> B x W x H * C
+ features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
+ logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(logits, beam_width=beam_width, top_paths=top_paths)
+
+ if target is not None:
+ out["loss"] = self.compute_loss(logits, target)
+
+ return out
+
+
+def _crnn(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ pretrained_backbone: bool = True,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> CRNN:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"])
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["vocab"] = kwargs["vocab"]
+ _cfg["input_shape"] = input_shape or default_cfgs[arch]["input_shape"]
+
+ feat_extractor = backbone_fn(
+ input_shape=_cfg["input_shape"],
+ include_top=False,
+ pretrained=pretrained_backbone,
+ )
+
+ # Build the model
+ model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, _cfg["url"])
+
+ return model
+
+
+def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import crnn_vgg16_bn
+ >>> model = crnn_vgg16_bn(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
+
+
+def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import crnn_mobilenet_v3_small
+ >>> model = crnn_mobilenet_v3_small(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
+
+
+def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
+ """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based
+ Sequence Recognition and Its Application to Scene Text Recognition" `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import crnn_mobilenet_v3_large
+ >>> model = crnn_mobilenet_v3_large(pretrained=True)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the CRNN architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
diff --git a/doctr/models/recognition/master/__init__.py b/doctr/models/recognition/master/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/recognition/master/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d3002893ee26b4aa41838e41e63cd3f17c35779
--- /dev/null
+++ b/doctr/models/recognition/master/base.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple
+
+import numpy as np
+
+from ....datasets import encode_sequences
+from ..core import RecognitionPostProcessor
+
+
+class _MASTER:
+ vocab: str
+ max_length: int
+
+ def build_target(
+ self,
+ gts: List[str],
+ ) -> Tuple[np.ndarray, List[int]]:
+ """Encode a list of gts sequences into a np array and gives the corresponding*
+ sequence lengths.
+
+ Args:
+ ----
+ gts: list of ground-truth labels
+
+ Returns:
+ -------
+ A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
+ """
+ encoded = encode_sequences(
+ sequences=gts,
+ vocab=self.vocab,
+ target_size=self.max_length,
+ eos=len(self.vocab),
+ sos=len(self.vocab) + 1,
+ pad=len(self.vocab) + 2,
+ )
+ seq_len = [len(word) for word in gts]
+ return encoded, seq_len
+
+
+class _MASTERPostProcessor(RecognitionPostProcessor):
+ """Abstract class to postprocess the raw output of the model
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __init__(
+ self,
+ vocab: str,
+ ) -> None:
+ super().__init__(vocab)
+ self._embedding = list(vocab) + [""] + [""] + [""]
diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..875fcbd687d6e644623632eba39da6f57f26edf0
--- /dev/null
+++ b/doctr/models/recognition/master/pytorch.py
@@ -0,0 +1,338 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models._utils import IntermediateLayerGetter
+
+from doctr.datasets import VOCABS
+from doctr.models.classification import magc_resnet31
+from doctr.models.modules.transformer import Decoder, PositionalEncoding
+
+from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
+from .base import _MASTER, _MASTERPostProcessor
+
+__all__ = ["MASTER", "master"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "master": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/master-fde31e4a.pt&src=0",
+ },
+}
+
+
+class MASTER(_MASTER, nn.Module):
+ """Implements MASTER as described in paper: `_.
+ Implementation based on the official Pytorch implementation: `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary, (without EOS, SOS, PAD)
+ d_model: d parameter for the transformer decoder
+ dff: depth of the pointwise feed-forward layer
+ num_heads: number of heads for the mutli-head attention module
+ num_layers: number of decoder layers to stack
+ max_length: maximum length of character sequence handled by the model
+ dropout: dropout probability of the decoder
+ input_shape: size of the image inputs
+ exportable: onnx exportable returns only logits
+ cfg: dictionary containing information about the model
+ """
+
+ def __init__(
+ self,
+ feature_extractor: nn.Module,
+ vocab: str,
+ d_model: int = 512,
+ dff: int = 2048,
+ num_heads: int = 8, # number of heads in the transformer decoder
+ num_layers: int = 3,
+ max_length: int = 50,
+ dropout: float = 0.2,
+ input_shape: Tuple[int, int, int] = (3, 32, 128), # different from the paper
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.exportable = exportable
+ self.max_length = max_length
+ self.d_model = d_model
+ self.vocab = vocab
+ self.cfg = cfg
+ self.vocab_size = len(vocab)
+
+ self.feat_extractor = feature_extractor
+ self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[1] * input_shape[2])
+
+ self.decoder = Decoder(
+ num_layers=num_layers,
+ d_model=self.d_model,
+ num_heads=num_heads,
+ vocab_size=self.vocab_size + 3, # EOS, SOS, PAD
+ dff=dff,
+ dropout=dropout,
+ maximum_position_encoding=self.max_length,
+ )
+
+ self.linear = nn.Linear(self.d_model, self.vocab_size + 3)
+ self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def make_source_and_target_mask(
+ self, source: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
+ # NOTE: nn.TransformerDecoder takes the inverse from this implementation
+ # [True, True, True, ..., False, False, False] -> False is masked
+ # (N, 1, 1, max_length)
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
+ target_length = target.size(1)
+ # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
+ # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
+ target_sub_mask = torch.tril(torch.ones((target_length, target_length), device=source.device), diagonal=0).to(
+ dtype=torch.bool
+ )
+ # source mask filled with ones (max_length, positional_encoded_seq_len)
+ source_mask = torch.ones((target_length, source.size(1)), dtype=torch.uint8, device=source.device)
+ # combine the two masks into one (N, 1, max_length, max_length)
+ target_mask = target_pad_mask & target_sub_mask
+ return source_mask, target_mask.int()
+
+ @staticmethod
+ def compute_loss(
+ model_output: torch.Tensor,
+ gt: torch.Tensor,
+ seq_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute categorical cross-entropy loss for the model.
+ Sequences are masked after the EOS character.
+
+ Args:
+ ----
+ gt: the encoded tensor with gt labels
+ model_output: predicted logits of the model
+ seq_len: lengths of each gt word inside the batch
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ # Input length : number of timesteps
+ input_len = model_output.shape[1]
+ # Add one for additional token (sos disappear in shift!)
+ seq_len = seq_len + 1
+ # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
+ # The "masked" first gt char is . Delete last logit of the model output.
+ cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
+ # Compute mask, remove 1 timestep here as well
+ mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None]
+ cce[mask_2d] = 0
+
+ ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
+ return ce_loss.mean()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, Any]:
+ """Call function for training
+
+ Args:
+ ----
+ x: images
+ target: list of str labels
+ return_model_output: if True, return logits
+ return_preds: if True, decode logits
+
+ Returns:
+ -------
+ A dictionnary containing eventually loss, logits and predictions.
+ """
+ # Encode
+ features = self.feat_extractor(x)["features"]
+ b, c, h, w = features.shape
+ # (N, C, H, W) --> (N, H * W, C)
+ features = features.view(b, c, h * w).permute((0, 2, 1))
+ # add positional encoding to features
+ encoded = self.positional_encoding(features)
+
+ out: Dict[str, Any] = {}
+
+ if self.training and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ if target is not None:
+ # Compute target: tensor of gts and sequence lengths
+ _gt, _seq_len = self.build_target(target)
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
+ gt, seq_len = gt.to(x.device), seq_len.to(x.device)
+
+ # Compute source mask and target mask
+ source_mask, target_mask = self.make_source_and_target_mask(encoded, gt)
+ output = self.decoder(gt, encoded, source_mask, target_mask)
+ # Compute logits
+ logits = self.linear(output)
+ else:
+ logits = self.decode(encoded)
+
+ logits = _bf16_to_float32(logits)
+
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if target is not None:
+ out["loss"] = self.compute_loss(logits, gt, seq_len)
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if return_preds:
+ out["preds"] = self.postprocessor(logits)
+
+ return out
+
+ def decode(self, encoded: torch.Tensor) -> torch.Tensor:
+ """Decode function for prediction
+
+ Args:
+ ----
+ encoded: input tensor
+
+ Returns:
+ -------
+ A Tuple of torch.Tensor: predictions, logits
+ """
+ b = encoded.size(0)
+
+ # Padding symbol + SOS at the beginning
+ ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device) # pad
+ ys[:, 0] = self.vocab_size + 1 # sos
+
+ # Final dimension include EOS/SOS/PAD
+ for i in range(self.max_length - 1):
+ source_mask, target_mask = self.make_source_and_target_mask(encoded, ys)
+ output = self.decoder(ys, encoded, source_mask, target_mask)
+ logits = self.linear(output)
+ prob = torch.softmax(logits, dim=-1)
+ next_token = torch.max(prob, dim=-1).indices
+ # update ys with the next token and ignore the first token (SOS)
+ ys[:, i + 1] = next_token[:, i]
+
+ # Shape (N, max_length, vocab_size + 1)
+ return logits
+
+
+class MASTERPostProcessor(_MASTERPostProcessor):
+ """Post processor for MASTER architectures"""
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = logits.argmax(-1)
+ # N x L
+ probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
+ # Take the minimum confidence of the sequence
+ probs = probs.min(dim=1).values.detach().cpu()
+
+ # Manual decoding
+ word_values = [
+ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0]
+ for encoded_seq in out_idxs.cpu().numpy()
+ ]
+
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
+
+
+def _master(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ layer: str,
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> MASTER:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
+
+ kwargs["vocab"] = _cfg["vocab"]
+ kwargs["input_shape"] = _cfg["input_shape"]
+
+ # Build the model
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(pretrained_backbone),
+ {layer: "features"},
+ )
+ model = MASTER(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
+ """MASTER as described in paper: `_.
+
+ >>> import torch
+ >>> from doctr.models import master
+ >>> model = master(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keywoard arguments passed to the MASTER architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _master(
+ "master",
+ pretrained,
+ magc_resnet31,
+ "10",
+ ignore_keys=[
+ "decoder.embed.weight",
+ "linear.weight",
+ "linear.bias",
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3ecadcc15b071c781d38280b0904901ba365619
--- /dev/null
+++ b/doctr/models/recognition/master/tensorflow.py
@@ -0,0 +1,318 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import Model, layers
+
+from doctr.datasets import VOCABS
+from doctr.models.classification import magc_resnet31
+from doctr.models.modules.transformer import Decoder, PositionalEncoding
+
+from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
+from .base import _MASTER, _MASTERPostProcessor
+
+__all__ = ["MASTER", "master"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "master": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0",
+ },
+}
+
+
+class MASTER(_MASTER, Model):
+ """Implements MASTER as described in paper: `_.
+ Implementation based on the official TF implementation: `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary, (without EOS, SOS, PAD)
+ d_model: d parameter for the transformer decoder
+ dff: depth of the pointwise feed-forward layer
+ num_heads: number of heads for the mutli-head attention module
+ num_layers: number of decoder layers to stack
+ max_length: maximum length of character sequence handled by the model
+ dropout: dropout probability of the decoder
+ input_shape: size of the image inputs
+ exportable: onnx exportable returns only logits
+ cfg: dictionary containing information about the model
+ """
+
+ def __init__(
+ self,
+ feature_extractor: tf.keras.Model,
+ vocab: str,
+ d_model: int = 512,
+ dff: int = 2048,
+ num_heads: int = 8, # number of heads in the transformer decoder
+ num_layers: int = 3,
+ max_length: int = 50,
+ dropout: float = 0.2,
+ input_shape: Tuple[int, int, int] = (32, 128, 3), # different from the paper
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.exportable = exportable
+ self.max_length = max_length
+ self.d_model = d_model
+ self.vocab = vocab
+ self.cfg = cfg
+ self.vocab_size = len(vocab)
+
+ self.feat_extractor = feature_extractor
+ self.positional_encoding = PositionalEncoding(self.d_model, dropout, max_len=input_shape[0] * input_shape[1])
+
+ self.decoder = Decoder(
+ num_layers=num_layers,
+ d_model=self.d_model,
+ num_heads=num_heads,
+ vocab_size=self.vocab_size + 3, # EOS, SOS, PAD
+ dff=dff,
+ dropout=dropout,
+ maximum_position_encoding=self.max_length,
+ )
+
+ self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
+ self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
+
+ @tf.function
+ def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
+ # (N, 1, 1, max_length)
+ target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
+ target_pad_mask = target_pad_mask[:, tf.newaxis, tf.newaxis, :]
+ target_length = target.shape[1]
+ # sub mask filled diagonal with 1 = see 0 = masked (max_length, max_length)
+ target_sub_mask = tf.linalg.band_part(tf.ones((target_length, target_length)), -1, 0)
+ # source mask filled with ones (max_length, positional_encoded_seq_len)
+ source_mask = tf.ones((target_length, source.shape[1]))
+ # combine the two masks into one boolean mask where False is masked (N, 1, max_length, max_length)
+ target_mask = tf.math.logical_and(
+ tf.cast(target_sub_mask, dtype=tf.bool), tf.cast(target_pad_mask, dtype=tf.bool)
+ )
+ return source_mask, target_mask
+
+ @staticmethod
+ def compute_loss(
+ model_output: tf.Tensor,
+ gt: tf.Tensor,
+ seq_len: List[int],
+ ) -> tf.Tensor:
+ """Compute categorical cross-entropy loss for the model.
+ Sequences are masked after the EOS character.
+
+ Args:
+ ----
+ gt: the encoded tensor with gt labels
+ model_output: predicted logits of the model
+ seq_len: lengths of each gt word inside the batch
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ # Input length : number of timesteps
+ input_len = tf.shape(model_output)[1]
+ # Add one for additional token (sos disappear in shift!)
+ seq_len = tf.cast(seq_len, tf.int32) + 1
+ # One-hot gt labels
+ oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
+ # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
+ # The "masked" first gt char is . Delete last logit of the model output.
+ cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :])
+ # Compute mask
+ mask_values = tf.zeros_like(cce)
+ mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well
+ masked_loss = tf.where(mask_2d, cce, mask_values)
+ ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
+
+ return tf.expand_dims(ce_loss, axis=1)
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """Call function for training
+
+ Args:
+ ----
+ x: images
+ target: list of str labels
+ return_model_output: if True, return logits
+ return_preds: if True, decode logits
+ **kwargs: keyword arguments passed to the decoder
+
+ Returns:
+ -------
+ A dictionnary containing eventually loss, logits and predictions.
+ """
+ # Encode
+ feature = self.feat_extractor(x, **kwargs)
+ b, h, w, c = feature.get_shape()
+ # (N, H, W, C) --> (N, H * W, C)
+ feature = tf.reshape(feature, shape=(b, h * w, c))
+ # add positional encoding to features
+ encoded = self.positional_encoding(feature, **kwargs)
+
+ out: Dict[str, tf.Tensor] = {}
+
+ if kwargs.get("training", False) and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ if target is not None:
+ # Compute target: tensor of gts and sequence lengths
+ gt, seq_len = self.build_target(target)
+ # Compute decoder masks
+ source_mask, target_mask = self.make_source_and_target_mask(encoded, gt)
+ # Compute logits
+ output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs)
+ logits = self.linear(output, **kwargs)
+ else:
+ logits = self.decode(encoded, **kwargs)
+
+ logits = _bf16_to_float32(logits)
+
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if target is not None:
+ out["loss"] = self.compute_loss(logits, gt, seq_len)
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if return_preds:
+ out["preds"] = self.postprocessor(logits)
+
+ return out
+
+ @tf.function
+ def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ """Decode function for prediction
+
+ Args:
+ ----
+ encoded: encoded features
+ **kwargs: keyword arguments passed to the decoder
+
+ Returns:
+ -------
+ A Tuple of tf.Tensor: predictions, logits
+ """
+ b = encoded.shape[0]
+
+ start_symbol = tf.constant(self.vocab_size + 1, dtype=tf.int32) # SOS
+ padding_symbol = tf.constant(self.vocab_size + 2, dtype=tf.int32) # PAD
+
+ ys = tf.fill(dims=(b, self.max_length - 1), value=padding_symbol)
+ start_vector = tf.fill(dims=(b, 1), value=start_symbol)
+ ys = tf.concat([start_vector, ys], axis=-1)
+
+ # Final dimension include EOS/SOS/PAD
+ for i in range(self.max_length - 1):
+ source_mask, target_mask = self.make_source_and_target_mask(encoded, ys)
+ output = self.decoder(ys, encoded, source_mask, target_mask, **kwargs)
+ logits = self.linear(output, **kwargs)
+ prob = tf.nn.softmax(logits, axis=-1)
+ next_token = tf.argmax(prob, axis=-1, output_type=ys.dtype)
+ # update ys with the next token and ignore the first token (SOS)
+ i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(self.max_length), indexing="ij")
+ indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
+
+ ys = tf.tensor_scatter_nd_update(ys, indices, next_token[:, i])
+
+ # Shape (N, max_length, vocab_size + 1)
+ return logits
+
+
+class MASTERPostProcessor(_MASTERPostProcessor):
+ """Post processor for MASTER architectures
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __call__(
+ self,
+ logits: tf.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = tf.math.argmax(logits, axis=2)
+ # N x L
+ probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
+ # Take the minimum confidence of the sequence
+ probs = tf.math.reduce_min(probs, axis=1)
+
+ # decode raw output of the model with tf_label_to_idx
+ out_idxs = tf.cast(out_idxs, dtype="int32")
+ embedding = tf.constant(self._embedding, dtype=tf.string)
+ decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
+ decoded_strings_pred = tf.strings.split(decoded_strings_pred, "")
+ decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
+ word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
+
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
+
+
+def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
+
+ kwargs["vocab"] = _cfg["vocab"]
+ kwargs["input_shape"] = _cfg["input_shape"]
+
+ # Build the model
+ model = MASTER(
+ backbone_fn(pretrained=pretrained_backbone, input_shape=_cfg["input_shape"], include_top=False),
+ cfg=_cfg,
+ **kwargs,
+ )
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
+ """MASTER as described in paper: `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import master
+ >>> model = master(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keywoard arguments passed to the MASTER architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _master("master", pretrained, magc_resnet31, **kwargs)
diff --git a/doctr/models/recognition/parseq/__init__.py b/doctr/models/recognition/parseq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/recognition/parseq/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/recognition/parseq/base.py b/doctr/models/recognition/parseq/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..60aa1fcfcf073751bac75350ae25dfdfa29bc491
--- /dev/null
+++ b/doctr/models/recognition/parseq/base.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple
+
+import numpy as np
+
+from ....datasets import encode_sequences
+from ..core import RecognitionPostProcessor
+
+
+class _PARSeq:
+ vocab: str
+ max_length: int
+
+ def build_target(
+ self,
+ gts: List[str],
+ ) -> Tuple[np.ndarray, List[int]]:
+ """Encode a list of gts sequences into a np array and gives the corresponding*
+ sequence lengths.
+
+ Args:
+ ----
+ gts: list of ground-truth labels
+
+ Returns:
+ -------
+ A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
+ """
+ encoded = encode_sequences(
+ sequences=gts,
+ vocab=self.vocab,
+ target_size=self.max_length,
+ eos=len(self.vocab),
+ sos=len(self.vocab) + 1,
+ pad=len(self.vocab) + 2,
+ )
+ seq_len = [len(word) for word in gts]
+ return encoded, seq_len
+
+
+class _PARSeqPostProcessor(RecognitionPostProcessor):
+ """Abstract class to postprocess the raw output of the model
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __init__(
+ self,
+ vocab: str,
+ ) -> None:
+ super().__init__(vocab)
+ self._embedding = list(vocab) + ["", "", ""]
diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a21cd5780b5ac063c8b047596a7dfa11a8e588
--- /dev/null
+++ b/doctr/models/recognition/parseq/pytorch.py
@@ -0,0 +1,481 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from copy import deepcopy
+from itertools import permutations
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models._utils import IntermediateLayerGetter
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
+
+from ...classification import vit_s
+from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
+from .base import _PARSeq, _PARSeqPostProcessor
+
+__all__ = ["PARSeq", "parseq"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "parseq": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/parseq-56125471.pt&src=0",
+ },
+}
+
+
+class CharEmbedding(nn.Module):
+ """Implements the character embedding module
+
+ Args:
+ ----
+ vocab_size: size of the vocabulary
+ d_model: dimension of the model
+ """
+
+ def __init__(self, vocab_size: int, d_model: int):
+ super().__init__()
+ self.embedding = nn.Embedding(vocab_size, d_model)
+ self.d_model = d_model
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return math.sqrt(self.d_model) * self.embedding(x)
+
+
+class PARSeqDecoder(nn.Module):
+ """Implements decoder module of the PARSeq model
+
+ Args:
+ ----
+ d_model: dimension of the model
+ num_heads: number of attention heads
+ ffd: dimension of the feed forward layer
+ ffd_ratio: depth multiplier for the feed forward layer
+ dropout: dropout rate
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int = 12,
+ ffd: int = 2048,
+ ffd_ratio: int = 4,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
+ self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
+ self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
+
+ self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.output_norm = nn.LayerNorm(d_model, eps=1e-5)
+ self.attention_dropout = nn.Dropout(dropout)
+ self.cross_attention_dropout = nn.Dropout(dropout)
+ self.feed_forward_dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ target,
+ content,
+ memory,
+ target_mask: Optional[torch.Tensor] = None,
+ ):
+ query_norm = self.query_norm(target)
+ content_norm = self.content_norm(content)
+ target = target.clone() + self.attention_dropout(
+ self.attention(query_norm, content_norm, content_norm, mask=target_mask)
+ )
+ target = target.clone() + self.cross_attention_dropout(
+ self.cross_attention(self.query_norm(target), memory, memory)
+ )
+ target = target.clone() + self.feed_forward_dropout(self.position_feed_forward(self.feed_forward_norm(target)))
+ return self.output_norm(target)
+
+
+class PARSeq(_PARSeq, nn.Module):
+ """Implements a PARSeq architecture as described in `"Scene Text Recognition
+ with Permuted Autoregressive Sequence Models" `_.
+ Slightly modified implementation based on the official Pytorch implementation: None:
+ super().__init__()
+ self.vocab = vocab
+ self.exportable = exportable
+ self.cfg = cfg
+ self.max_length = max_length
+ self.vocab_size = len(vocab)
+ self.rng = np.random.default_rng()
+
+ self.feat_extractor = feature_extractor
+ self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
+ self.head = nn.Linear(embedding_units, self.vocab_size + 1) # +1 for EOS
+ self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
+
+ self.pos_queries = nn.Parameter(torch.Tensor(1, self.max_length + 1, embedding_units)) # +1 for EOS
+ self.dropout = nn.Dropout(p=dropout_prob)
+
+ self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
+
+ nn.init.trunc_normal_(self.pos_queries, std=0.02)
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
+ # Generates permutations of the target sequence.
+ # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
+ # with small modifications
+
+ max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch
+ perms = [torch.arange(max_num_chars, device=seqlen.device)]
+
+ max_perms = math.factorial(max_num_chars) // 2
+ num_gen_perms = min(3, max_perms)
+ if max_num_chars < 5:
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
+ if max_num_chars == 4:
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
+ else:
+ selector = list(range(max_perms))
+ perm_pool = torch.as_tensor(list(permutations(range(max_num_chars), max_num_chars)), device=seqlen.device)[
+ selector
+ ]
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
+ perm_pool = perm_pool[1:]
+ final_perms = torch.stack(perms)
+ if len(perm_pool):
+ i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
+ final_perms = torch.cat([final_perms, perm_pool[i]])
+ else:
+ perms.extend([
+ torch.randperm(max_num_chars, device=seqlen.device) for _ in range(num_gen_perms - len(perms))
+ ])
+ final_perms = torch.stack(perms)
+
+ comp = final_perms.flip(-1)
+ final_perms = torch.stack([final_perms, comp]).transpose(0, 1).reshape(-1, max_num_chars)
+
+ sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
+ eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
+ if len(combined) > 1:
+ combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
+ return combined
+
+ def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Generate source and target mask for the decoder attention.
+ sz = permutation.shape[0]
+ mask = torch.ones((sz, sz), device=permutation.device)
+
+ for i in range(sz):
+ query_idx = permutation[i]
+ masked_keys = permutation[i + 1 :]
+ mask[query_idx, masked_keys] = 0.0
+ source_mask = mask[:-1, :-1].clone()
+ mask[torch.eye(sz, dtype=torch.bool, device=permutation.device)] = 0.0
+ target_mask = mask[1:, :-1]
+
+ return source_mask.int(), target_mask.int()
+
+ def decode(
+ self,
+ target: torch.Tensor,
+ memory: torch.Tensor,
+ target_mask: Optional[torch.Tensor] = None,
+ target_query: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Add positional information to the target sequence and pass it through the decoder."""
+ batch_size, sequence_length = target.shape
+ # apply positional information to the target sequence excluding the SOS token
+ null_ctx = self.embed(target[:, :1])
+ content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:])
+ content = self.dropout(torch.cat([null_ctx, content], dim=1))
+ if target_query is None:
+ target_query = self.pos_queries[:, :sequence_length].expand(batch_size, -1, -1)
+ target_query = self.dropout(target_query)
+ return self.decoder(target_query, content, memory, target_mask)
+
+ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
+ """Generate predictions for the given features."""
+ max_length = max_len if max_len is not None else self.max_length
+ max_length = min(max_length, self.max_length) + 1
+ # Padding symbol + SOS at the beginning
+ ys = torch.full(
+ (features.size(0), max_length), self.vocab_size + 2, dtype=torch.long, device=features.device
+ ) # pad
+ ys[:, 0] = self.vocab_size + 1 # SOS token
+ pos_queries = self.pos_queries[:, :max_length].expand(features.size(0), -1, -1)
+ # Create query mask for the decoder attention
+ query_mask = (
+ torch.tril(torch.ones((max_length, max_length), device=features.device), diagonal=0).to(dtype=torch.bool)
+ ).int()
+
+ pos_logits = []
+ for i in range(max_length):
+ # Decode one token at a time without providing information about the future tokens
+ tgt_out = self.decode(
+ ys[:, : i + 1],
+ features,
+ query_mask[i : i + 1, : i + 1],
+ target_query=pos_queries[:, i : i + 1],
+ )
+ pos_prob = self.head(tgt_out)
+ pos_logits.append(pos_prob)
+
+ if i + 1 < max_length:
+ # Update with the next token
+ ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
+
+ # Stop decoding if all sequences have reached the EOS token
+ if max_len is None and (ys == self.vocab_size).any(dim=-1).all():
+ break
+
+ logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
+
+ # One refine iteration
+ # Update query mask
+ query_mask[torch.triu(torch.ones(max_length, max_length, dtype=torch.bool, device=features.device), 2)] = 1
+
+ # Prepare target input for 1 refine iteration
+ sos = torch.full((features.size(0), 1), self.vocab_size + 1, dtype=torch.long, device=features.device)
+ ys = torch.cat([sos, logits[:, :-1].argmax(-1)], dim=1)
+
+ # Create padding mask for refined target input maskes all behind EOS token as False
+ # (N, 1, 1, max_length)
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
+ mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
+ logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
+
+ return logits # (N, max_length, vocab_size + 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, Any]:
+ features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
+ # remove cls token
+ features = features[:, 1:, :]
+
+ if self.training and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ if target is not None:
+ # Build target tensor
+ _gt, _seq_len = self.build_target(target)
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device)
+ gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
+
+ if self.training:
+ # Generate permutations for the target sequences
+ tgt_perms = self.generate_permutations(seq_len)
+
+ gt_in = gt[:, :-1] # remove EOS token from longest target sequence
+ gt_out = gt[:, 1:] # remove SOS token
+ # Create padding mask for target input
+ # [True, True, True, ..., False, False, False] -> False is masked
+ padding_mask = ~(
+ ((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0
+ ).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
+
+ loss = torch.tensor(0.0, device=features.device)
+ loss_numel: Union[int, float] = 0
+ n = (gt_out != self.vocab_size + 2).sum().item()
+ for i, perm in enumerate(tgt_perms):
+ _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
+ # combine both masks
+ mask = (target_mask.bool() & padding_mask.bool()).int() # (N, 1, seq_len, seq_len)
+
+ logits = self.head(self.decode(gt_in, features, mask)).flatten(end_dim=1)
+ loss += n * F.cross_entropy(logits, gt_out.flatten(), ignore_index=self.vocab_size + 2)
+ loss_numel += n
+ # After the second iteration (i.e. done with canonical and reverse orderings),
+ # remove the [EOS] tokens for the succeeding perms
+ if i == 1:
+ gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
+ n = (gt_out != self.vocab_size + 2).sum().item()
+
+ loss /= loss_numel
+
+ else:
+ gt = gt[:, 1:] # remove SOS token
+ max_len = gt.shape[1] - 1 # exclude EOS token
+ logits = self.decode_autoregressive(features, max_len)
+ loss = F.cross_entropy(logits.flatten(end_dim=1), gt.flatten(), ignore_index=self.vocab_size + 2)
+ else:
+ logits = self.decode_autoregressive(features)
+
+ logits = _bf16_to_float32(logits)
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(logits)
+
+ if target is not None:
+ out["loss"] = loss
+
+ return out
+
+
+class PARSeqPostProcessor(_PARSeqPostProcessor):
+ """Post processor for PARSeq architecture
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = logits.argmax(-1)
+ preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
+
+ # Manual decoding
+ word_values = [
+ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0]
+ for encoded_seq in out_idxs.cpu().numpy()
+ ]
+ # compute probabilties for each word up to the EOS token
+ probs = [
+ preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
+ ]
+
+ return list(zip(word_values, probs))
+
+
+def _parseq(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ layer: str,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> PARSeq:
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
+ patch_size = kwargs.get("patch_size", (4, 8))
+
+ kwargs["vocab"] = _cfg["vocab"]
+ kwargs["input_shape"] = _cfg["input_shape"]
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
+ backbone_fn(False, input_shape=_cfg["input_shape"], patch_size=patch_size), # type: ignore[call-arg]
+ {layer: "features"},
+ )
+
+ kwargs.pop("patch_size", None)
+ kwargs.pop("pretrained_backbone", None)
+
+ # Build the model
+ model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
+ """PARSeq architecture from
+ `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_.
+
+ >>> import torch
+ >>> from doctr.models import parseq
+ >>> model = parseq(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the PARSeq architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _parseq(
+ "parseq",
+ pretrained,
+ vit_s,
+ "1",
+ embedding_units=384,
+ patch_size=(4, 8),
+ ignore_keys=["embed.embedding.weight", "head.weight", "head.bias"],
+ **kwargs,
+ )
diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1904c082e1d28b3a61945685e3cf79bca50133e
--- /dev/null
+++ b/doctr/models/recognition/parseq/tensorflow.py
@@ -0,0 +1,511 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from copy import deepcopy
+from itertools import permutations
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras import Model, layers
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
+
+from ...classification import vit_s
+from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
+from .base import _PARSeq, _PARSeqPostProcessor
+
+__all__ = ["PARSeq", "parseq"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "parseq": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0",
+ },
+}
+
+
+class CharEmbedding(layers.Layer):
+ """Implements the character embedding module
+
+ Args:
+ ----
+ vocab_size: size of the vocabulary
+ d_model: dimension of the model
+ """
+
+ def __init__(self, vocab_size: int, d_model: int):
+ super(CharEmbedding, self).__init__()
+ self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
+ self.d_model = d_model
+
+ def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
+ return math.sqrt(self.d_model) * self.embedding(x, **kwargs)
+
+
+class PARSeqDecoder(layers.Layer):
+ """Implements decoder module of the PARSeq model
+
+ Args:
+ ----
+ d_model: dimension of the model
+ num_heads: number of attention heads
+ ffd: dimension of the feed forward layer
+ ffd_ratio: depth multiplier for the feed forward layer
+ dropout: dropout rate
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int = 12,
+ ffd: int = 2048,
+ ffd_ratio: int = 4,
+ dropout: float = 0.1,
+ ):
+ super(PARSeqDecoder, self).__init__()
+ self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
+ self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
+ self.position_feed_forward = PositionwiseFeedForward(
+ d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
+ )
+
+ self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.query_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.content_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.output_norm = layers.LayerNormalization(epsilon=1e-5)
+ self.attention_dropout = layers.Dropout(dropout)
+ self.cross_attention_dropout = layers.Dropout(dropout)
+ self.feed_forward_dropout = layers.Dropout(dropout)
+
+ def call(
+ self,
+ target,
+ content,
+ memory,
+ target_mask=None,
+ **kwargs: Any,
+ ):
+ query_norm = self.query_norm(target, **kwargs)
+ content_norm = self.content_norm(content, **kwargs)
+ target = target + self.attention_dropout(
+ self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs
+ )
+ target = target + self.cross_attention_dropout(
+ self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs
+ )
+ target = target + self.feed_forward_dropout(
+ self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs
+ )
+ return self.output_norm(target, **kwargs)
+
+
+class PARSeq(_PARSeq, Model):
+ """Implements a PARSeq architecture as described in `"Scene Text Recognition
+ with Permuted Autoregressive Sequence Models" `_.
+ Modified implementation based on the official Pytorch implementation: None:
+ super().__init__()
+ self.vocab = vocab
+ self.exportable = exportable
+ self.cfg = cfg
+ self.max_length = max_length
+ self.vocab_size = len(vocab)
+ self.rng = np.random.default_rng()
+
+ self.feat_extractor = feature_extractor
+ self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
+ self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
+ self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS
+ self.pos_queries = self.add_weight(
+ shape=(1, self.max_length + 1, embedding_units),
+ initializer="zeros",
+ trainable=True,
+ name="positions",
+ )
+ self.dropout = layers.Dropout(dropout_prob)
+
+ self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
+
+ @tf.function
+ def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
+ # Generates permutations of the target sequence.
+ # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
+ # with small modifications
+
+ max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch
+ perms = [tf.range(max_num_chars, dtype=tf.int32)]
+
+ max_perms = math.factorial(max_num_chars) // 2
+ num_gen_perms = min(3, max_perms)
+ if max_num_chars < 5:
+ # Pool of permutations to sample from. We only need the first half (if complementary option is selected)
+ # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
+ if max_num_chars == 4:
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
+ else:
+ selector = list(range(max_perms))
+ perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars))
+ perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector])
+ # If the forward permutation is always selected, no need to add it to the pool for sampling
+ perm_pool = perm_pool[1:]
+ final_perms = tf.stack(perms)
+ if len(perm_pool):
+ i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
+ final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
+ else:
+ perms.extend([
+ tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
+ ])
+ final_perms = tf.stack(perms)
+
+ comp = tf.reverse(final_perms, axis=[-1])
+ final_perms = tf.stack([final_perms, comp])
+ final_perms = tf.transpose(final_perms, perm=[1, 0, 2])
+ final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars))
+
+ sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32)
+ eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1)
+ combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1)
+ combined = tf.cast(combined, dtype=tf.int32)
+ if tf.shape(combined)[0] > 1:
+ combined = tf.tensor_scatter_nd_update(
+ combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
+ )
+ return combined
+
+ @tf.function
+ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ # Generate source and target mask for the decoder attention.
+ sz = permutation.shape[0]
+ mask = tf.ones((sz, sz), dtype=tf.float32)
+
+ for i in range(sz - 1):
+ query_idx = int(permutation[i])
+ masked_keys = permutation[i + 1 :].numpy().tolist()
+ indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32)
+ mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32))
+
+ source_mask = tf.identity(mask[:-1, :-1])
+ eye_indices = tf.eye(sz, dtype=tf.bool)
+ mask = tf.tensor_scatter_nd_update(
+ mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
+ )
+ target_mask = mask[1:, :-1]
+ return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
+
+ @tf.function
+ def decode(
+ self,
+ target: tf.Tensor,
+ memory: tf,
+ target_mask: Optional[tf.Tensor] = None,
+ target_query: Optional[tf.Tensor] = None,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ batch_size, sequence_length = target.shape
+ # apply positional information to the target sequence excluding the SOS token
+ null_ctx = self.embed(target[:, :1], **kwargs)
+ content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
+ content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
+ if target_query is None:
+ target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
+ target_query = self.dropout(target_query, **kwargs)
+ return self.decoder(target_query, content, memory, target_mask, **kwargs)
+
+ @tf.function
+ def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
+ """Generate predictions for the given features."""
+ max_length = max_len if max_len is not None else self.max_length
+ max_length = min(max_length, self.max_length) + 1
+ b = tf.shape(features)[0]
+ # Padding symbol + SOS at the beginning
+ ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
+ start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
+ ys = tf.concat([start_vector, ys], axis=-1)
+ pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
+ query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)
+
+ pos_logits = []
+ for i in range(max_length):
+ # Decode one token at a time without providing information about the future tokens
+ tgt_out = self.decode(
+ ys[:, : i + 1],
+ features,
+ query_mask[i : i + 1, : i + 1],
+ target_query=pos_queries[:, i : i + 1],
+ **kwargs,
+ )
+ pos_prob = self.head(tgt_out)
+ pos_logits.append(pos_prob)
+
+ if i + 1 < max_length:
+ # update ys with the next token
+ i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
+ indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
+ ys = tf.tensor_scatter_nd_update(
+ ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
+ )
+
+ # Stop decoding if all sequences have reached the EOS token
+ # We need to check it on True to be compatible with ONNX
+ if (
+ max_len is None
+ and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
+ ):
+ break
+
+ logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)
+
+ # One refine iteration
+ # Update query mask
+ diag_matrix = tf.eye(max_length)
+ diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
+ query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)
+
+ sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
+ ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
+ # Create padding mask for refined target input maskes all behind EOS token as False
+ # (N, 1, 1, max_length)
+ mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
+ first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
+ mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
+ target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)
+
+ mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
+ logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)
+
+ return logits # (N, max_length, vocab_size + 1)
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
+ # remove cls token
+ features = features[:, 1:, :]
+
+ if kwargs.get("training", False) and target is None:
+ raise ValueError("Need to provide labels during training")
+
+ if target is not None:
+ gt, seq_len = self.build_target(target)
+ seq_len = tf.cast(seq_len, tf.int32)
+ gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
+
+ if kwargs.get("training", False):
+ # Generate permutations of the target sequences
+ tgt_perms = self.generate_permutations(seq_len)
+
+ gt_in = gt[:, :-1] # remove EOS token from longest target sequence
+ gt_out = gt[:, 1:] # remove SOS token
+
+ # Create padding mask for target input
+ # [True, True, True, ..., False, False, False] -> False is masked
+ padding_mask = tf.math.logical_and(
+ tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
+ )
+ padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)
+
+ loss = tf.constant(0.0)
+ loss_numel = tf.constant(0.0)
+ n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
+ for i, perm in enumerate(tgt_perms):
+ _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
+ # combine both masks to (N, 1, seq_len, seq_len)
+ mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))
+
+ logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
+ logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
+ targets_flat = tf.reshape(gt_out, (-1,))
+ mask = tf.not_equal(targets_flat, self.vocab_size + 2)
+ loss += n * tf.reduce_mean(
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
+ )
+ )
+ loss_numel += n
+
+ # After the second iteration (i.e. done with canonical and reverse orderings),
+ # remove the [EOS] tokens for the succeeding perms
+ if i == 1:
+ gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
+ n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
+
+ loss /= loss_numel
+
+ else:
+ gt = gt[:, 1:] # remove SOS token
+ max_len = gt.shape[1] - 1 # exclude EOS token
+ logits = self.decode_autoregressive(features, max_len, **kwargs)
+ logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
+ targets_flat = tf.reshape(gt, (-1,))
+ mask = tf.not_equal(targets_flat, self.vocab_size + 2)
+ loss = tf.reduce_mean(
+ tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
+ )
+ )
+ else:
+ logits = self.decode_autoregressive(features, **kwargs)
+
+ logits = _bf16_to_float32(logits)
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = logits
+ return out
+
+ if return_model_output:
+ out["out_map"] = logits
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(logits)
+
+ if target is not None:
+ out["loss"] = loss
+
+ return out
+
+
+class PARSeqPostProcessor(_PARSeqPostProcessor):
+ """Post processor for PARSeq architecture
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __call__(
+ self,
+ logits: tf.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = tf.math.argmax(logits, axis=2)
+ preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
+
+ # decode raw output of the model with tf_label_to_idx
+ out_idxs = tf.cast(out_idxs, dtype="int32")
+ embedding = tf.constant(self._embedding, dtype=tf.string)
+ decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
+ decoded_strings_pred = tf.strings.split(decoded_strings_pred, "")
+ decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
+ word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
+
+ # compute probabilties for each word up to the EOS token
+ probs = [
+ preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
+ for i, word in enumerate(word_values)
+ ]
+
+ return list(zip(word_values, probs))
+
+
+def _parseq(
+ arch: str,
+ pretrained: bool,
+ backbone_fn,
+ input_shape: Optional[Tuple[int, int, int]] = None,
+ **kwargs: Any,
+) -> PARSeq:
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["input_shape"] = input_shape or _cfg["input_shape"]
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
+ patch_size = kwargs.get("patch_size", (4, 8))
+
+ kwargs["vocab"] = _cfg["vocab"]
+
+ # Feature extractor
+ feat_extractor = backbone_fn(
+ # NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
+ pretrained=False,
+ input_shape=_cfg["input_shape"],
+ patch_size=patch_size,
+ include_top=False,
+ )
+
+ kwargs.pop("patch_size", None)
+ kwargs.pop("pretrained_backbone", None)
+
+ # Build the model
+ model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ load_pretrained_params(model, default_cfgs[arch]["url"])
+
+ return model
+
+
+def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
+ """PARSeq architecture from
+ `"Scene Text Recognition with Permuted Autoregressive Sequence Models" `_.
+
+ >>> import tensorflow as tf
+ >>> from doctr.models import parseq
+ >>> model = parseq(pretrained=False)
+ >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the PARSeq architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _parseq(
+ "parseq",
+ pretrained,
+ vit_s,
+ embedding_units=384,
+ patch_size=(4, 8),
+ **kwargs,
+ )
diff --git a/doctr/models/recognition/predictor/__init__.py b/doctr/models/recognition/predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff30c3b2e7d34bf85e30291e39f9d3206c0f4bdd
--- /dev/null
+++ b/doctr/models/recognition/predictor/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available
+
+if is_tf_available():
+ from .tensorflow import *
+else:
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/recognition/predictor/_utils.py b/doctr/models/recognition/predictor/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac98d41862f04efec4ec105bbc50000b35b55555
--- /dev/null
+++ b/doctr/models/recognition/predictor/_utils.py
@@ -0,0 +1,86 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import List, Tuple, Union
+
+import numpy as np
+
+from ..utils import merge_multi_strings
+
+__all__ = ["split_crops", "remap_preds"]
+
+
+def split_crops(
+ crops: List[np.ndarray],
+ max_ratio: float,
+ target_ratio: int,
+ dilation: float,
+ channels_last: bool = True,
+) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
+ """Chunk crops horizontally to match a given aspect ratio
+
+ Args:
+ ----
+ crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
+ max_ratio: the maximum aspect ratio that won't trigger the chunk
+ target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
+ dilation: the width dilation of final chunks (to provide some overlaps)
+ channels_last: whether the numpy array has dimensions in channels last order
+
+ Returns:
+ -------
+ a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
+ """
+ _remap_required = False
+ crop_map: List[Union[int, Tuple[int, int]]] = []
+ new_crops: List[np.ndarray] = []
+ for crop in crops:
+ h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
+ aspect_ratio = w / h
+ if aspect_ratio > max_ratio:
+ # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
+ num_subcrops = int(aspect_ratio // target_ratio)
+ # Find the new widths, additional dilation factor to overlap crops
+ width = dilation * w / num_subcrops
+ centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
+ # Get the crops
+ if channels_last:
+ _crops = [
+ crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
+ for center in centers
+ ]
+ else:
+ _crops = [
+ crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
+ for center in centers
+ ]
+ # Avoid sending zero-sized crops
+ _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
+ # Record the slice of crops
+ crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
+ new_crops.extend(_crops)
+ # At least one crop will require merging
+ _remap_required = True
+ else:
+ crop_map.append(len(new_crops))
+ new_crops.append(crop)
+
+ return new_crops, crop_map, _remap_required
+
+
+def remap_preds(
+ preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
+) -> List[Tuple[str, float]]:
+ remapped_out = []
+ for _idx in crop_map:
+ # Crop hasn't been split
+ if isinstance(_idx, int):
+ remapped_out.append(preds[_idx])
+ else:
+ # unzip
+ vals, probs = zip(*preds[_idx[0] : _idx[1]])
+ # Merge the string values
+ remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
+ return remapped_out
diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71202f7c28c5dacdc6ab4e605695ddb24df3ff8
--- /dev/null
+++ b/doctr/models/recognition/predictor/pytorch.py
@@ -0,0 +1,86 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.models.utils import set_device_and_dtype
+
+from ._utils import remap_preds, split_crops
+
+__all__ = ["RecognitionPredictor"]
+
+
+class RecognitionPredictor(nn.Module):
+ """Implements an object able to identify character sequences in images
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core detection architecture
+ split_wide_crops: wether to use crop splitting for high aspect ratio crops
+ """
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: nn.Module,
+ split_wide_crops: bool = True,
+ ) -> None:
+ super().__init__()
+ self.pre_processor = pre_processor
+ self.model = model.eval()
+ self.split_wide_crops = split_wide_crops
+ self.critical_ar = 8 # Critical aspect ratio
+ self.dil_factor = 1.4 # Dilation factor to overlap the crops
+ self.target_ar = 6 # Target aspect ratio
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ crops: Sequence[Union[np.ndarray, torch.Tensor]],
+ **kwargs: Any,
+ ) -> List[Tuple[str, float]]:
+ if len(crops) == 0:
+ return []
+ # Dimension check
+ if any(crop.ndim != 3 for crop in crops):
+ raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
+
+ # Split crops that are too wide
+ remapped = False
+ if self.split_wide_crops:
+ new_crops, crop_map, remapped = split_crops(
+ crops, # type: ignore[arg-type]
+ self.critical_ar,
+ self.target_ar,
+ self.dil_factor,
+ isinstance(crops[0], np.ndarray),
+ )
+ if remapped:
+ crops = new_crops
+
+ # Resize & batch them
+ processed_batches = self.pre_processor(crops)
+
+ # Forward it
+ _params = next(self.model.parameters())
+ self.model, processed_batches = set_device_and_dtype(
+ self.model, processed_batches, _params.device, _params.dtype
+ )
+ raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
+
+ # Process outputs
+ out = [charseq for batch in raw for charseq in batch]
+
+ # Remap crops
+ if self.split_wide_crops and remapped:
+ out = remap_preds(out, crop_map, self.dil_factor)
+
+ return out
diff --git a/doctr/models/recognition/predictor/tensorflow.py b/doctr/models/recognition/predictor/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..409f39323af4159b342f1a60b0682aca6142c51b
--- /dev/null
+++ b/doctr/models/recognition/predictor/tensorflow.py
@@ -0,0 +1,80 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from doctr.models.preprocessor import PreProcessor
+from doctr.utils.repr import NestedObject
+
+from ..core import RecognitionModel
+from ._utils import remap_preds, split_crops
+
+__all__ = ["RecognitionPredictor"]
+
+
+class RecognitionPredictor(NestedObject):
+ """Implements an object able to identify character sequences in images
+
+ Args:
+ ----
+ pre_processor: transform inputs for easier batched model inference
+ model: core detection architecture
+ split_wide_crops: wether to use crop splitting for high aspect ratio crops
+ """
+
+ _children_names: List[str] = ["pre_processor", "model"]
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: RecognitionModel,
+ split_wide_crops: bool = True,
+ ) -> None:
+ super().__init__()
+ self.pre_processor = pre_processor
+ self.model = model
+ self.split_wide_crops = split_wide_crops
+ self.critical_ar = 8 # Critical aspect ratio
+ self.dil_factor = 1.4 # Dilation factor to overlap the crops
+ self.target_ar = 6 # Target aspect ratio
+
+ def __call__(
+ self,
+ crops: List[Union[np.ndarray, tf.Tensor]],
+ **kwargs: Any,
+ ) -> List[Tuple[str, float]]:
+ if len(crops) == 0:
+ return []
+ # Dimension check
+ if any(crop.ndim != 3 for crop in crops):
+ raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
+
+ # Split crops that are too wide
+ remapped = False
+ if self.split_wide_crops:
+ new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor)
+ if remapped:
+ crops = new_crops
+
+ # Resize & batch them
+ processed_batches = self.pre_processor(crops)
+
+ # Forward it
+ raw = [
+ self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator]
+ for batch in processed_batches
+ ]
+
+ # Process outputs
+ out = [charseq for batch in raw for charseq in batch]
+
+ # Remap crops
+ if self.split_wide_crops and remapped:
+ out = remap_preds(out, crop_map, self.dil_factor)
+
+ return out
diff --git a/doctr/models/recognition/sar/__init__.py b/doctr/models/recognition/sar/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7110f5669d4e8637b00a46e3fc34ece581eb10a
--- /dev/null
+++ b/doctr/models/recognition/sar/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+ from .tensorflow import *
+elif is_torch_available():
+ from .pytorch import * # type: ignore[assignment]
diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66bd32036c11ea66ebb18fcaf4aec7742869e1d
--- /dev/null
+++ b/doctr/models/recognition/sar/pytorch.py
@@ -0,0 +1,402 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models._utils import IntermediateLayerGetter
+
+from doctr.datasets import VOCABS
+
+from ...classification import resnet31
+from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
+from ..core import RecognitionModel, RecognitionPostProcessor
+
+__all__ = ["SAR", "sar_resnet31"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "sar_resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (3, 32, 128),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.7.0/sar_resnet31-9a1deedf.pt&src=0",
+ },
+}
+
+
+class SAREncoder(nn.Module):
+ def __init__(self, in_feats: int, rnn_units: int, dropout_prob: float = 0.0) -> None:
+ super().__init__()
+ self.rnn = nn.LSTM(in_feats, rnn_units, 2, batch_first=True, dropout=dropout_prob)
+ self.linear = nn.Linear(rnn_units, rnn_units)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # (N, L, C) --> (N, T, C)
+ encoded = self.rnn(x)[0]
+ # (N, C)
+ return self.linear(encoded[:, -1, :])
+
+
+class AttentionModule(nn.Module):
+ def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None:
+ super().__init__()
+ self.feat_conv = nn.Conv2d(feat_chans, attention_units, kernel_size=3, padding=1)
+ # No need to add another bias since both tensors are summed together
+ self.state_conv = nn.Conv2d(state_chans, attention_units, kernel_size=1, bias=False)
+ self.attention_projector = nn.Conv2d(attention_units, 1, kernel_size=1, bias=False)
+
+ def forward(
+ self,
+ features: torch.Tensor, # (N, C, H, W)
+ hidden_state: torch.Tensor, # (N, C)
+ ) -> torch.Tensor:
+ H_f, W_f = features.shape[2:]
+
+ # (N, feat_chans, H, W) --> (N, attention_units, H, W)
+ feat_projection = self.feat_conv(features)
+ # (N, state_chans, 1, 1) --> (N, attention_units, 1, 1)
+ hidden_state = hidden_state.view(hidden_state.size(0), hidden_state.size(1), 1, 1)
+ state_projection = self.state_conv(hidden_state)
+ state_projection = state_projection.expand(-1, -1, H_f, W_f)
+ # (N, attention_units, 1, 1) --> (N, attention_units, H_f, W_f)
+ attention_weights = torch.tanh(feat_projection + state_projection)
+ # (N, attention_units, H_f, W_f) --> (N, 1, H_f, W_f)
+ attention_weights = self.attention_projector(attention_weights)
+ B, C, H, W = attention_weights.size()
+
+ # (N, H, W) --> (N, 1, H, W)
+ attention_weights = torch.softmax(attention_weights.view(B, -1), dim=-1).view(B, C, H, W)
+ # fuse features and attention weights (N, C)
+ return (features * attention_weights).sum(dim=(2, 3))
+
+
+class SARDecoder(nn.Module):
+ """Implements decoder module of the SAR model
+
+ Args:
+ ----
+ rnn_units: number of hidden units in recurrent cells
+ max_length: maximum length of a sequence
+ vocab_size: number of classes in the model alphabet
+ embedding_units: number of hidden embedding units
+ attention_units: number of hidden attention units
+
+ """
+
+ def __init__(
+ self,
+ rnn_units: int,
+ max_length: int,
+ vocab_size: int,
+ embedding_units: int,
+ attention_units: int,
+ feat_chans: int = 512,
+ dropout_prob: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_length = max_length
+
+ self.embed = nn.Linear(self.vocab_size + 1, embedding_units)
+ self.embed_tgt = nn.Embedding(embedding_units, self.vocab_size + 1)
+ self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units)
+ self.lstm_cell = nn.LSTMCell(rnn_units, rnn_units)
+ self.output_dense = nn.Linear(2 * rnn_units, self.vocab_size + 1)
+ self.dropout = nn.Dropout(dropout_prob)
+
+ def forward(
+ self,
+ features: torch.Tensor, # (N, C, H, W)
+ holistic: torch.Tensor, # (N, C)
+ gt: Optional[torch.Tensor] = None, # (N, L)
+ ) -> torch.Tensor:
+ if gt is not None:
+ gt_embedding = self.embed_tgt(gt)
+
+ logits_list: List[torch.Tensor] = []
+
+ for t in range(self.max_length + 1): # 32
+ if t == 0:
+ # step to init the first states of the LSTMCell
+ hidden_state_init = cell_state_init = torch.zeros(
+ features.size(0), features.size(1), device=features.device, dtype=features.dtype
+ )
+ hidden_state, cell_state = hidden_state_init, cell_state_init
+ prev_symbol = holistic
+ elif t == 1:
+ # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
+ # (N, vocab_size + 1) --> (N, embedding_units)
+ prev_symbol = torch.zeros(
+ features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
+ )
+ prev_symbol = self.embed(prev_symbol)
+ else:
+ if gt is not None and self.training:
+ # (N, embedding_units) -2 because of and (same)
+ prev_symbol = self.embed(gt_embedding[:, t - 2])
+ else:
+ # -1 to start at timestep where prev_symbol was initialized
+ index = logits_list[t - 1].argmax(-1)
+ # update prev_symbol with ones at the index of the previous logit vector
+ prev_symbol = self.embed(self.embed_tgt(index))
+
+ # (N, C), (N, C) take the last hidden state and cell state from current timestep
+ hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
+ hidden_state, cell_state = self.lstm_cell(hidden_state_init, (hidden_state, cell_state))
+ # (N, C, H, W), (N, C) --> (N, C)
+ glimpse = self.attention_module(features, hidden_state)
+ # (N, C), (N, C) --> (N, 2 * C)
+ logits = torch.cat([hidden_state, glimpse], dim=1)
+ logits = self.dropout(logits)
+ # (N, vocab_size + 1)
+ logits_list.append(self.output_dense(logits))
+
+ # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
+ return torch.stack(logits_list[1:]).permute(1, 0, 2)
+
+
+class SAR(nn.Module, RecognitionModel):
+ """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
+ Irregular Text Recognition" `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary used for encoding
+ rnn_units: number of hidden units in both encoder and decoder LSTM
+ embedding_units: number of embedding units
+ attention_units: number of hidden units in attention module
+ max_length: maximum word length handled by the model
+ dropout_prob: dropout probability of the encoder LSTM
+ exportable: onnx exportable returns only logits
+ cfg: dictionary containing information about the model
+ """
+
+ def __init__(
+ self,
+ feature_extractor,
+ vocab: str,
+ rnn_units: int = 512,
+ embedding_units: int = 512,
+ attention_units: int = 512,
+ max_length: int = 30,
+ dropout_prob: float = 0.0,
+ input_shape: Tuple[int, int, int] = (3, 32, 128),
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+ self.vocab = vocab
+ self.exportable = exportable
+ self.cfg = cfg
+
+ self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
+
+ self.feat_extractor = feature_extractor
+
+ # Size the LSTM
+ self.feat_extractor.eval()
+ with torch.no_grad():
+ out_shape = self.feat_extractor(torch.zeros((1, *input_shape)))["features"].shape
+ # Switch back to original mode
+ self.feat_extractor.train()
+
+ self.encoder = SAREncoder(out_shape[1], rnn_units, dropout_prob)
+ self.decoder = SARDecoder(
+ rnn_units,
+ self.max_length,
+ len(self.vocab),
+ embedding_units,
+ attention_units,
+ dropout_prob=dropout_prob,
+ )
+
+ self.postprocessor = SARPostProcessor(vocab=vocab)
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ ) -> Dict[str, Any]:
+ features = self.feat_extractor(x)["features"]
+ # NOTE: use max instead of functional max_pool2d which leads to ONNX incompatibility (kernel_size)
+ # Vertical max pooling (N, C, H, W) --> (N, C, W)
+ pooled_features = features.max(dim=-2).values
+ # (N, W, C)
+ pooled_features = pooled_features.permute(0, 2, 1).contiguous()
+ # (N, C)
+ encoded = self.encoder(pooled_features)
+ if target is not None:
+ _gt, _seq_len = self.build_target(target)
+ gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
+ gt, seq_len = gt.to(x.device), seq_len.to(x.device)
+
+ if self.training and target is None:
+ raise ValueError("Need to provide labels during training for teacher forcing")
+
+ decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))
+
+ out: Dict[str, Any] = {}
+ if self.exportable:
+ out["logits"] = decoded_features
+ return out
+
+ if return_model_output:
+ out["out_map"] = decoded_features
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(decoded_features)
+
+ if target is not None:
+ out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
+
+ return out
+
+ @staticmethod
+ def compute_loss(
+ model_output: torch.Tensor,
+ gt: torch.Tensor,
+ seq_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute categorical cross-entropy loss for the model.
+ Sequences are masked after the EOS character.
+
+ Args:
+ ----
+ model_output: predicted logits of the model
+ gt: the encoded tensor with gt labels
+ seq_len: lengths of each gt word inside the batch
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ # Input length : number of timesteps
+ input_len = model_output.shape[1]
+ # Add one for additional token
+ seq_len = seq_len + 1
+ # Compute loss
+ # (N, L, vocab_size + 1)
+ cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
+ mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None]
+ cce[mask_2d] = 0
+
+ ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
+ return ce_loss.mean()
+
+
+class SARPostProcessor(RecognitionPostProcessor):
+ """Post processor for SAR architectures
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = logits.argmax(-1)
+ # N x L
+ probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
+ # Take the minimum confidence of the sequence
+ probs = probs.min(dim=1).values.detach().cpu()
+
+ # Manual decoding
+ word_values = [
+ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0]
+ for encoded_seq in out_idxs.detach().cpu().numpy()
+ ]
+
+ return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))
+
+
+def _sar(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ layer: str,
+ pretrained_backbone: bool = True,
+ ignore_keys: Optional[List[str]] = None,
+ **kwargs: Any,
+) -> SAR:
+ pretrained_backbone = pretrained_backbone and not pretrained
+
+ # Patch the config
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
+ _cfg["input_shape"] = kwargs.get("input_shape", _cfg["input_shape"])
+
+ # Feature extractor
+ feat_extractor = IntermediateLayerGetter(
+ backbone_fn(pretrained_backbone),
+ {layer: "features"},
+ )
+ kwargs["vocab"] = _cfg["vocab"]
+ kwargs["input_shape"] = _cfg["input_shape"]
+
+ # Build the model
+ model = SAR(feat_extractor, cfg=_cfg, **kwargs)
+ # Load pretrained parameters
+ if pretrained:
+ # The number of classes is not the same as the number of classes in the pretrained model =>
+ # remove the last layer weights
+ _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
+ load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR:
+ """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
+ Baseline for Irregular Text Recognition" `_.
+
+ >>> import torch
+ >>> from doctr.models import sar_resnet31
+ >>> model = sar_resnet31(pretrained=False)
+ >>> input_tensor = torch.rand((1, 3, 32, 128))
+ >>> out = model(input_tensor)
+
+ Args:
+ ----
+ pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
+ **kwargs: keyword arguments of the SAR architecture
+
+ Returns:
+ -------
+ text recognition architecture
+ """
+ return _sar(
+ "sar_resnet31",
+ pretrained,
+ resnet31,
+ "10",
+ ignore_keys=[
+ "decoder.embed.weight",
+ "decoder.embed_tgt.weight",
+ "decoder.output_dense.weight",
+ "decoder.output_dense.bias",
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e557c2329e0a6c07efe1f6c10f2d8694109100
--- /dev/null
+++ b/doctr/models/recognition/sar/tensorflow.py
@@ -0,0 +1,421 @@
+# Copyright (C) 2021-2024, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import Model, Sequential, layers
+
+from doctr.datasets import VOCABS
+from doctr.utils.repr import NestedObject
+
+from ...classification import resnet31
+from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
+from ..core import RecognitionModel, RecognitionPostProcessor
+
+__all__ = ["SAR", "sar_resnet31"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+ "sar_resnet31": {
+ "mean": (0.694, 0.695, 0.693),
+ "std": (0.299, 0.296, 0.301),
+ "input_shape": (32, 128, 3),
+ "vocab": VOCABS["french"],
+ "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0",
+ },
+}
+
+
+class SAREncoder(layers.Layer, NestedObject):
+ """Implements encoder module of the SAR model
+
+ Args:
+ ----
+ rnn_units: number of hidden rnn units
+ dropout_prob: dropout probability
+ """
+
+ def __init__(self, rnn_units: int, dropout_prob: float = 0.0) -> None:
+ super().__init__()
+ self.rnn = Sequential([
+ layers.LSTM(units=rnn_units, return_sequences=True, recurrent_dropout=dropout_prob),
+ layers.LSTM(units=rnn_units, return_sequences=False, recurrent_dropout=dropout_prob),
+ ])
+
+ def call(
+ self,
+ x: tf.Tensor,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ # (N, C)
+ return self.rnn(x, **kwargs)
+
+
+class AttentionModule(layers.Layer, NestedObject):
+ """Implements attention module of the SAR model
+
+ Args:
+ ----
+ attention_units: number of hidden attention units
+
+ """
+
+ def __init__(self, attention_units: int) -> None:
+ super().__init__()
+ self.hidden_state_projector = layers.Conv2D(
+ attention_units,
+ 1,
+ strides=1,
+ use_bias=False,
+ padding="same",
+ kernel_initializer="he_normal",
+ )
+ self.features_projector = layers.Conv2D(
+ attention_units,
+ 3,
+ strides=1,
+ use_bias=True,
+ padding="same",
+ kernel_initializer="he_normal",
+ )
+ self.attention_projector = layers.Conv2D(
+ 1,
+ 1,
+ strides=1,
+ use_bias=False,
+ padding="same",
+ kernel_initializer="he_normal",
+ )
+ self.flatten = layers.Flatten()
+
+ def call(
+ self,
+ features: tf.Tensor,
+ hidden_state: tf.Tensor,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ [H, W] = features.get_shape().as_list()[1:3]
+ # shape (N, H, W, vgg_units) -> (N, H, W, attention_units)
+ features_projection = self.features_projector(features, **kwargs)
+ # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units)
+ hidden_state = tf.expand_dims(tf.expand_dims(hidden_state, axis=1), axis=1)
+ hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs)
+ projection = tf.math.tanh(hidden_state_projection + features_projection)
+ # shape (N, H, W, attention_units) -> (N, H, W, 1)
+ attention = self.attention_projector(projection, **kwargs)
+ # shape (N, H, W, 1) -> (N, H * W)
+ attention = self.flatten(attention)
+ attention = tf.nn.softmax(attention)
+ # shape (N, H * W) -> (N, H, W, 1)
+ attention_map = tf.reshape(attention, [-1, H, W, 1])
+ glimpse = tf.math.multiply(features, attention_map)
+ # shape (N, H * W) -> (N, C)
+ return tf.reduce_sum(glimpse, axis=[1, 2])
+
+
+class SARDecoder(layers.Layer, NestedObject):
+ """Implements decoder module of the SAR model
+
+ Args:
+ ----
+ rnn_units: number of hidden units in recurrent cells
+ max_length: maximum length of a sequence
+ vocab_size: number of classes in the model alphabet
+ embedding_units: number of hidden embedding units
+ attention_units: number of hidden attention units
+ num_decoder_cells: number of LSTMCell layers to stack
+ dropout_prob: dropout probability
+
+ """
+
+ def __init__(
+ self,
+ rnn_units: int,
+ max_length: int,
+ vocab_size: int,
+ embedding_units: int,
+ attention_units: int,
+ num_decoder_cells: int = 2,
+ dropout_prob: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.vocab_size = vocab_size
+ self.max_length = max_length
+
+ self.embed = layers.Dense(embedding_units, use_bias=False)
+ self.embed_tgt = layers.Embedding(embedding_units, self.vocab_size + 1)
+
+ self.lstm_cells = layers.StackedRNNCells([
+ layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_cells)
+ ])
+ self.attention_module = AttentionModule(attention_units)
+ self.output_dense = layers.Dense(self.vocab_size + 1, use_bias=True)
+ self.dropout = layers.Dropout(dropout_prob)
+
+ def call(
+ self,
+ features: tf.Tensor,
+ holistic: tf.Tensor,
+ gt: Optional[tf.Tensor] = None,
+ **kwargs: Any,
+ ) -> tf.Tensor:
+ if gt is not None:
+ gt_embedding = self.embed_tgt(gt, **kwargs)
+
+ logits_list: List[tf.Tensor] = []
+
+ for t in range(self.max_length + 1): # 32
+ if t == 0:
+ # step to init the first states of the LSTMCell
+ states = self.lstm_cells.get_initial_state(
+ inputs=None, batch_size=features.shape[0], dtype=features.dtype
+ )
+ prev_symbol = holistic
+ elif t == 1:
+ # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
+ # (N, vocab_size + 1) --> (N, embedding_units)
+ prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
+ prev_symbol = self.embed(prev_symbol, **kwargs)
+ else:
+ if gt is not None and kwargs.get("training", False):
+ # (N, embedding_units) -2 because of and (same)
+ prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
+ else:
+ # -1 to start at timestep where prev_symbol was initialized
+ index = tf.argmax(logits_list[t - 1], axis=-1)
+ # update prev_symbol with ones at the index of the previous logit vector
+ prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
+
+ # (N, C), (N, C) take the last hidden state and cell state from current timestep
+ _, states = self.lstm_cells(prev_symbol, states, **kwargs)
+ # states = (hidden_state, cell_state)
+ hidden_state = states[0][0]
+ # (N, H, W, C), (N, C) --> (N, C)
+ glimpse = self.attention_module(features, hidden_state, **kwargs)
+ # (N, C), (N, C) --> (N, 2 * C)
+ logits = tf.concat([hidden_state, glimpse], axis=1)
+ logits = self.dropout(logits, **kwargs)
+ # (N, vocab_size + 1)
+ logits_list.append(self.output_dense(logits, **kwargs))
+
+ # (max_length + 1, N, vocab_size + 1) --> (N, max_length + 1, vocab_size + 1)
+ return tf.transpose(tf.stack(logits_list[1:]), (1, 0, 2))
+
+
+class SAR(Model, RecognitionModel):
+ """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for
+ Irregular Text Recognition" `_.
+
+ Args:
+ ----
+ feature_extractor: the backbone serving as feature extractor
+ vocab: vocabulary used for encoding
+ rnn_units: number of hidden units in both encoder and decoder LSTM
+ embedding_units: number of embedding units
+ attention_units: number of hidden units in attention module
+ max_length: maximum word length handled by the model
+ num_decoder_cells: number of LSTMCell layers to stack
+ dropout_prob: dropout probability for the encoder and decoder
+ exportable: onnx exportable returns only logits
+ cfg: dictionary containing information about the model
+ """
+
+ _children_names: List[str] = ["feat_extractor", "encoder", "decoder", "postprocessor"]
+
+ def __init__(
+ self,
+ feature_extractor,
+ vocab: str,
+ rnn_units: int = 512,
+ embedding_units: int = 512,
+ attention_units: int = 512,
+ max_length: int = 30,
+ num_decoder_cells: int = 2,
+ dropout_prob: float = 0.0,
+ exportable: bool = False,
+ cfg: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+ self.vocab = vocab
+ self.exportable = exportable
+ self.cfg = cfg
+ self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
+
+ self.feat_extractor = feature_extractor
+
+ self.encoder = SAREncoder(rnn_units, dropout_prob)
+ self.decoder = SARDecoder(
+ rnn_units,
+ self.max_length,
+ len(vocab),
+ embedding_units,
+ attention_units,
+ num_decoder_cells,
+ dropout_prob,
+ )
+
+ self.postprocessor = SARPostProcessor(vocab=vocab)
+
+ @staticmethod
+ def compute_loss(
+ model_output: tf.Tensor,
+ gt: tf.Tensor,
+ seq_len: tf.Tensor,
+ ) -> tf.Tensor:
+ """Compute categorical cross-entropy loss for the model.
+ Sequences are masked after the EOS character.
+
+ Args:
+ ----
+ gt: the encoded tensor with gt labels
+ model_output: predicted logits of the model
+ seq_len: lengths of each gt word inside the batch
+
+ Returns:
+ -------
+ The loss of the model on the batch
+ """
+ # Input length : number of timesteps
+ input_len = tf.shape(model_output)[1]
+ # Add one for additional token
+ seq_len = seq_len + 1
+ # One-hot gt labels
+ oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
+ # Compute loss
+ cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output)
+ # Compute mask
+ mask_values = tf.zeros_like(cce)
+ mask_2d = tf.sequence_mask(seq_len, input_len)
+ masked_loss = tf.where(mask_2d, cce, mask_values)
+ ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))
+ return tf.expand_dims(ce_loss, axis=1)
+
+ def call(
+ self,
+ x: tf.Tensor,
+ target: Optional[List[str]] = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ features = self.feat_extractor(x, **kwargs)
+ # vertical max pooling --> (N, C, W)
+ pooled_features = tf.reduce_max(features, axis=1)
+ # holistic (N, C)
+ encoded = self.encoder(pooled_features, **kwargs)
+
+ if target is not None:
+ gt, seq_len = self.build_target(target)
+ seq_len = tf.cast(seq_len, tf.int32)
+
+ if kwargs.get("training", False) and target is None:
+ raise ValueError("Need to provide labels during training for teacher forcing")
+
+ decoded_features = _bf16_to_float32(
+ self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
+ )
+
+ out: Dict[str, tf.Tensor] = {}
+ if self.exportable:
+ out["logits"] = decoded_features
+ return out
+
+ if return_model_output:
+ out["out_map"] = decoded_features
+
+ if target is None or return_preds:
+ # Post-process boxes
+ out["preds"] = self.postprocessor(decoded_features)
+
+ if target is not None:
+ out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
+
+ return out
+
+
+class SARPostProcessor(RecognitionPostProcessor):
+ """Post processor for SAR architectures
+
+ Args:
+ ----
+ vocab: string containing the ordered sequence of supported characters
+ """
+
+ def __call__(
+ self,
+ logits: tf.Tensor,
+ ) -> List[Tuple[str, float]]:
+ # compute pred with argmax for attention models
+ out_idxs = tf.math.argmax(logits, axis=2)
+ # N x L
+ probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
+ # Take the minimum confidence of the sequence
+ probs = tf.math.reduce_min(probs, axis=1)
+
+ # decode raw output of the model with tf_label_to_idx
+ out_idxs = tf.cast(out_idxs, dtype="int32")
+ embedding = tf.constant(self._embedding, dtype=tf.string)
+ decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
+ decoded_strings_pred = tf.strings.split(decoded_strings_pred, "