initial files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +140 -0
- .pre-commit-config.yaml +23 -0
- CODE_OF_CONDUCT.md +128 -0
- CONTRIBUTING.md +92 -0
- Dockerfile +75 -0
- LICENSE +201 -0
- Makefile +33 -0
- README.md +384 -12
- backend/pytorch.py +93 -0
- doctr/__init__.py +3 -0
- doctr/datasets/__init__.py +26 -0
- doctr/datasets/cord.py +121 -0
- doctr/datasets/datasets/__init__.py +6 -0
- doctr/datasets/datasets/base.py +132 -0
- doctr/datasets/datasets/pytorch.py +59 -0
- doctr/datasets/datasets/tensorflow.py +59 -0
- doctr/datasets/detection.py +98 -0
- doctr/datasets/doc_artefacts.py +82 -0
- doctr/datasets/funsd.py +112 -0
- doctr/datasets/generator/__init__.py +6 -0
- doctr/datasets/generator/base.py +155 -0
- doctr/datasets/generator/pytorch.py +54 -0
- doctr/datasets/generator/tensorflow.py +60 -0
- doctr/datasets/ic03.py +126 -0
- doctr/datasets/ic13.py +99 -0
- doctr/datasets/iiit5k.py +103 -0
- doctr/datasets/iiithws.py +75 -0
- doctr/datasets/imgur5k.py +147 -0
- doctr/datasets/loader.py +102 -0
- doctr/datasets/mjsynth.py +106 -0
- doctr/datasets/ocr.py +71 -0
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +56 -0
- doctr/datasets/sroie.py +103 -0
- doctr/datasets/svhn.py +131 -0
- doctr/datasets/svt.py +117 -0
- doctr/datasets/synthtext.py +128 -0
- doctr/datasets/utils.py +216 -0
- doctr/datasets/vocabs.py +71 -0
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +92 -0
- doctr/io/__init__.py +5 -0
- doctr/io/elements.py +621 -0
- doctr/io/html.py +28 -0
- doctr/io/image/__init__.py +8 -0
- doctr/io/image/base.py +56 -0
- doctr/io/image/pytorch.py +109 -0
- doctr/io/image/tensorflow.py +110 -0
- doctr/io/pdf.py +42 -0
- doctr/io/reader.py +79 -0
.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
# Temp files
|
| 132 |
+
doctr/version.py
|
| 133 |
+
logs/
|
| 134 |
+
wandb/
|
| 135 |
+
.idea/
|
| 136 |
+
|
| 137 |
+
# Checkpoints
|
| 138 |
+
*.pt
|
| 139 |
+
*.pb
|
| 140 |
+
*.index
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.5.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: check-ast
|
| 6 |
+
- id: check-yaml
|
| 7 |
+
exclude: .conda
|
| 8 |
+
- id: check-toml
|
| 9 |
+
- id: check-json
|
| 10 |
+
- id: check-added-large-files
|
| 11 |
+
exclude: docs/images/
|
| 12 |
+
- id: end-of-file-fixer
|
| 13 |
+
- id: trailing-whitespace
|
| 14 |
+
- id: debug-statements
|
| 15 |
+
- id: check-merge-conflict
|
| 16 |
+
- id: no-commit-to-branch
|
| 17 |
+
args: ['--branch', 'main']
|
| 18 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 19 |
+
rev: v0.3.2
|
| 20 |
+
hooks:
|
| 21 |
+
- id: ruff
|
| 22 |
+
args: [ --fix ]
|
| 23 |
+
- id: ruff-format
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
| 10 |
+
and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
| 26 |
+
overall community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
| 31 |
+
advances of any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email
|
| 35 |
+
address, without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
contact@mindee.com.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series
|
| 86 |
+
of actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or
|
| 93 |
+
permanent ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
| 113 |
+
the community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.0, available at
|
| 119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
| 122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
| 123 |
+
|
| 124 |
+
[homepage]: https://www.contributor-covenant.org
|
| 125 |
+
|
| 126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
| 128 |
+
https://www.contributor-covenant.org/translations.
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to docTR
|
| 2 |
+
|
| 3 |
+
Everything you need to know to contribute efficiently to the project.
|
| 4 |
+
|
| 5 |
+
## Codebase structure
|
| 6 |
+
|
| 7 |
+
- [doctr](https://github.com/mindee/doctr/blob/main/doctr) - The package codebase
|
| 8 |
+
- [tests](https://github.com/mindee/doctr/blob/main/tests) - Python unit tests
|
| 9 |
+
- [docs](https://github.com/mindee/doctr/blob/main/docs) - Library documentation building
|
| 10 |
+
- [scripts](https://github.com/mindee/doctr/blob/main/scripts) - Example scripts
|
| 11 |
+
- [references](https://github.com/mindee/doctr/blob/main/references) - Reference training scripts
|
| 12 |
+
- [demo](https://github.com/mindee/doctr/blob/main/demo) - Small demo app to showcase docTR capabilities
|
| 13 |
+
- [api](https://github.com/mindee/doctr/blob/main/api) - A minimal template to deploy a REST API with docTR
|
| 14 |
+
|
| 15 |
+
## Continuous Integration
|
| 16 |
+
|
| 17 |
+
This project uses the following integrations to ensure proper codebase maintenance:
|
| 18 |
+
|
| 19 |
+
- [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage
|
| 20 |
+
- [Codecov](https://codecov.io/) - reports back coverage results
|
| 21 |
+
|
| 22 |
+
As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code.
|
| 23 |
+
|
| 24 |
+
## Feedback
|
| 25 |
+
|
| 26 |
+
### Feature requests & bug report
|
| 27 |
+
|
| 28 |
+
Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/mindee/doctr/issues).
|
| 29 |
+
|
| 30 |
+
First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in.
|
| 31 |
+
|
| 32 |
+
### Questions
|
| 33 |
+
|
| 34 |
+
If you are wondering how to do something with docTR, or a more general question, you should consider checking out Github [discussions](https://github.com/mindee/doctr/discussions). See it as a Q&A forum, or the docTR-specific StackOverflow!
|
| 35 |
+
|
| 36 |
+
## Developing docTR
|
| 37 |
+
|
| 38 |
+
### Developer mode installation
|
| 39 |
+
|
| 40 |
+
Install all additional dependencies with the following command:
|
| 41 |
+
|
| 42 |
+
```shell
|
| 43 |
+
python -m pip install --upgrade pip
|
| 44 |
+
pip install -e .[dev]
|
| 45 |
+
pre-commit install
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Commits
|
| 49 |
+
|
| 50 |
+
- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later.
|
| 51 |
+
- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/)
|
| 52 |
+
|
| 53 |
+
### Unit tests
|
| 54 |
+
|
| 55 |
+
In order to run the same unit tests as the CI workflows, you can run unittests locally:
|
| 56 |
+
|
| 57 |
+
```shell
|
| 58 |
+
make test
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Code quality
|
| 62 |
+
|
| 63 |
+
To run all quality checks together
|
| 64 |
+
|
| 65 |
+
```shell
|
| 66 |
+
make quality
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
#### Code style verification
|
| 70 |
+
|
| 71 |
+
To run all style checks together
|
| 72 |
+
|
| 73 |
+
```shell
|
| 74 |
+
make style
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Modifying the documentation
|
| 78 |
+
|
| 79 |
+
The current documentation is built using `sphinx` thanks to our CI.
|
| 80 |
+
You can build the documentation locally:
|
| 81 |
+
|
| 82 |
+
```shell
|
| 83 |
+
make docs-single-version
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Please note that files that have not been modified will not be rebuilt. If you want to force a complete rebuild, you can delete the `_build` directory. Additionally, you may need to clear your web browser's cache to see the modifications.
|
| 87 |
+
|
| 88 |
+
You can now open your local version of the documentation located at `docs/_build/index.html` in your browser
|
| 89 |
+
|
| 90 |
+
## Let's connect
|
| 91 |
+
|
| 92 |
+
Should you wish to connect somewhere else than on GitHub, feel free to join us on [Slack](https://join.slack.com/t/mindee-community/shared_invite/zt-uzgmljfl-MotFVfH~IdEZxjp~0zldww), where you will find a `#doctr` channel!
|
Dockerfile
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM ubuntu:22.04
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
ENV LANG=C.UTF-8
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
|
| 8 |
+
ARG SYSTEM=gpu
|
| 9 |
+
|
| 10 |
+
# Enroll NVIDIA GPG public key and install CUDA
|
| 11 |
+
RUN if [ "$SYSTEM" = "gpu" ]; then \
|
| 12 |
+
apt-get update && \
|
| 13 |
+
apt-get install -y gnupg ca-certificates wget && \
|
| 14 |
+
# - Install Nvidia repo keys
|
| 15 |
+
# - See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#network-repo-installation-for-ubuntu
|
| 16 |
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
|
| 17 |
+
dpkg -i cuda-keyring_1.1-1_all.deb && \
|
| 18 |
+
apt-get update && apt-get install -y --no-install-recommends \
|
| 19 |
+
cuda-command-line-tools-11-8 \
|
| 20 |
+
cuda-cudart-dev-11-8 \
|
| 21 |
+
cuda-nvcc-11-8 \
|
| 22 |
+
cuda-cupti-11-8 \
|
| 23 |
+
cuda-nvprune-11-8 \
|
| 24 |
+
cuda-libraries-11-8 \
|
| 25 |
+
cuda-nvrtc-11-8 \
|
| 26 |
+
libcufft-11-8 \
|
| 27 |
+
libcurand-11-8 \
|
| 28 |
+
libcusolver-11-8 \
|
| 29 |
+
libcusparse-11-8 \
|
| 30 |
+
libcublas-11-8 \
|
| 31 |
+
# - CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation
|
| 32 |
+
libcudnn8=8.6.0.163-1+cuda11.8 \
|
| 33 |
+
libnvinfer-plugin8=8.6.1.6-1+cuda11.8 \
|
| 34 |
+
libnvinfer8=8.6.1.6-1+cuda11.8; \
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 38 |
+
# - Other packages
|
| 39 |
+
build-essential \
|
| 40 |
+
pkg-config \
|
| 41 |
+
curl \
|
| 42 |
+
wget \
|
| 43 |
+
software-properties-common \
|
| 44 |
+
unzip \
|
| 45 |
+
git \
|
| 46 |
+
# - Packages to build Python
|
| 47 |
+
tar make gcc zlib1g-dev libffi-dev libssl-dev liblzma-dev libbz2-dev libsqlite3-dev \
|
| 48 |
+
# - Packages for docTR
|
| 49 |
+
libgl1-mesa-dev libsm6 libxext6 libxrender-dev libpangocairo-1.0-0 \
|
| 50 |
+
&& apt-get clean \
|
| 51 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
# Install Python
|
| 55 |
+
ARG PYTHON_VERSION=3.10.13
|
| 56 |
+
|
| 57 |
+
RUN wget http://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \
|
| 58 |
+
tar -zxf Python-$PYTHON_VERSION.tgz && \
|
| 59 |
+
cd Python-$PYTHON_VERSION && \
|
| 60 |
+
mkdir /opt/python/ && \
|
| 61 |
+
./configure --prefix=/opt/python && \
|
| 62 |
+
make && \
|
| 63 |
+
make install && \
|
| 64 |
+
cd .. && \
|
| 65 |
+
rm Python-$PYTHON_VERSION.tgz && \
|
| 66 |
+
rm -r Python-$PYTHON_VERSION
|
| 67 |
+
|
| 68 |
+
ENV PATH=/opt/python/bin:$PATH
|
| 69 |
+
|
| 70 |
+
# Install docTR
|
| 71 |
+
ARG FRAMEWORK=tf
|
| 72 |
+
ARG DOCTR_REPO='mindee/doctr'
|
| 73 |
+
ARG DOCTR_VERSION=main
|
| 74 |
+
RUN pip3 install -U pip setuptools wheel && \
|
| 75 |
+
pip3 install "python-doctr[$FRAMEWORK]@git+https://github.com/$DOCTR_REPO.git@$DOCTR_VERSION"
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2022 Mindee
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Makefile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: quality style test test-common test-tf test-torch docs-single-version docs
|
| 2 |
+
# this target runs checks on all files
|
| 3 |
+
quality:
|
| 4 |
+
ruff check .
|
| 5 |
+
mypy doctr/
|
| 6 |
+
|
| 7 |
+
# this target runs checks on all files and potentially modifies some of them
|
| 8 |
+
style:
|
| 9 |
+
ruff check --fix .
|
| 10 |
+
ruff format .
|
| 11 |
+
|
| 12 |
+
# Run tests for the library
|
| 13 |
+
test:
|
| 14 |
+
coverage run -m pytest tests/common/
|
| 15 |
+
USE_TF='1' coverage run -m pytest tests/tensorflow/
|
| 16 |
+
USE_TORCH='1' coverage run -m pytest tests/pytorch/
|
| 17 |
+
|
| 18 |
+
test-common:
|
| 19 |
+
coverage run -m pytest tests/common/
|
| 20 |
+
|
| 21 |
+
test-tf:
|
| 22 |
+
USE_TF='1' coverage run -m pytest tests/tensorflow/
|
| 23 |
+
|
| 24 |
+
test-torch:
|
| 25 |
+
USE_TORCH='1' coverage run -m pytest tests/pytorch/
|
| 26 |
+
|
| 27 |
+
# Check that docs can build
|
| 28 |
+
docs-single-version:
|
| 29 |
+
sphinx-build docs/source docs/_build -a
|
| 30 |
+
|
| 31 |
+
# Check that docs can build
|
| 32 |
+
docs:
|
| 33 |
+
cd docs && bash build.sh
|
README.md
CHANGED
|
@@ -1,12 +1,384 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="https://github.com/mindee/doctr/raw/main/docs/images/Logo_doctr.gif" width="40%">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
[](https://slack.mindee.com) [](LICENSE)  [](https://github.com/mindee/doctr/pkgs/container/doctr) [](https://codecov.io/gh/mindee/doctr) [](https://www.codefactor.io/repository/github/mindee/doctr) [](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [](https://mindee.github.io/doctr) [](https://pypi.org/project/python-doctr/) [](https://huggingface.co/spaces/mindee/doctr) [](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
**Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
|
| 9 |
+
|
| 10 |
+
What you can expect from this repository:
|
| 11 |
+
|
| 12 |
+
- efficient ways to parse textual information (localize and identify each word) from your documents
|
| 13 |
+
- guidance on how to integrate this in your current architecture
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
## Quick Tour
|
| 18 |
+
|
| 19 |
+
### Getting your pretrained model
|
| 20 |
+
|
| 21 |
+
End-to-End OCR is achieved in docTR using a two-stage approach: text detection (localizing words), then text recognition (identify all characters in the word).
|
| 22 |
+
As such, you can select the architecture used for [text detection](https://mindee.github.io/doctr/latest/modules/models.html#doctr-models-detection), and the one for [text recognition](https://mindee.github.io/doctr/latest//modules/models.html#doctr-models-recognition) from the list of available implementations.
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
from doctr.models import ocr_predictor
|
| 26 |
+
|
| 27 |
+
model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Reading files
|
| 31 |
+
|
| 32 |
+
Documents can be interpreted from PDF or images:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
from doctr.io import DocumentFile
|
| 36 |
+
# PDF
|
| 37 |
+
pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
|
| 38 |
+
# Image
|
| 39 |
+
single_img_doc = DocumentFile.from_images("path/to/your/img.jpg")
|
| 40 |
+
# Webpage
|
| 41 |
+
webpage_doc = DocumentFile.from_url("https://www.yoursite.com")
|
| 42 |
+
# Multiple page images
|
| 43 |
+
multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"])
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### Putting it together
|
| 47 |
+
|
| 48 |
+
Let's use the default pretrained model for an example:
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
from doctr.io import DocumentFile
|
| 52 |
+
from doctr.models import ocr_predictor
|
| 53 |
+
|
| 54 |
+
model = ocr_predictor(pretrained=True)
|
| 55 |
+
# PDF
|
| 56 |
+
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
|
| 57 |
+
# Analyze
|
| 58 |
+
result = model(doc)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Dealing with rotated documents
|
| 62 |
+
|
| 63 |
+
Should you use docTR on documents that include rotated pages, or pages with multiple box orientations,
|
| 64 |
+
you have multiple options to handle it:
|
| 65 |
+
|
| 66 |
+
- If you only use straight document pages with straight words (horizontal, same reading direction),
|
| 67 |
+
consider passing `assume_straight_boxes=True` to the ocr_predictor. It will directly fit straight boxes
|
| 68 |
+
on your page and return straight boxes, which makes it the fastest option.
|
| 69 |
+
|
| 70 |
+
- If you want the predictor to output straight boxes (no matter the orientation of your pages, the final localizations
|
| 71 |
+
will be converted to straight boxes), you need to pass `export_as_straight_boxes=True` in the predictor. Otherwise, if `assume_straight_pages=False`, it will return rotated bounding boxes (potentially with an angle of 0°).
|
| 72 |
+
|
| 73 |
+
If both options are set to False, the predictor will always fit and return rotated boxes.
|
| 74 |
+
|
| 75 |
+
To interpret your model's predictions, you can visualize them interactively as follows:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
result.show()
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+

|
| 82 |
+
|
| 83 |
+
Or even rebuild the original document from its predictions:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
import matplotlib.pyplot as plt
|
| 87 |
+
|
| 88 |
+
synthetic_pages = result.synthesize()
|
| 89 |
+
plt.imshow(synthetic_pages[0]); plt.axis('off'); plt.show()
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+

|
| 93 |
+
|
| 94 |
+
The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
|
| 95 |
+
To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/modules/io.html#document-structure):
|
| 96 |
+
|
| 97 |
+
You can also export them as a nested dict, more appropriate for JSON format:
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
json_output = result.export()
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Use the KIE predictor
|
| 104 |
+
|
| 105 |
+
The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and addresses in a document.
|
| 106 |
+
|
| 107 |
+
The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you.
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
from doctr.io import DocumentFile
|
| 111 |
+
from doctr.models import kie_predictor
|
| 112 |
+
|
| 113 |
+
# Model
|
| 114 |
+
model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
| 115 |
+
# PDF
|
| 116 |
+
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
|
| 117 |
+
# Analyze
|
| 118 |
+
result = model(doc)
|
| 119 |
+
|
| 120 |
+
predictions = result.pages[0].predictions
|
| 121 |
+
for class_name in predictions.keys():
|
| 122 |
+
list_predictions = predictions[class_name]
|
| 123 |
+
for prediction in list_predictions:
|
| 124 |
+
print(f"Prediction for {class_name}: {prediction}")
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.
|
| 128 |
+
|
| 129 |
+
### If you are looking for support from the Mindee team
|
| 130 |
+
|
| 131 |
+
[](https://mindee.com/product/doctr)
|
| 132 |
+
|
| 133 |
+
## Installation
|
| 134 |
+
|
| 135 |
+
### Prerequisites
|
| 136 |
+
|
| 137 |
+
Python 3.9 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR.
|
| 138 |
+
|
| 139 |
+
Since we use [weasyprint](https://weasyprint.org/), you will need extra dependencies if you are not running Linux.
|
| 140 |
+
|
| 141 |
+
For MacOS users, you can install them as follows:
|
| 142 |
+
|
| 143 |
+
```shell
|
| 144 |
+
brew install cairo pango gdk-pixbuf libffi
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases).
|
| 148 |
+
|
| 149 |
+
### Latest release
|
| 150 |
+
|
| 151 |
+
You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows:
|
| 152 |
+
|
| 153 |
+
```shell
|
| 154 |
+
pip install python-doctr
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
> :warning: Please note that the basic installation is not standalone, as it does not provide a deep learning framework, which is required for the package to run.
|
| 158 |
+
|
| 159 |
+
We try to keep framework-specific dependencies to a minimum. You can install framework-specific builds as follows:
|
| 160 |
+
|
| 161 |
+
```shell
|
| 162 |
+
# for TensorFlow
|
| 163 |
+
pip install "python-doctr[tf]"
|
| 164 |
+
# for PyTorch
|
| 165 |
+
pip install "python-doctr[torch]"
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
For MacBooks with M1 chip, you will need some additional packages or specific versions:
|
| 169 |
+
|
| 170 |
+
- TensorFlow 2: [metal plugin](https://developer.apple.com/metal/tensorflow-plugin/)
|
| 171 |
+
- PyTorch: [version >= 1.12.0](https://pytorch.org/get-started/locally/#start-locally)
|
| 172 |
+
|
| 173 |
+
### Developer mode
|
| 174 |
+
|
| 175 |
+
Alternatively, you can install it from source, which will require you to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git).
|
| 176 |
+
First clone the project repository:
|
| 177 |
+
|
| 178 |
+
```shell
|
| 179 |
+
git clone https://github.com/mindee/doctr.git
|
| 180 |
+
pip install -e doctr/.
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
Again, if you prefer to avoid the risk of missing dependencies, you can install the TensorFlow or the PyTorch build:
|
| 184 |
+
|
| 185 |
+
```shell
|
| 186 |
+
# for TensorFlow
|
| 187 |
+
pip install -e doctr/.[tf]
|
| 188 |
+
# for PyTorch
|
| 189 |
+
pip install -e doctr/.[torch]
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
## Models architectures
|
| 193 |
+
|
| 194 |
+
Credits where it's due: this repository is implementing, among others, architectures from published research papers.
|
| 195 |
+
|
| 196 |
+
### Text Detection
|
| 197 |
+
|
| 198 |
+
- DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf).
|
| 199 |
+
- LinkNet: [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/pdf/1707.03718.pdf)
|
| 200 |
+
- FAST: [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/pdf/2111.02394.pdf)
|
| 201 |
+
|
| 202 |
+
### Text Recognition
|
| 203 |
+
|
| 204 |
+
- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
|
| 205 |
+
- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
|
| 206 |
+
- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
|
| 207 |
+
- ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).
|
| 208 |
+
- PARSeq: [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/pdf/2207.06966).
|
| 209 |
+
|
| 210 |
+
## More goodies
|
| 211 |
+
|
| 212 |
+
### Documentation
|
| 213 |
+
|
| 214 |
+
The full package documentation is available [here](https://mindee.github.io/doctr/) for detailed specifications.
|
| 215 |
+
|
| 216 |
+
### Demo app
|
| 217 |
+
|
| 218 |
+
A minimal demo app is provided for you to play with our end-to-end OCR models!
|
| 219 |
+
|
| 220 |
+

|
| 221 |
+
|
| 222 |
+
#### Live demo
|
| 223 |
+
|
| 224 |
+
Courtesy of :hugs: [Hugging Face](https://huggingface.co/) :hugs:, docTR has now a fully deployed version available on [Spaces](https://huggingface.co/spaces)!
|
| 225 |
+
Check it out [](https://huggingface.co/spaces/mindee/doctr)
|
| 226 |
+
|
| 227 |
+
#### Running it locally
|
| 228 |
+
|
| 229 |
+
If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required.
|
| 230 |
+
|
| 231 |
+
##### Tensorflow version
|
| 232 |
+
|
| 233 |
+
```shell
|
| 234 |
+
pip install -r demo/tf-requirements.txt
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
Then run your app in your default browser with:
|
| 238 |
+
|
| 239 |
+
```shell
|
| 240 |
+
USE_TF=1 streamlit run demo/app.py
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
##### PyTorch version
|
| 244 |
+
|
| 245 |
+
```shell
|
| 246 |
+
pip install -r demo/pt-requirements.txt
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
Then run your app in your default browser with:
|
| 250 |
+
|
| 251 |
+
```shell
|
| 252 |
+
USE_TORCH=1 streamlit run demo/app.py
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
#### TensorFlow.js
|
| 256 |
+
|
| 257 |
+
Instead of having your demo actually running Python, you would prefer to run everything in your web browser?
|
| 258 |
+
Check out our [TensorFlow.js demo](https://github.com/mindee/doctr-tfjs-demo) to get started!
|
| 259 |
+
|
| 260 |
+

|
| 261 |
+
|
| 262 |
+
### Docker container
|
| 263 |
+
|
| 264 |
+
[We offer Docker container support for easy testing and deployment](https://github.com/mindee/doctr/pkgs/container/doctr).
|
| 265 |
+
|
| 266 |
+
#### Using GPU with docTR Docker Images
|
| 267 |
+
|
| 268 |
+
The docTR Docker images are GPU-ready and based on CUDA `11.8`.
|
| 269 |
+
However, to use GPU support with these Docker images, please ensure that Docker is configured to use your GPU.
|
| 270 |
+
|
| 271 |
+
To verify and configure GPU support for Docker, please follow the instructions provided in the [NVIDIA Container Toolkit Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
| 272 |
+
|
| 273 |
+
Once Docker is configured to use GPUs, you can run docTR Docker containers with GPU support:
|
| 274 |
+
|
| 275 |
+
```shell
|
| 276 |
+
docker run -it --gpus all ghcr.io/mindee/doctr:tf-py3.8.18-gpu-2023-09 bash
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
#### Available Tags
|
| 280 |
+
|
| 281 |
+
The Docker images for docTR follow a specific tag nomenclature: `<framework>-py<python_version>-<system>-<doctr_version|YYYY-MM>`. Here's a breakdown of the tag structure:
|
| 282 |
+
|
| 283 |
+
- `<framework>`: `tf` (TensorFlow) or `torch` (PyTorch).
|
| 284 |
+
- `<python_version>`: `3.8.18`, `3.9.18`, or `3.10.13`.
|
| 285 |
+
- `<system>`: `cpu` or `gpu`
|
| 286 |
+
- `<doctr_version>`: a tag >= `v0.7.1`
|
| 287 |
+
- `<YYYY-MM>`: e.g. `2023-09`
|
| 288 |
+
|
| 289 |
+
Here are examples of different image tags:
|
| 290 |
+
|
| 291 |
+
| Tag | Description |
|
| 292 |
+
|----------------------------|---------------------------------------------------|
|
| 293 |
+
| `tf-py3.8.18-cpu-v0.7.1` | TensorFlow version `3.8.18` with docTR `v0.7.1`. |
|
| 294 |
+
| `torch-py3.9.18-gpu-2023-09`| PyTorch version `3.9.18` with GPU support and a monthly build from `2023-09`. |
|
| 295 |
+
|
| 296 |
+
#### Building Docker Images Locally
|
| 297 |
+
|
| 298 |
+
You can also build docTR Docker images locally on your computer.
|
| 299 |
+
|
| 300 |
+
```shell
|
| 301 |
+
docker build -t doctr .
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
You can specify custom Python versions and docTR versions using build arguments. For example, to build a docTR image with TensorFlow, Python version `3.9.10`, and docTR version `v0.7.0`, run the following command:
|
| 305 |
+
|
| 306 |
+
```shell
|
| 307 |
+
docker build -t doctr --build-arg FRAMEWORK=tf --build-arg PYTHON_VERSION=3.9.10 --build-arg DOCTR_VERSION=v0.7.0 .
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
### Example script
|
| 311 |
+
|
| 312 |
+
An example script is provided for a simple documentation analysis of a PDF or image file:
|
| 313 |
+
|
| 314 |
+
```shell
|
| 315 |
+
python scripts/analyze.py path/to/your/doc.pdf
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
All script arguments can be checked using `python scripts/analyze.py --help`
|
| 319 |
+
|
| 320 |
+
### Minimal API integration
|
| 321 |
+
|
| 322 |
+
Looking to integrate docTR into your API? Here is a template to get you started with a fully working API using the wonderful [FastAPI](https://github.com/tiangolo/fastapi) framework.
|
| 323 |
+
|
| 324 |
+
#### Deploy your API locally
|
| 325 |
+
|
| 326 |
+
Specific dependencies are required to run the API template, which you can install as follows:
|
| 327 |
+
|
| 328 |
+
```shell
|
| 329 |
+
cd api/
|
| 330 |
+
pip install poetry
|
| 331 |
+
make lock
|
| 332 |
+
pip install -r requirements.txt
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
You can now run your API locally:
|
| 336 |
+
|
| 337 |
+
```shell
|
| 338 |
+
uvicorn --reload --workers 1 --host 0.0.0.0 --port=8002 --app-dir api/ app.main:app
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
Alternatively, you can run the same server on a docker container if you prefer using:
|
| 342 |
+
|
| 343 |
+
```shell
|
| 344 |
+
PORT=8002 docker-compose up -d --build
|
| 345 |
+
```
|
| 346 |
+
|
| 347 |
+
#### What you have deployed
|
| 348 |
+
|
| 349 |
+
Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route:
|
| 350 |
+
|
| 351 |
+
```python
|
| 352 |
+
import requests
|
| 353 |
+
with open('/path/to/your/doc.jpg', 'rb') as f:
|
| 354 |
+
data = f.read()
|
| 355 |
+
response = requests.post("http://localhost:8002/ocr", files={'file': data}).json()
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
### Example notebooks
|
| 359 |
+
|
| 360 |
+
Looking for more illustrations of docTR features? You might want to check the [Jupyter notebooks](https://github.com/mindee/doctr/tree/main/notebooks) designed to give you a broader overview.
|
| 361 |
+
|
| 362 |
+
## Citation
|
| 363 |
+
|
| 364 |
+
If you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference:
|
| 365 |
+
|
| 366 |
+
```bibtex
|
| 367 |
+
@misc{doctr2021,
|
| 368 |
+
title={docTR: Document Text Recognition},
|
| 369 |
+
author={Mindee},
|
| 370 |
+
year={2021},
|
| 371 |
+
publisher = {GitHub},
|
| 372 |
+
howpublished = {\url{https://github.com/mindee/doctr}}
|
| 373 |
+
}
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
## Contributing
|
| 377 |
+
|
| 378 |
+
If you scrolled down to this section, you most likely appreciate open source. Do you feel like extending the range of our supported characters? Or perhaps submitting a paper implementation? Or contributing in any other way?
|
| 379 |
+
|
| 380 |
+
You're in luck, we compiled a short guide (cf. [`CONTRIBUTING`](https://mindee.github.io/doctr/contributing/contributing.html)) for you to easily do so!
|
| 381 |
+
|
| 382 |
+
## License
|
| 383 |
+
|
| 384 |
+
Distributed under the Apache 2.0 License. See [`LICENSE`](https://github.com/mindee/doctr?tab=Apache-2.0-1-ov-file#readme) for more information.
|
backend/pytorch.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from doctr.models import ocr_predictor
|
| 10 |
+
from doctr.models.predictor import OCRPredictor
|
| 11 |
+
|
| 12 |
+
DET_ARCHS = [
|
| 13 |
+
"db_resnet50",
|
| 14 |
+
"db_resnet34",
|
| 15 |
+
"db_mobilenet_v3_large",
|
| 16 |
+
"linknet_resnet18",
|
| 17 |
+
"linknet_resnet34",
|
| 18 |
+
"linknet_resnet50",
|
| 19 |
+
"fast_tiny",
|
| 20 |
+
"fast_small",
|
| 21 |
+
"fast_base",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
RECO_ARCHS = [
|
| 25 |
+
"crnn_vgg16_bn",
|
| 26 |
+
"crnn_mobilenet_v3_small",
|
| 27 |
+
"crnn_mobilenet_v3_large",
|
| 28 |
+
"master",
|
| 29 |
+
"sar_resnet31",
|
| 30 |
+
"vitstr_small",
|
| 31 |
+
"vitstr_base",
|
| 32 |
+
"parseq",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_predictor(
|
| 37 |
+
det_arch: str,
|
| 38 |
+
reco_arch: str,
|
| 39 |
+
assume_straight_pages: bool,
|
| 40 |
+
straighten_pages: bool,
|
| 41 |
+
bin_thresh: float,
|
| 42 |
+
box_thresh: float,
|
| 43 |
+
device: torch.device,
|
| 44 |
+
) -> OCRPredictor:
|
| 45 |
+
"""Load a predictor from doctr.models
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
----
|
| 49 |
+
det_arch: detection architecture
|
| 50 |
+
reco_arch: recognition architecture
|
| 51 |
+
assume_straight_pages: whether to assume straight pages or not
|
| 52 |
+
straighten_pages: whether to straighten rotated pages or not
|
| 53 |
+
bin_thresh: binarization threshold for the segmentation map
|
| 54 |
+
box_thresh: minimal objectness score to consider a box
|
| 55 |
+
device: torch.device, the device to load the predictor on
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
-------
|
| 59 |
+
instance of OCRPredictor
|
| 60 |
+
"""
|
| 61 |
+
predictor = ocr_predictor(
|
| 62 |
+
det_arch,
|
| 63 |
+
reco_arch,
|
| 64 |
+
pretrained=True,
|
| 65 |
+
assume_straight_pages=assume_straight_pages,
|
| 66 |
+
straighten_pages=straighten_pages,
|
| 67 |
+
export_as_straight_boxes=straighten_pages,
|
| 68 |
+
detect_orientation=not assume_straight_pages,
|
| 69 |
+
).to(device)
|
| 70 |
+
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
|
| 71 |
+
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
|
| 72 |
+
return predictor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
|
| 76 |
+
"""Forward an image through the predictor
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
----
|
| 80 |
+
predictor: instance of OCRPredictor
|
| 81 |
+
image: image to process
|
| 82 |
+
device: torch.device, the device to process the image on
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
-------
|
| 86 |
+
segmentation map
|
| 87 |
+
"""
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
processed_batches = predictor.det_predictor.pre_processor([image])
|
| 90 |
+
out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
|
| 91 |
+
seg_map = out["out_map"].to("cpu").numpy()
|
| 92 |
+
|
| 93 |
+
return seg_map
|
doctr/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import io, models, datasets, transforms, utils
|
| 2 |
+
from .file_utils import is_tf_available, is_torch_available
|
| 3 |
+
from .version import __version__ # noqa: F401
|
doctr/datasets/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from doctr.file_utils import is_tf_available
|
| 2 |
+
|
| 3 |
+
from .generator import *
|
| 4 |
+
from .cord import *
|
| 5 |
+
from .detection import *
|
| 6 |
+
from .doc_artefacts import *
|
| 7 |
+
from .funsd import *
|
| 8 |
+
from .ic03 import *
|
| 9 |
+
from .ic13 import *
|
| 10 |
+
from .iiit5k import *
|
| 11 |
+
from .iiithws import *
|
| 12 |
+
from .imgur5k import *
|
| 13 |
+
from .mjsynth import *
|
| 14 |
+
from .ocr import *
|
| 15 |
+
from .recognition import *
|
| 16 |
+
from .orientation import *
|
| 17 |
+
from .sroie import *
|
| 18 |
+
from .svhn import *
|
| 19 |
+
from .svt import *
|
| 20 |
+
from .synthtext import *
|
| 21 |
+
from .utils import *
|
| 22 |
+
from .vocabs import *
|
| 23 |
+
from .wildreceipt import *
|
| 24 |
+
|
| 25 |
+
if is_tf_available():
|
| 26 |
+
from .loader import *
|
doctr/datasets/cord.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from .datasets import VisionDataset
|
| 15 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 16 |
+
|
| 17 |
+
__all__ = ["CORD"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CORD(VisionDataset):
|
| 21 |
+
"""CORD dataset from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing"
|
| 22 |
+
<https://openreview.net/pdf?id=SJl3z659UH>`_.
|
| 23 |
+
|
| 24 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/cord-grid.png&src=0
|
| 25 |
+
:align: center
|
| 26 |
+
|
| 27 |
+
>>> from doctr.datasets import CORD
|
| 28 |
+
>>> train_set = CORD(train=True, download=True)
|
| 29 |
+
>>> img, target = train_set[0]
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
----
|
| 33 |
+
train: whether the subset should be the training one
|
| 34 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 35 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 36 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
TRAIN = (
|
| 40 |
+
"https://doctr-static.mindee.com/models?id=v0.1.1/cord_train.zip&src=0",
|
| 41 |
+
"45f9dc77f126490f3e52d7cb4f70ef3c57e649ea86d19d862a2757c9c455d7f8",
|
| 42 |
+
"cord_train.zip",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
TEST = (
|
| 46 |
+
"https://doctr-static.mindee.com/models?id=v0.1.1/cord_test.zip&src=0",
|
| 47 |
+
"8c895e3d6f7e1161c5b7245e3723ce15c04d84be89eaa6093949b75a66fb3c58",
|
| 48 |
+
"cord_test.zip",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
train: bool = True,
|
| 54 |
+
use_polygons: bool = False,
|
| 55 |
+
recognition_task: bool = False,
|
| 56 |
+
**kwargs: Any,
|
| 57 |
+
) -> None:
|
| 58 |
+
url, sha256, name = self.TRAIN if train else self.TEST
|
| 59 |
+
super().__init__(
|
| 60 |
+
url,
|
| 61 |
+
name,
|
| 62 |
+
sha256,
|
| 63 |
+
True,
|
| 64 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 65 |
+
**kwargs,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# List images
|
| 69 |
+
tmp_root = os.path.join(self.root, "image")
|
| 70 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 71 |
+
self.train = train
|
| 72 |
+
np_dtype = np.float32
|
| 73 |
+
for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
|
| 74 |
+
# File existence check
|
| 75 |
+
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
| 76 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
| 77 |
+
|
| 78 |
+
stem = Path(img_path).stem
|
| 79 |
+
_targets = []
|
| 80 |
+
with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f:
|
| 81 |
+
label = json.load(f)
|
| 82 |
+
for line in label["valid_line"]:
|
| 83 |
+
for word in line["words"]:
|
| 84 |
+
if len(word["text"]) > 0:
|
| 85 |
+
x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
|
| 86 |
+
y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
|
| 87 |
+
box: Union[List[float], np.ndarray]
|
| 88 |
+
if use_polygons:
|
| 89 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 90 |
+
box = np.array(
|
| 91 |
+
[
|
| 92 |
+
[x[0], y[0]],
|
| 93 |
+
[x[1], y[1]],
|
| 94 |
+
[x[2], y[2]],
|
| 95 |
+
[x[3], y[3]],
|
| 96 |
+
],
|
| 97 |
+
dtype=np_dtype,
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
# Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax
|
| 101 |
+
box = [min(x), min(y), max(x), max(y)]
|
| 102 |
+
_targets.append((word["text"], box))
|
| 103 |
+
|
| 104 |
+
text_targets, box_targets = zip(*_targets)
|
| 105 |
+
|
| 106 |
+
if recognition_task:
|
| 107 |
+
crops = crop_bboxes_from_image(
|
| 108 |
+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
| 109 |
+
)
|
| 110 |
+
for crop, label in zip(crops, list(text_targets)):
|
| 111 |
+
self.data.append((crop, label))
|
| 112 |
+
else:
|
| 113 |
+
self.data.append((
|
| 114 |
+
img_path,
|
| 115 |
+
dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
|
| 116 |
+
))
|
| 117 |
+
|
| 118 |
+
self.root = tmp_root
|
| 119 |
+
|
| 120 |
+
def extra_repr(self) -> str:
|
| 121 |
+
return f"train={self.train}"
|
doctr/datasets/datasets/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from doctr.file_utils import is_tf_available, is_torch_available
|
| 2 |
+
|
| 3 |
+
if is_tf_available():
|
| 4 |
+
from .tensorflow import *
|
| 5 |
+
elif is_torch_available():
|
| 6 |
+
from .pytorch import * # type: ignore[assignment]
|
doctr/datasets/datasets/base.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from doctr.io.image import get_img_shape
|
| 14 |
+
from doctr.utils.data import download_from_url
|
| 15 |
+
|
| 16 |
+
from ...models.utils import _copy_tensor
|
| 17 |
+
|
| 18 |
+
__all__ = ["_AbstractDataset", "_VisionDataset"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _AbstractDataset:
|
| 22 |
+
data: List[Any] = []
|
| 23 |
+
_pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
root: Union[str, Path],
|
| 28 |
+
img_transforms: Optional[Callable[[Any], Any]] = None,
|
| 29 |
+
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
|
| 30 |
+
pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
|
| 31 |
+
) -> None:
|
| 32 |
+
if not Path(root).is_dir():
|
| 33 |
+
raise ValueError(f"expected a path to a reachable folder: {root}")
|
| 34 |
+
|
| 35 |
+
self.root = root
|
| 36 |
+
self.img_transforms = img_transforms
|
| 37 |
+
self.sample_transforms = sample_transforms
|
| 38 |
+
self._pre_transforms = pre_transforms
|
| 39 |
+
self._get_img_shape = get_img_shape
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.data)
|
| 43 |
+
|
| 44 |
+
def _read_sample(self, index: int) -> Tuple[Any, Any]:
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 48 |
+
# Read image
|
| 49 |
+
img, target = self._read_sample(index)
|
| 50 |
+
# Pre-transforms (format conversion at run-time etc.)
|
| 51 |
+
if self._pre_transforms is not None:
|
| 52 |
+
img, target = self._pre_transforms(img, target)
|
| 53 |
+
|
| 54 |
+
if self.img_transforms is not None:
|
| 55 |
+
# typing issue cf. https://github.com/python/mypy/issues/5485
|
| 56 |
+
img = self.img_transforms(img)
|
| 57 |
+
|
| 58 |
+
if self.sample_transforms is not None:
|
| 59 |
+
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
|
| 60 |
+
if (
|
| 61 |
+
isinstance(target, dict)
|
| 62 |
+
and all(isinstance(item, np.ndarray) for item in target.values())
|
| 63 |
+
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
|
| 64 |
+
):
|
| 65 |
+
img_transformed = _copy_tensor(img)
|
| 66 |
+
for class_name, bboxes in target.items():
|
| 67 |
+
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
|
| 68 |
+
img = img_transformed
|
| 69 |
+
else:
|
| 70 |
+
img, target = self.sample_transforms(img, target)
|
| 71 |
+
|
| 72 |
+
return img, target
|
| 73 |
+
|
| 74 |
+
def extra_repr(self) -> str:
|
| 75 |
+
return ""
|
| 76 |
+
|
| 77 |
+
def __repr__(self) -> str:
|
| 78 |
+
return f"{self.__class__.__name__}({self.extra_repr()})"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class _VisionDataset(_AbstractDataset):
|
| 82 |
+
"""Implements an abstract dataset
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
----
|
| 86 |
+
url: URL of the dataset
|
| 87 |
+
file_name: name of the file once downloaded
|
| 88 |
+
file_hash: expected SHA256 of the file
|
| 89 |
+
extract_archive: whether the downloaded file is an archive to be extracted
|
| 90 |
+
download: whether the dataset should be downloaded if not present on disk
|
| 91 |
+
overwrite: whether the archive should be re-extracted
|
| 92 |
+
cache_dir: cache directory
|
| 93 |
+
cache_subdir: subfolder to use in the cache
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
url: str,
|
| 99 |
+
file_name: Optional[str] = None,
|
| 100 |
+
file_hash: Optional[str] = None,
|
| 101 |
+
extract_archive: bool = False,
|
| 102 |
+
download: bool = False,
|
| 103 |
+
overwrite: bool = False,
|
| 104 |
+
cache_dir: Optional[str] = None,
|
| 105 |
+
cache_subdir: Optional[str] = None,
|
| 106 |
+
**kwargs: Any,
|
| 107 |
+
) -> None:
|
| 108 |
+
cache_dir = (
|
| 109 |
+
str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
|
| 110 |
+
if cache_dir is None
|
| 111 |
+
else cache_dir
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
cache_subdir = "datasets" if cache_subdir is None else cache_subdir
|
| 115 |
+
|
| 116 |
+
file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
|
| 117 |
+
# Download the file if not present
|
| 118 |
+
archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
|
| 119 |
+
|
| 120 |
+
if not os.path.exists(archive_path) and not download:
|
| 121 |
+
raise ValueError("the dataset needs to be downloaded first with download=True")
|
| 122 |
+
|
| 123 |
+
archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
|
| 124 |
+
|
| 125 |
+
# Extract the archive
|
| 126 |
+
if extract_archive:
|
| 127 |
+
archive_path = Path(archive_path)
|
| 128 |
+
dataset_path = archive_path.parent.joinpath(archive_path.stem)
|
| 129 |
+
if not dataset_path.is_dir() or overwrite:
|
| 130 |
+
shutil.unpack_archive(archive_path, dataset_path)
|
| 131 |
+
|
| 132 |
+
super().__init__(dataset_path if extract_archive else archive_path, **kwargs)
|
doctr/datasets/datasets/pytorch.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from typing import Any, List, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from doctr.io import read_img_as_tensor, tensor_from_numpy
|
| 14 |
+
|
| 15 |
+
from .base import _AbstractDataset, _VisionDataset
|
| 16 |
+
|
| 17 |
+
__all__ = ["AbstractDataset", "VisionDataset"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AbstractDataset(_AbstractDataset):
|
| 21 |
+
"""Abstract class for all datasets"""
|
| 22 |
+
|
| 23 |
+
def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
|
| 24 |
+
img_name, target = self.data[index]
|
| 25 |
+
|
| 26 |
+
# Check target
|
| 27 |
+
if isinstance(target, dict):
|
| 28 |
+
assert "boxes" in target, "Target should contain 'boxes' key"
|
| 29 |
+
assert "labels" in target, "Target should contain 'labels' key"
|
| 30 |
+
elif isinstance(target, tuple):
|
| 31 |
+
assert len(target) == 2
|
| 32 |
+
assert isinstance(target[0], str) or isinstance(
|
| 33 |
+
target[0], np.ndarray
|
| 34 |
+
), "first element of the tuple should be a string or a numpy array"
|
| 35 |
+
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
| 36 |
+
else:
|
| 37 |
+
assert isinstance(target, str) or isinstance(
|
| 38 |
+
target, np.ndarray
|
| 39 |
+
), "Target should be a string or a numpy array"
|
| 40 |
+
|
| 41 |
+
# Read image
|
| 42 |
+
img = (
|
| 43 |
+
tensor_from_numpy(img_name, dtype=torch.float32)
|
| 44 |
+
if isinstance(img_name, np.ndarray)
|
| 45 |
+
else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return img, deepcopy(target)
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
|
| 52 |
+
images, targets = zip(*samples)
|
| 53 |
+
images = torch.stack(images, dim=0) # type: ignore[assignment]
|
| 54 |
+
|
| 55 |
+
return images, list(targets) # type: ignore[return-value]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
| 59 |
+
pass
|
doctr/datasets/datasets/tensorflow.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from typing import Any, List, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
from doctr.io import read_img_as_tensor, tensor_from_numpy
|
| 14 |
+
|
| 15 |
+
from .base import _AbstractDataset, _VisionDataset
|
| 16 |
+
|
| 17 |
+
__all__ = ["AbstractDataset", "VisionDataset"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AbstractDataset(_AbstractDataset):
|
| 21 |
+
"""Abstract class for all datasets"""
|
| 22 |
+
|
| 23 |
+
def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
|
| 24 |
+
img_name, target = self.data[index]
|
| 25 |
+
|
| 26 |
+
# Check target
|
| 27 |
+
if isinstance(target, dict):
|
| 28 |
+
assert "boxes" in target, "Target should contain 'boxes' key"
|
| 29 |
+
assert "labels" in target, "Target should contain 'labels' key"
|
| 30 |
+
elif isinstance(target, tuple):
|
| 31 |
+
assert len(target) == 2
|
| 32 |
+
assert isinstance(target[0], str) or isinstance(
|
| 33 |
+
target[0], np.ndarray
|
| 34 |
+
), "first element of the tuple should be a string or a numpy array"
|
| 35 |
+
assert isinstance(target[1], list), "second element of the tuple should be a list"
|
| 36 |
+
else:
|
| 37 |
+
assert isinstance(target, str) or isinstance(
|
| 38 |
+
target, np.ndarray
|
| 39 |
+
), "Target should be a string or a numpy array"
|
| 40 |
+
|
| 41 |
+
# Read image
|
| 42 |
+
img = (
|
| 43 |
+
tensor_from_numpy(img_name, dtype=tf.float32)
|
| 44 |
+
if isinstance(img_name, np.ndarray)
|
| 45 |
+
else read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return img, deepcopy(target)
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
|
| 52 |
+
images, targets = zip(*samples)
|
| 53 |
+
images = tf.stack(images, axis=0)
|
| 54 |
+
|
| 55 |
+
return images, list(targets)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
|
| 59 |
+
pass
|
doctr/datasets/detection.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any, Dict, List, Tuple, Type, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from doctr.file_utils import CLASS_NAME
|
| 13 |
+
|
| 14 |
+
from .datasets import AbstractDataset
|
| 15 |
+
from .utils import pre_transform_multiclass
|
| 16 |
+
|
| 17 |
+
__all__ = ["DetectionDataset"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DetectionDataset(AbstractDataset):
|
| 21 |
+
"""Implements a text detection dataset
|
| 22 |
+
|
| 23 |
+
>>> from doctr.datasets import DetectionDataset
|
| 24 |
+
>>> train_set = DetectionDataset(img_folder="/path/to/images",
|
| 25 |
+
>>> label_path="/path/to/labels.json")
|
| 26 |
+
>>> img, target = train_set[0]
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
----
|
| 30 |
+
img_folder: folder with all the images of the dataset
|
| 31 |
+
label_path: path to the annotations of each image
|
| 32 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 33 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
img_folder: str,
|
| 39 |
+
label_path: str,
|
| 40 |
+
use_polygons: bool = False,
|
| 41 |
+
**kwargs: Any,
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__(
|
| 44 |
+
img_folder,
|
| 45 |
+
pre_transforms=pre_transform_multiclass,
|
| 46 |
+
**kwargs,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# File existence check
|
| 50 |
+
self._class_names: List = []
|
| 51 |
+
if not os.path.exists(label_path):
|
| 52 |
+
raise FileNotFoundError(f"unable to locate {label_path}")
|
| 53 |
+
with open(label_path, "rb") as f:
|
| 54 |
+
labels = json.load(f)
|
| 55 |
+
|
| 56 |
+
self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
|
| 57 |
+
np_dtype = np.float32
|
| 58 |
+
for img_name, label in labels.items():
|
| 59 |
+
# File existence check
|
| 60 |
+
if not os.path.exists(os.path.join(self.root, img_name)):
|
| 61 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
|
| 62 |
+
|
| 63 |
+
geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)
|
| 64 |
+
|
| 65 |
+
self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
|
| 66 |
+
|
| 67 |
+
def format_polygons(
|
| 68 |
+
self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
|
| 69 |
+
) -> Tuple[np.ndarray, List[str]]:
|
| 70 |
+
"""Format polygons into an array
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
----
|
| 74 |
+
polygons: the bounding boxes
|
| 75 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 76 |
+
np_dtype: dtype of array
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
-------
|
| 80 |
+
geoms: bounding boxes as np array
|
| 81 |
+
polygons_classes: list of classes for each bounding box
|
| 82 |
+
"""
|
| 83 |
+
if isinstance(polygons, list):
|
| 84 |
+
self._class_names += [CLASS_NAME]
|
| 85 |
+
polygons_classes = [CLASS_NAME for _ in polygons]
|
| 86 |
+
_polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
|
| 87 |
+
elif isinstance(polygons, dict):
|
| 88 |
+
self._class_names += list(polygons.keys())
|
| 89 |
+
polygons_classes = [k for k, v in polygons.items() for _ in v]
|
| 90 |
+
_polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
|
| 91 |
+
else:
|
| 92 |
+
raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
|
| 93 |
+
geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
|
| 94 |
+
return geoms, polygons_classes
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def class_names(self):
|
| 98 |
+
return sorted(set(self._class_names))
|
doctr/datasets/doc_artefacts.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any, Dict, List, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from .datasets import VisionDataset
|
| 13 |
+
|
| 14 |
+
__all__ = ["DocArtefacts"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DocArtefacts(VisionDataset):
|
| 18 |
+
"""Object detection dataset for non-textual elements in documents.
|
| 19 |
+
The dataset includes a variety of synthetic document pages with non-textual elements.
|
| 20 |
+
|
| 21 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0
|
| 22 |
+
:align: center
|
| 23 |
+
|
| 24 |
+
>>> from doctr.datasets import DocArtefacts
|
| 25 |
+
>>> train_set = DocArtefacts(train=True, download=True)
|
| 26 |
+
>>> img, target = train_set[0]
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
----
|
| 30 |
+
train: whether the subset should be the training one
|
| 31 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 32 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0"
|
| 36 |
+
SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b"
|
| 37 |
+
CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"]
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
train: bool = True,
|
| 42 |
+
use_polygons: bool = False,
|
| 43 |
+
**kwargs: Any,
|
| 44 |
+
) -> None:
|
| 45 |
+
super().__init__(self.URL, None, self.SHA256, True, **kwargs)
|
| 46 |
+
self.train = train
|
| 47 |
+
|
| 48 |
+
# Update root
|
| 49 |
+
self.root = os.path.join(self.root, "train" if train else "val")
|
| 50 |
+
# List images
|
| 51 |
+
tmp_root = os.path.join(self.root, "images")
|
| 52 |
+
with open(os.path.join(self.root, "labels.json"), "rb") as f:
|
| 53 |
+
labels = json.load(f)
|
| 54 |
+
self.data: List[Tuple[str, Dict[str, Any]]] = []
|
| 55 |
+
img_list = os.listdir(tmp_root)
|
| 56 |
+
if len(labels) != len(img_list):
|
| 57 |
+
raise AssertionError("the number of images and labels do not match")
|
| 58 |
+
np_dtype = np.float32
|
| 59 |
+
for img_name, label in labels.items():
|
| 60 |
+
# File existence check
|
| 61 |
+
if not os.path.exists(os.path.join(tmp_root, img_name)):
|
| 62 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
|
| 63 |
+
|
| 64 |
+
# xmin, ymin, xmax, ymax
|
| 65 |
+
boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype)
|
| 66 |
+
classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64)
|
| 67 |
+
if use_polygons:
|
| 68 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 69 |
+
boxes = np.stack(
|
| 70 |
+
[
|
| 71 |
+
np.stack([boxes[:, 0], boxes[:, 1]], axis=-1),
|
| 72 |
+
np.stack([boxes[:, 2], boxes[:, 1]], axis=-1),
|
| 73 |
+
np.stack([boxes[:, 2], boxes[:, 3]], axis=-1),
|
| 74 |
+
np.stack([boxes[:, 0], boxes[:, 3]], axis=-1),
|
| 75 |
+
],
|
| 76 |
+
axis=1,
|
| 77 |
+
)
|
| 78 |
+
self.data.append((img_name, dict(boxes=boxes, labels=classes)))
|
| 79 |
+
self.root = tmp_root
|
| 80 |
+
|
| 81 |
+
def extra_repr(self) -> str:
|
| 82 |
+
return f"train={self.train}"
|
doctr/datasets/funsd.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from .datasets import VisionDataset
|
| 15 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 16 |
+
|
| 17 |
+
__all__ = ["FUNSD"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FUNSD(VisionDataset):
|
| 21 |
+
"""FUNSD dataset from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents"
|
| 22 |
+
<https://arxiv.org/pdf/1905.13538.pdf>`_.
|
| 23 |
+
|
| 24 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/funsd-grid.png&src=0
|
| 25 |
+
:align: center
|
| 26 |
+
|
| 27 |
+
>>> from doctr.datasets import FUNSD
|
| 28 |
+
>>> train_set = FUNSD(train=True, download=True)
|
| 29 |
+
>>> img, target = train_set[0]
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
----
|
| 33 |
+
train: whether the subset should be the training one
|
| 34 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 35 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 36 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
URL = "https://guillaumejaume.github.io/FUNSD/dataset.zip"
|
| 40 |
+
SHA256 = "c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f"
|
| 41 |
+
FILE_NAME = "funsd.zip"
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
train: bool = True,
|
| 46 |
+
use_polygons: bool = False,
|
| 47 |
+
recognition_task: bool = False,
|
| 48 |
+
**kwargs: Any,
|
| 49 |
+
) -> None:
|
| 50 |
+
super().__init__(
|
| 51 |
+
self.URL,
|
| 52 |
+
self.FILE_NAME,
|
| 53 |
+
self.SHA256,
|
| 54 |
+
True,
|
| 55 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 56 |
+
**kwargs,
|
| 57 |
+
)
|
| 58 |
+
self.train = train
|
| 59 |
+
np_dtype = np.float32
|
| 60 |
+
|
| 61 |
+
# Use the subset
|
| 62 |
+
subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
|
| 63 |
+
|
| 64 |
+
# # List images
|
| 65 |
+
tmp_root = os.path.join(self.root, subfolder, "images")
|
| 66 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 67 |
+
for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
|
| 68 |
+
# File existence check
|
| 69 |
+
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
| 70 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
| 71 |
+
|
| 72 |
+
stem = Path(img_path).stem
|
| 73 |
+
with open(os.path.join(self.root, subfolder, "annotations", f"{stem}.json"), "rb") as f:
|
| 74 |
+
data = json.load(f)
|
| 75 |
+
|
| 76 |
+
_targets = [
|
| 77 |
+
(word["text"], word["box"])
|
| 78 |
+
for block in data["form"]
|
| 79 |
+
for word in block["words"]
|
| 80 |
+
if len(word["text"]) > 0
|
| 81 |
+
]
|
| 82 |
+
text_targets, box_targets = zip(*_targets)
|
| 83 |
+
if use_polygons:
|
| 84 |
+
# xmin, ymin, xmax, ymax -> (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 85 |
+
box_targets = [ # type: ignore[assignment]
|
| 86 |
+
[
|
| 87 |
+
[box[0], box[1]],
|
| 88 |
+
[box[2], box[1]],
|
| 89 |
+
[box[2], box[3]],
|
| 90 |
+
[box[0], box[3]],
|
| 91 |
+
]
|
| 92 |
+
for box in box_targets
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
if recognition_task:
|
| 96 |
+
crops = crop_bboxes_from_image(
|
| 97 |
+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=np_dtype)
|
| 98 |
+
)
|
| 99 |
+
for crop, label in zip(crops, list(text_targets)):
|
| 100 |
+
# filter labels with unknown characters
|
| 101 |
+
if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
|
| 102 |
+
self.data.append((crop, label))
|
| 103 |
+
else:
|
| 104 |
+
self.data.append((
|
| 105 |
+
img_path,
|
| 106 |
+
dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(text_targets)),
|
| 107 |
+
))
|
| 108 |
+
|
| 109 |
+
self.root = tmp_root
|
| 110 |
+
|
| 111 |
+
def extra_repr(self) -> str:
|
| 112 |
+
return f"train={self.train}"
|
doctr/datasets/generator/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from doctr.file_utils import is_tf_available, is_torch_available
|
| 2 |
+
|
| 3 |
+
if is_tf_available():
|
| 4 |
+
from .tensorflow import *
|
| 5 |
+
elif is_torch_available():
|
| 6 |
+
from .pytorch import * # type: ignore[assignment]
|
doctr/datasets/generator/base.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
from PIL import Image, ImageDraw
|
| 10 |
+
|
| 11 |
+
from doctr.io.image import tensor_from_pil
|
| 12 |
+
from doctr.utils.fonts import get_font
|
| 13 |
+
|
| 14 |
+
from ..datasets import AbstractDataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def synthesize_text_img(
|
| 18 |
+
text: str,
|
| 19 |
+
font_size: int = 32,
|
| 20 |
+
font_family: Optional[str] = None,
|
| 21 |
+
background_color: Optional[Tuple[int, int, int]] = None,
|
| 22 |
+
text_color: Optional[Tuple[int, int, int]] = None,
|
| 23 |
+
) -> Image.Image:
|
| 24 |
+
"""Generate a synthetic text image
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
----
|
| 28 |
+
text: the text to render as an image
|
| 29 |
+
font_size: the size of the font
|
| 30 |
+
font_family: the font family (has to be installed on your system)
|
| 31 |
+
background_color: background color of the final image
|
| 32 |
+
text_color: text color on the final image
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
-------
|
| 36 |
+
PIL image of the text
|
| 37 |
+
"""
|
| 38 |
+
background_color = (0, 0, 0) if background_color is None else background_color
|
| 39 |
+
text_color = (255, 255, 255) if text_color is None else text_color
|
| 40 |
+
|
| 41 |
+
font = get_font(font_family, font_size)
|
| 42 |
+
left, top, right, bottom = font.getbbox(text)
|
| 43 |
+
text_w, text_h = right - left, bottom - top
|
| 44 |
+
h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
|
| 45 |
+
# If single letter, make the image square, otherwise expand to meet the text size
|
| 46 |
+
img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))
|
| 47 |
+
|
| 48 |
+
img = Image.new("RGB", img_size[::-1], color=background_color)
|
| 49 |
+
d = ImageDraw.Draw(img)
|
| 50 |
+
|
| 51 |
+
# Offset so that the text is centered
|
| 52 |
+
text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
|
| 53 |
+
# Draw the text
|
| 54 |
+
d.text(text_pos, text, font=font, fill=text_color)
|
| 55 |
+
return img
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class _CharacterGenerator(AbstractDataset):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
vocab: str,
|
| 62 |
+
num_samples: int,
|
| 63 |
+
cache_samples: bool = False,
|
| 64 |
+
font_family: Optional[Union[str, List[str]]] = None,
|
| 65 |
+
img_transforms: Optional[Callable[[Any], Any]] = None,
|
| 66 |
+
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
self.vocab = vocab
|
| 69 |
+
self._num_samples = num_samples
|
| 70 |
+
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
|
| 71 |
+
# Validate fonts
|
| 72 |
+
if isinstance(font_family, list):
|
| 73 |
+
for font in self.font_family:
|
| 74 |
+
try:
|
| 75 |
+
_ = get_font(font, 10)
|
| 76 |
+
except OSError:
|
| 77 |
+
raise ValueError(f"unable to locate font: {font}")
|
| 78 |
+
self.img_transforms = img_transforms
|
| 79 |
+
self.sample_transforms = sample_transforms
|
| 80 |
+
|
| 81 |
+
self._data: List[Image.Image] = []
|
| 82 |
+
if cache_samples:
|
| 83 |
+
self._data = [
|
| 84 |
+
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
|
| 85 |
+
for idx, char in enumerate(self.vocab)
|
| 86 |
+
for font in self.font_family
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
def __len__(self) -> int:
|
| 90 |
+
return self._num_samples
|
| 91 |
+
|
| 92 |
+
def _read_sample(self, index: int) -> Tuple[Any, int]:
|
| 93 |
+
# Samples are already cached
|
| 94 |
+
if len(self._data) > 0:
|
| 95 |
+
idx = index % len(self._data)
|
| 96 |
+
pil_img, target = self._data[idx] # type: ignore[misc]
|
| 97 |
+
else:
|
| 98 |
+
target = index % len(self.vocab)
|
| 99 |
+
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
|
| 100 |
+
img = tensor_from_pil(pil_img)
|
| 101 |
+
|
| 102 |
+
return img, target
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class _WordGenerator(AbstractDataset):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab: str,
|
| 109 |
+
min_chars: int,
|
| 110 |
+
max_chars: int,
|
| 111 |
+
num_samples: int,
|
| 112 |
+
cache_samples: bool = False,
|
| 113 |
+
font_family: Optional[Union[str, List[str]]] = None,
|
| 114 |
+
img_transforms: Optional[Callable[[Any], Any]] = None,
|
| 115 |
+
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
|
| 116 |
+
) -> None:
|
| 117 |
+
self.vocab = vocab
|
| 118 |
+
self.wordlen_range = (min_chars, max_chars)
|
| 119 |
+
self._num_samples = num_samples
|
| 120 |
+
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
|
| 121 |
+
# Validate fonts
|
| 122 |
+
if isinstance(font_family, list):
|
| 123 |
+
for font in self.font_family:
|
| 124 |
+
try:
|
| 125 |
+
_ = get_font(font, 10)
|
| 126 |
+
except OSError:
|
| 127 |
+
raise ValueError(f"unable to locate font: {font}")
|
| 128 |
+
self.img_transforms = img_transforms
|
| 129 |
+
self.sample_transforms = sample_transforms
|
| 130 |
+
|
| 131 |
+
self._data: List[Image.Image] = []
|
| 132 |
+
if cache_samples:
|
| 133 |
+
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
|
| 134 |
+
self._data = [
|
| 135 |
+
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
|
| 136 |
+
for text in _words
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
def _generate_string(self, min_chars: int, max_chars: int) -> str:
|
| 140 |
+
num_chars = random.randint(min_chars, max_chars)
|
| 141 |
+
return "".join(random.choice(self.vocab) for _ in range(num_chars))
|
| 142 |
+
|
| 143 |
+
def __len__(self) -> int:
|
| 144 |
+
return self._num_samples
|
| 145 |
+
|
| 146 |
+
def _read_sample(self, index: int) -> Tuple[Any, str]:
|
| 147 |
+
# Samples are already cached
|
| 148 |
+
if len(self._data) > 0:
|
| 149 |
+
pil_img, target = self._data[index] # type: ignore[misc]
|
| 150 |
+
else:
|
| 151 |
+
target = self._generate_string(*self.wordlen_range)
|
| 152 |
+
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
|
| 153 |
+
img = tensor_from_pil(pil_img)
|
| 154 |
+
|
| 155 |
+
return img, target
|
doctr/datasets/generator/pytorch.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from torch.utils.data._utils.collate import default_collate
|
| 7 |
+
|
| 8 |
+
from .base import _CharacterGenerator, _WordGenerator
|
| 9 |
+
|
| 10 |
+
__all__ = ["CharacterGenerator", "WordGenerator"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CharacterGenerator(_CharacterGenerator):
|
| 14 |
+
"""Implements a character image generation dataset
|
| 15 |
+
|
| 16 |
+
>>> from doctr.datasets import CharacterGenerator
|
| 17 |
+
>>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
|
| 18 |
+
>>> img, target = ds[0]
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
----
|
| 22 |
+
vocab: vocabulary to take the character from
|
| 23 |
+
num_samples: number of samples that will be generated iterating over the dataset
|
| 24 |
+
cache_samples: whether generated images should be cached firsthand
|
| 25 |
+
font_family: font to use to generate the text images
|
| 26 |
+
img_transforms: composable transformations that will be applied to each image
|
| 27 |
+
sample_transforms: composable transformations that will be applied to both the image and the target
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
setattr(self, "collate_fn", default_collate)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class WordGenerator(_WordGenerator):
|
| 36 |
+
"""Implements a character image generation dataset
|
| 37 |
+
|
| 38 |
+
>>> from doctr.datasets import WordGenerator
|
| 39 |
+
>>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
|
| 40 |
+
>>> img, target = ds[0]
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
----
|
| 44 |
+
vocab: vocabulary to take the character from
|
| 45 |
+
min_chars: minimum number of characters in a word
|
| 46 |
+
max_chars: maximum number of characters in a word
|
| 47 |
+
num_samples: number of samples that will be generated iterating over the dataset
|
| 48 |
+
cache_samples: whether generated images should be cached firsthand
|
| 49 |
+
font_family: font to use to generate the text images
|
| 50 |
+
img_transforms: composable transformations that will be applied to each image
|
| 51 |
+
sample_transforms: composable transformations that will be applied to both the image and the target
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
pass
|
doctr/datasets/generator/tensorflow.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
|
| 8 |
+
from .base import _CharacterGenerator, _WordGenerator
|
| 9 |
+
|
| 10 |
+
__all__ = ["CharacterGenerator", "WordGenerator"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CharacterGenerator(_CharacterGenerator):
|
| 14 |
+
"""Implements a character image generation dataset
|
| 15 |
+
|
| 16 |
+
>>> from doctr.datasets import CharacterGenerator
|
| 17 |
+
>>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
|
| 18 |
+
>>> img, target = ds[0]
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
----
|
| 22 |
+
vocab: vocabulary to take the character from
|
| 23 |
+
num_samples: number of samples that will be generated iterating over the dataset
|
| 24 |
+
cache_samples: whether generated images should be cached firsthand
|
| 25 |
+
font_family: font to use to generate the text images
|
| 26 |
+
img_transforms: composable transformations that will be applied to each image
|
| 27 |
+
sample_transforms: composable transformations that will be applied to both the image and the target
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def collate_fn(samples):
|
| 35 |
+
images, targets = zip(*samples)
|
| 36 |
+
images = tf.stack(images, axis=0)
|
| 37 |
+
|
| 38 |
+
return images, tf.convert_to_tensor(targets)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class WordGenerator(_WordGenerator):
|
| 42 |
+
"""Implements a character image generation dataset
|
| 43 |
+
|
| 44 |
+
>>> from doctr.datasets import WordGenerator
|
| 45 |
+
>>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
|
| 46 |
+
>>> img, target = ds[0]
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
----
|
| 50 |
+
vocab: vocabulary to take the character from
|
| 51 |
+
min_chars: minimum number of characters in a word
|
| 52 |
+
max_chars: maximum number of characters in a word
|
| 53 |
+
num_samples: number of samples that will be generated iterating over the dataset
|
| 54 |
+
cache_samples: whether generated images should be cached firsthand
|
| 55 |
+
font_family: font to use to generate the text images
|
| 56 |
+
img_transforms: composable transformations that will be applied to each image
|
| 57 |
+
sample_transforms: composable transformations that will be applied to both the image and the target
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
pass
|
doctr/datasets/ic03.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import defusedxml.ElementTree as ET
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .datasets import VisionDataset
|
| 14 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 15 |
+
|
| 16 |
+
__all__ = ["IC03"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class IC03(VisionDataset):
|
| 20 |
+
"""IC03 dataset from `"ICDAR 2003 Robust Reading Competitions: Entries, Results and Future Directions"
|
| 21 |
+
<http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions>`_.
|
| 22 |
+
|
| 23 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic03-grid.png&src=0
|
| 24 |
+
:align: center
|
| 25 |
+
|
| 26 |
+
>>> from doctr.datasets import IC03
|
| 27 |
+
>>> train_set = IC03(train=True, download=True)
|
| 28 |
+
>>> img, target = train_set[0]
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
----
|
| 32 |
+
train: whether the subset should be the training one
|
| 33 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 34 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 35 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
TRAIN = (
|
| 39 |
+
"http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTrain/scene.zip",
|
| 40 |
+
"9d86df514eb09dd693fb0b8c671ef54a0cfe02e803b1bbef9fc676061502eb94",
|
| 41 |
+
"ic03_train.zip",
|
| 42 |
+
)
|
| 43 |
+
TEST = (
|
| 44 |
+
"http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTest/scene.zip",
|
| 45 |
+
"dbc4b5fd5d04616b8464a1b42ea22db351ee22c2546dd15ac35611857ea111f8",
|
| 46 |
+
"ic03_test.zip",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
train: bool = True,
|
| 52 |
+
use_polygons: bool = False,
|
| 53 |
+
recognition_task: bool = False,
|
| 54 |
+
**kwargs: Any,
|
| 55 |
+
) -> None:
|
| 56 |
+
url, sha256, file_name = self.TRAIN if train else self.TEST
|
| 57 |
+
super().__init__(
|
| 58 |
+
url,
|
| 59 |
+
file_name,
|
| 60 |
+
sha256,
|
| 61 |
+
True,
|
| 62 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 63 |
+
**kwargs,
|
| 64 |
+
)
|
| 65 |
+
self.train = train
|
| 66 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 67 |
+
np_dtype = np.float32
|
| 68 |
+
|
| 69 |
+
# Load xml data
|
| 70 |
+
tmp_root = (
|
| 71 |
+
os.path.join(self.root, "SceneTrialTrain" if self.train else "SceneTrialTest") if sha256 else self.root
|
| 72 |
+
)
|
| 73 |
+
xml_tree = ET.parse(os.path.join(tmp_root, "words.xml"))
|
| 74 |
+
xml_root = xml_tree.getroot()
|
| 75 |
+
|
| 76 |
+
for image in tqdm(iterable=xml_root, desc="Unpacking IC03", total=len(xml_root)):
|
| 77 |
+
name, _resolution, rectangles = image
|
| 78 |
+
|
| 79 |
+
# File existence check
|
| 80 |
+
if not os.path.exists(os.path.join(tmp_root, name.text)):
|
| 81 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
|
| 82 |
+
|
| 83 |
+
if use_polygons:
|
| 84 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 85 |
+
_boxes = [
|
| 86 |
+
[
|
| 87 |
+
[float(rect.attrib["x"]), float(rect.attrib["y"])],
|
| 88 |
+
[float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
|
| 89 |
+
[
|
| 90 |
+
float(rect.attrib["x"]) + float(rect.attrib["width"]),
|
| 91 |
+
float(rect.attrib["y"]) + float(rect.attrib["height"]),
|
| 92 |
+
],
|
| 93 |
+
[float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
|
| 94 |
+
]
|
| 95 |
+
for rect in rectangles
|
| 96 |
+
]
|
| 97 |
+
else:
|
| 98 |
+
# x_min, y_min, x_max, y_max
|
| 99 |
+
_boxes = [
|
| 100 |
+
[
|
| 101 |
+
float(rect.attrib["x"]), # type: ignore[list-item]
|
| 102 |
+
float(rect.attrib["y"]), # type: ignore[list-item]
|
| 103 |
+
float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
|
| 104 |
+
float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
|
| 105 |
+
]
|
| 106 |
+
for rect in rectangles
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
# filter images without boxes
|
| 110 |
+
if len(_boxes) > 0:
|
| 111 |
+
boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
|
| 112 |
+
# Get the labels
|
| 113 |
+
labels = [lab.text for rect in rectangles for lab in rect if lab.text]
|
| 114 |
+
|
| 115 |
+
if recognition_task:
|
| 116 |
+
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
|
| 117 |
+
for crop, label in zip(crops, labels):
|
| 118 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 119 |
+
self.data.append((crop, label))
|
| 120 |
+
else:
|
| 121 |
+
self.data.append((name.text, dict(boxes=boxes, labels=labels)))
|
| 122 |
+
|
| 123 |
+
self.root = tmp_root
|
| 124 |
+
|
| 125 |
+
def extra_repr(self) -> str:
|
| 126 |
+
return f"train={self.train}"
|
doctr/datasets/ic13.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import csv
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from .datasets import AbstractDataset
|
| 15 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 16 |
+
|
| 17 |
+
__all__ = ["IC13"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class IC13(AbstractDataset):
|
| 21 |
+
"""IC13 dataset from `"ICDAR 2013 Robust Reading Competition" <https://rrc.cvc.uab.es/>`_.
|
| 22 |
+
|
| 23 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic13-grid.png&src=0
|
| 24 |
+
:align: center
|
| 25 |
+
|
| 26 |
+
>>> # NOTE: You need to download both image and label parts from Focused Scene Text challenge Task2.1 2013-2015.
|
| 27 |
+
>>> from doctr.datasets import IC13
|
| 28 |
+
>>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images",
|
| 29 |
+
>>> label_folder="/path/to/Challenge2_Training_Task1_GT")
|
| 30 |
+
>>> img, target = train_set[0]
|
| 31 |
+
>>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images",
|
| 32 |
+
>>> label_folder="/path/to/Challenge2_Test_Task1_GT")
|
| 33 |
+
>>> img, target = test_set[0]
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
----
|
| 37 |
+
img_folder: folder with all the images of the dataset
|
| 38 |
+
label_folder: folder with all annotation files for the images
|
| 39 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 40 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 41 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
img_folder: str,
|
| 47 |
+
label_folder: str,
|
| 48 |
+
use_polygons: bool = False,
|
| 49 |
+
recognition_task: bool = False,
|
| 50 |
+
**kwargs: Any,
|
| 51 |
+
) -> None:
|
| 52 |
+
super().__init__(
|
| 53 |
+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# File existence check
|
| 57 |
+
if not os.path.exists(label_folder) or not os.path.exists(img_folder):
|
| 58 |
+
raise FileNotFoundError(
|
| 59 |
+
f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 63 |
+
np_dtype = np.float32
|
| 64 |
+
|
| 65 |
+
img_names = os.listdir(img_folder)
|
| 66 |
+
|
| 67 |
+
for img_name in tqdm(iterable=img_names, desc="Unpacking IC13", total=len(img_names)):
|
| 68 |
+
img_path = Path(img_folder, img_name)
|
| 69 |
+
label_path = Path(label_folder, "gt_" + Path(img_name).stem + ".txt")
|
| 70 |
+
|
| 71 |
+
with open(label_path, newline="\n") as f:
|
| 72 |
+
_lines = [
|
| 73 |
+
[val[:-1] if val.endswith(",") else val for val in row]
|
| 74 |
+
for row in csv.reader(f, delimiter=" ", quotechar="'")
|
| 75 |
+
]
|
| 76 |
+
labels = [line[-1].replace('"', "") for line in _lines]
|
| 77 |
+
# xmin, ymin, xmax, ymax
|
| 78 |
+
box_targets: np.ndarray = np.array([list(map(int, line[:4])) for line in _lines], dtype=np_dtype)
|
| 79 |
+
if use_polygons:
|
| 80 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 81 |
+
box_targets = np.array(
|
| 82 |
+
[
|
| 83 |
+
[
|
| 84 |
+
[coords[0], coords[1]],
|
| 85 |
+
[coords[2], coords[1]],
|
| 86 |
+
[coords[2], coords[3]],
|
| 87 |
+
[coords[0], coords[3]],
|
| 88 |
+
]
|
| 89 |
+
for coords in box_targets
|
| 90 |
+
],
|
| 91 |
+
dtype=np_dtype,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if recognition_task:
|
| 95 |
+
crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
|
| 96 |
+
for crop, label in zip(crops, labels):
|
| 97 |
+
self.data.append((crop, label))
|
| 98 |
+
else:
|
| 99 |
+
self.data.append((img_path, dict(boxes=box_targets, labels=labels)))
|
doctr/datasets/iiit5k.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import scipy.io as sio
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .datasets import VisionDataset
|
| 14 |
+
from .utils import convert_target_to_relative
|
| 15 |
+
|
| 16 |
+
__all__ = ["IIIT5K"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class IIIT5K(VisionDataset):
|
| 20 |
+
"""IIIT-5K character-level localization dataset from
|
| 21 |
+
`"BMVC 2012 Scene Text Recognition using Higher Order Language Priors"
|
| 22 |
+
<https://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/home/mishraBMVC12.pdf>`_.
|
| 23 |
+
|
| 24 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/iiit5k-grid.png&src=0
|
| 25 |
+
:align: center
|
| 26 |
+
|
| 27 |
+
>>> # NOTE: this dataset is for character-level localization
|
| 28 |
+
>>> from doctr.datasets import IIIT5K
|
| 29 |
+
>>> train_set = IIIT5K(train=True, download=True)
|
| 30 |
+
>>> img, target = train_set[0]
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
----
|
| 34 |
+
train: whether the subset should be the training one
|
| 35 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 36 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 37 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
URL = "https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz"
|
| 41 |
+
SHA256 = "7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e"
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
train: bool = True,
|
| 46 |
+
use_polygons: bool = False,
|
| 47 |
+
recognition_task: bool = False,
|
| 48 |
+
**kwargs: Any,
|
| 49 |
+
) -> None:
|
| 50 |
+
super().__init__(
|
| 51 |
+
self.URL,
|
| 52 |
+
None,
|
| 53 |
+
file_hash=self.SHA256,
|
| 54 |
+
extract_archive=True,
|
| 55 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 56 |
+
**kwargs,
|
| 57 |
+
)
|
| 58 |
+
self.train = train
|
| 59 |
+
|
| 60 |
+
# Load mat data
|
| 61 |
+
tmp_root = os.path.join(self.root, "IIIT5K") if self.SHA256 else self.root
|
| 62 |
+
mat_file = "trainCharBound" if self.train else "testCharBound"
|
| 63 |
+
mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0]
|
| 64 |
+
|
| 65 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 66 |
+
np_dtype = np.float32
|
| 67 |
+
|
| 68 |
+
for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)):
|
| 69 |
+
_raw_path = img_path[0]
|
| 70 |
+
_raw_label = label[0]
|
| 71 |
+
|
| 72 |
+
# File existence check
|
| 73 |
+
if not os.path.exists(os.path.join(tmp_root, _raw_path)):
|
| 74 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")
|
| 75 |
+
|
| 76 |
+
if recognition_task:
|
| 77 |
+
self.data.append((_raw_path, _raw_label))
|
| 78 |
+
else:
|
| 79 |
+
if use_polygons:
|
| 80 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 81 |
+
box_targets = [
|
| 82 |
+
[
|
| 83 |
+
[box[0], box[1]],
|
| 84 |
+
[box[0] + box[2], box[1]],
|
| 85 |
+
[box[0] + box[2], box[1] + box[3]],
|
| 86 |
+
[box[0], box[1] + box[3]],
|
| 87 |
+
]
|
| 88 |
+
for box in box_targets
|
| 89 |
+
]
|
| 90 |
+
else:
|
| 91 |
+
# xmin, ymin, xmax, ymax
|
| 92 |
+
box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
|
| 93 |
+
|
| 94 |
+
# label are casted to list where each char corresponds to the character's bounding box
|
| 95 |
+
self.data.append((
|
| 96 |
+
_raw_path,
|
| 97 |
+
dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(_raw_label)),
|
| 98 |
+
))
|
| 99 |
+
|
| 100 |
+
self.root = tmp_root
|
| 101 |
+
|
| 102 |
+
def extra_repr(self) -> str:
|
| 103 |
+
return f"train={self.train}"
|
doctr/datasets/iiithws.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from random import sample
|
| 8 |
+
from typing import Any, List, Tuple
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .datasets import AbstractDataset
|
| 13 |
+
|
| 14 |
+
__all__ = ["IIITHWS"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class IIITHWS(AbstractDataset):
|
| 18 |
+
"""IIITHWS dataset from `"Generating Synthetic Data for Text Recognition"
|
| 19 |
+
<https://arxiv.org/pdf/1608.04224.pdf>`_ | `"repository" <https://github.com/kris314/hwnet>`_ |
|
| 20 |
+
`"website" <https://cvit.iiit.ac.in/research/projects/cvit-projects/matchdocimgs>`_.
|
| 21 |
+
|
| 22 |
+
>>> # NOTE: This is a pure recognition dataset without bounding box labels.
|
| 23 |
+
>>> # NOTE: You need to download the dataset.
|
| 24 |
+
>>> from doctr.datasets import IIITHWS
|
| 25 |
+
>>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
|
| 26 |
+
>>> label_path="/path/to/IIIT-HWS-90K.txt",
|
| 27 |
+
>>> train=True)
|
| 28 |
+
>>> img, target = train_set[0]
|
| 29 |
+
>>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
|
| 30 |
+
>>> label_path="/path/to/IIIT-HWS-90K.txt")
|
| 31 |
+
>>> train=False)
|
| 32 |
+
>>> img, target = test_set[0]
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
----
|
| 36 |
+
img_folder: folder with all the images of the dataset
|
| 37 |
+
label_path: path to the file with the labels
|
| 38 |
+
train: whether the subset should be the training one
|
| 39 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
img_folder: str,
|
| 45 |
+
label_path: str,
|
| 46 |
+
train: bool = True,
|
| 47 |
+
**kwargs: Any,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__(img_folder, **kwargs)
|
| 50 |
+
|
| 51 |
+
# File existence check
|
| 52 |
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
| 53 |
+
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
|
| 54 |
+
|
| 55 |
+
self.data: List[Tuple[str, str]] = []
|
| 56 |
+
self.train = train
|
| 57 |
+
|
| 58 |
+
with open(label_path) as f:
|
| 59 |
+
annotations = f.readlines()
|
| 60 |
+
|
| 61 |
+
# Shuffle the dataset otherwise the test set will contain the same labels n times
|
| 62 |
+
annotations = sample(annotations, len(annotations))
|
| 63 |
+
train_samples = int(len(annotations) * 0.9)
|
| 64 |
+
set_slice = slice(train_samples) if self.train else slice(train_samples, None)
|
| 65 |
+
|
| 66 |
+
for annotation in tqdm(
|
| 67 |
+
iterable=annotations[set_slice], desc="Unpacking IIITHWS", total=len(annotations[set_slice])
|
| 68 |
+
):
|
| 69 |
+
img_path, label = annotation.split()[0:2]
|
| 70 |
+
img_path = os.path.join(img_folder, img_path)
|
| 71 |
+
|
| 72 |
+
self.data.append((img_path, label))
|
| 73 |
+
|
| 74 |
+
def extra_repr(self) -> str:
|
| 75 |
+
return f"train={self.train}"
|
doctr/datasets/imgur5k.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from .datasets import AbstractDataset
|
| 18 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 19 |
+
|
| 20 |
+
__all__ = ["IMGUR5K"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class IMGUR5K(AbstractDataset):
|
| 24 |
+
"""IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example"
|
| 25 |
+
<https://arxiv.org/abs/2106.08385>`_ |
|
| 26 |
+
`repository <https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset>`_.
|
| 27 |
+
|
| 28 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/imgur5k-grid.png&src=0
|
| 29 |
+
:align: center
|
| 30 |
+
:width: 630
|
| 31 |
+
:height: 400
|
| 32 |
+
|
| 33 |
+
>>> # NOTE: You need to download/generate the dataset from the repository.
|
| 34 |
+
>>> from doctr.datasets import IMGUR5K
|
| 35 |
+
>>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
|
| 36 |
+
>>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
|
| 37 |
+
>>> img, target = train_set[0]
|
| 38 |
+
>>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
|
| 39 |
+
>>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
|
| 40 |
+
>>> img, target = test_set[0]
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
----
|
| 44 |
+
img_folder: folder with all the images of the dataset
|
| 45 |
+
label_path: path to the annotations file of the dataset
|
| 46 |
+
train: whether the subset should be the training one
|
| 47 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 48 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 49 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
img_folder: str,
|
| 55 |
+
label_path: str,
|
| 56 |
+
train: bool = True,
|
| 57 |
+
use_polygons: bool = False,
|
| 58 |
+
recognition_task: bool = False,
|
| 59 |
+
**kwargs: Any,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__(
|
| 62 |
+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# File existence check
|
| 66 |
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
| 67 |
+
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
|
| 68 |
+
|
| 69 |
+
self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 70 |
+
self.train = train
|
| 71 |
+
np_dtype = np.float32
|
| 72 |
+
|
| 73 |
+
img_names = os.listdir(img_folder)
|
| 74 |
+
train_samples = int(len(img_names) * 0.9)
|
| 75 |
+
set_slice = slice(train_samples) if self.train else slice(train_samples, None)
|
| 76 |
+
|
| 77 |
+
# define folder to write IMGUR5K recognition dataset
|
| 78 |
+
reco_folder_name = "IMGUR5K_recognition_train" if self.train else "IMGUR5K_recognition_test"
|
| 79 |
+
reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
|
| 80 |
+
reco_folder_path = os.path.join(os.path.dirname(self.root), reco_folder_name)
|
| 81 |
+
reco_images_counter = 0
|
| 82 |
+
|
| 83 |
+
if recognition_task and os.path.isdir(reco_folder_path):
|
| 84 |
+
self._read_from_folder(reco_folder_path)
|
| 85 |
+
return
|
| 86 |
+
elif recognition_task and not os.path.isdir(reco_folder_path):
|
| 87 |
+
os.makedirs(reco_folder_path, exist_ok=False)
|
| 88 |
+
|
| 89 |
+
with open(label_path) as f:
|
| 90 |
+
annotation_file = json.load(f)
|
| 91 |
+
|
| 92 |
+
for img_name in tqdm(iterable=img_names[set_slice], desc="Unpacking IMGUR5K", total=len(img_names[set_slice])):
|
| 93 |
+
img_path = Path(img_folder, img_name)
|
| 94 |
+
img_id = img_name.split(".")[0]
|
| 95 |
+
|
| 96 |
+
# File existence check
|
| 97 |
+
if not os.path.exists(os.path.join(self.root, img_name)):
|
| 98 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
|
| 99 |
+
|
| 100 |
+
# some files have no annotations which are marked with only a dot in the 'word' key
|
| 101 |
+
# ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md
|
| 102 |
+
if img_id not in annotation_file["index_to_ann_map"].keys():
|
| 103 |
+
continue
|
| 104 |
+
ann_ids = annotation_file["index_to_ann_map"][img_id]
|
| 105 |
+
annotations = [annotation_file["ann_id"][a_id] for a_id in ann_ids]
|
| 106 |
+
|
| 107 |
+
labels = [ann["word"] for ann in annotations if ann["word"] != "."]
|
| 108 |
+
# x_center, y_center, width, height, angle
|
| 109 |
+
_boxes = [
|
| 110 |
+
list(map(float, ann["bounding_box"].strip("[ ]").split(", ")))
|
| 111 |
+
for ann in annotations
|
| 112 |
+
if ann["word"] != "."
|
| 113 |
+
]
|
| 114 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 115 |
+
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
|
| 116 |
+
|
| 117 |
+
if not use_polygons:
|
| 118 |
+
# xmin, ymin, xmax, ymax
|
| 119 |
+
box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets]
|
| 120 |
+
|
| 121 |
+
# filter images without boxes
|
| 122 |
+
if len(box_targets) > 0:
|
| 123 |
+
if recognition_task:
|
| 124 |
+
crops = crop_bboxes_from_image(
|
| 125 |
+
img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype)
|
| 126 |
+
)
|
| 127 |
+
for crop, label in zip(crops, labels):
|
| 128 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 129 |
+
# write data to disk
|
| 130 |
+
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
|
| 131 |
+
f.write(label)
|
| 132 |
+
tmp_img = Image.fromarray(crop)
|
| 133 |
+
tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
|
| 134 |
+
reco_images_counter += 1
|
| 135 |
+
else:
|
| 136 |
+
self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))
|
| 137 |
+
|
| 138 |
+
if recognition_task:
|
| 139 |
+
self._read_from_folder(reco_folder_path)
|
| 140 |
+
|
| 141 |
+
def extra_repr(self) -> str:
|
| 142 |
+
return f"train={self.train}"
|
| 143 |
+
|
| 144 |
+
def _read_from_folder(self, path: str) -> None:
|
| 145 |
+
for img_path in glob.glob(os.path.join(path, "*.png")):
|
| 146 |
+
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
|
| 147 |
+
self.data.append((img_path, f.read()))
|
doctr/datasets/loader.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
from doctr.utils.multithreading import multithread_exec
|
| 13 |
+
|
| 14 |
+
__all__ = ["DataLoader"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def default_collate(samples):
|
| 18 |
+
"""Collate multiple elements into batches
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
----
|
| 22 |
+
samples: list of N tuples containing M elements
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
-------
|
| 26 |
+
Tuple of M sequences contianing N elements each
|
| 27 |
+
"""
|
| 28 |
+
batch_data = zip(*samples)
|
| 29 |
+
|
| 30 |
+
tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)
|
| 31 |
+
|
| 32 |
+
return tf_data
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DataLoader:
|
| 36 |
+
"""Implements a dataset wrapper for fast data loading
|
| 37 |
+
|
| 38 |
+
>>> from doctr.datasets import CORD, DataLoader
|
| 39 |
+
>>> train_set = CORD(train=True, download=True)
|
| 40 |
+
>>> train_loader = DataLoader(train_set, batch_size=32)
|
| 41 |
+
>>> train_iter = iter(train_loader)
|
| 42 |
+
>>> images, targets = next(train_iter)
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
----
|
| 46 |
+
dataset: the dataset
|
| 47 |
+
shuffle: whether the samples should be shuffled before passing it to the iterator
|
| 48 |
+
batch_size: number of elements in each batch
|
| 49 |
+
drop_last: if `True`, drops the last batch if it isn't full
|
| 50 |
+
num_workers: number of workers to use for data loading
|
| 51 |
+
collate_fn: function to merge samples into a batch
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
dataset,
|
| 57 |
+
shuffle: bool = True,
|
| 58 |
+
batch_size: int = 1,
|
| 59 |
+
drop_last: bool = False,
|
| 60 |
+
num_workers: Optional[int] = None,
|
| 61 |
+
collate_fn: Optional[Callable] = None,
|
| 62 |
+
) -> None:
|
| 63 |
+
self.dataset = dataset
|
| 64 |
+
self.shuffle = shuffle
|
| 65 |
+
self.batch_size = batch_size
|
| 66 |
+
nb = len(self.dataset) / batch_size
|
| 67 |
+
self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
|
| 68 |
+
if collate_fn is None:
|
| 69 |
+
self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
|
| 70 |
+
else:
|
| 71 |
+
self.collate_fn = collate_fn
|
| 72 |
+
self.num_workers = num_workers
|
| 73 |
+
self.reset()
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return self.num_batches
|
| 77 |
+
|
| 78 |
+
def reset(self) -> None:
|
| 79 |
+
# Updates indices after each epoch
|
| 80 |
+
self._num_yielded = 0
|
| 81 |
+
self.indices = np.arange(len(self.dataset))
|
| 82 |
+
if self.shuffle is True:
|
| 83 |
+
np.random.shuffle(self.indices)
|
| 84 |
+
|
| 85 |
+
def __iter__(self):
|
| 86 |
+
self.reset()
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def __next__(self):
|
| 90 |
+
if self._num_yielded < self.num_batches:
|
| 91 |
+
# Get next indices
|
| 92 |
+
idx = self._num_yielded * self.batch_size
|
| 93 |
+
indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
|
| 94 |
+
|
| 95 |
+
samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
|
| 96 |
+
|
| 97 |
+
batch_data = self.collate_fn(samples)
|
| 98 |
+
|
| 99 |
+
self._num_yielded += 1
|
| 100 |
+
return batch_data
|
| 101 |
+
else:
|
| 102 |
+
raise StopIteration
|
doctr/datasets/mjsynth.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, List, Tuple
|
| 8 |
+
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from .datasets import AbstractDataset
|
| 12 |
+
|
| 13 |
+
__all__ = ["MJSynth"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MJSynth(AbstractDataset):
|
| 17 |
+
"""MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition"
|
| 18 |
+
<https://www.robots.ox.ac.uk/~vgg/data/text/>`_.
|
| 19 |
+
|
| 20 |
+
>>> # NOTE: This is a pure recognition dataset without bounding box labels.
|
| 21 |
+
>>> # NOTE: You need to download the dataset.
|
| 22 |
+
>>> from doctr.datasets import MJSynth
|
| 23 |
+
>>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
|
| 24 |
+
>>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt",
|
| 25 |
+
>>> train=True)
|
| 26 |
+
>>> img, target = train_set[0]
|
| 27 |
+
>>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
|
| 28 |
+
>>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt")
|
| 29 |
+
>>> train=False)
|
| 30 |
+
>>> img, target = test_set[0]
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
----
|
| 34 |
+
img_folder: folder with all the images of the dataset
|
| 35 |
+
label_path: path to the file with the labels
|
| 36 |
+
train: whether the subset should be the training one
|
| 37 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# filter corrupted or missing images
|
| 41 |
+
BLACKLIST = [
|
| 42 |
+
"./1881/4/225_Marbling_46673.jpg\n",
|
| 43 |
+
"./2069/4/192_whittier_86389.jpg\n",
|
| 44 |
+
"./869/4/234_TRIASSIC_80582.jpg\n",
|
| 45 |
+
"./173/2/358_BURROWING_10395.jpg\n",
|
| 46 |
+
"./913/4/231_randoms_62372.jpg\n",
|
| 47 |
+
"./596/2/372_Ump_81662.jpg\n",
|
| 48 |
+
"./936/2/375_LOCALITIES_44992.jpg\n",
|
| 49 |
+
"./2540/4/246_SQUAMOUS_73902.jpg\n",
|
| 50 |
+
"./1332/4/224_TETHERED_78397.jpg\n",
|
| 51 |
+
"./627/6/83_PATRIARCHATE_55931.jpg\n",
|
| 52 |
+
"./2013/2/370_refract_63890.jpg\n",
|
| 53 |
+
"./2911/6/77_heretical_35885.jpg\n",
|
| 54 |
+
"./1730/2/361_HEREON_35880.jpg\n",
|
| 55 |
+
"./2194/2/334_EFFLORESCENT_24742.jpg\n",
|
| 56 |
+
"./2025/2/364_SNORTERS_72304.jpg\n",
|
| 57 |
+
"./368/4/232_friar_30876.jpg\n",
|
| 58 |
+
"./275/6/96_hackle_34465.jpg\n",
|
| 59 |
+
"./384/4/220_bolts_8596.jpg\n",
|
| 60 |
+
"./905/4/234_Postscripts_59142.jpg\n",
|
| 61 |
+
"./2749/6/101_Chided_13155.jpg\n",
|
| 62 |
+
"./495/6/81_MIDYEAR_48332.jpg\n",
|
| 63 |
+
"./2852/6/60_TOILSOME_79481.jpg\n",
|
| 64 |
+
"./554/2/366_Teleconferences_77948.jpg\n",
|
| 65 |
+
"./1696/4/211_Queened_61779.jpg\n",
|
| 66 |
+
"./2128/2/369_REDACTED_63458.jpg\n",
|
| 67 |
+
"./2557/2/351_DOWN_23492.jpg\n",
|
| 68 |
+
"./2489/4/221_snored_72290.jpg\n",
|
| 69 |
+
"./1650/2/355_stony_74902.jpg\n",
|
| 70 |
+
"./1863/4/223_Diligently_21672.jpg\n",
|
| 71 |
+
"./264/2/362_FORETASTE_30276.jpg\n",
|
| 72 |
+
"./429/4/208_Mainmasts_46140.jpg\n",
|
| 73 |
+
"./1817/2/363_actuating_904.jpg\n",
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
img_folder: str,
|
| 79 |
+
label_path: str,
|
| 80 |
+
train: bool = True,
|
| 81 |
+
**kwargs: Any,
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(img_folder, **kwargs)
|
| 84 |
+
|
| 85 |
+
# File existence check
|
| 86 |
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
| 87 |
+
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
|
| 88 |
+
|
| 89 |
+
self.data: List[Tuple[str, str]] = []
|
| 90 |
+
self.train = train
|
| 91 |
+
|
| 92 |
+
with open(label_path) as f:
|
| 93 |
+
img_paths = f.readlines()
|
| 94 |
+
|
| 95 |
+
train_samples = int(len(img_paths) * 0.9)
|
| 96 |
+
set_slice = slice(train_samples) if self.train else slice(train_samples, None)
|
| 97 |
+
|
| 98 |
+
for path in tqdm(iterable=img_paths[set_slice], desc="Unpacking MJSynth", total=len(img_paths[set_slice])):
|
| 99 |
+
if path not in self.BLACKLIST:
|
| 100 |
+
label = path.split("_")[1]
|
| 101 |
+
img_path = os.path.join(img_folder, path[2:]).strip()
|
| 102 |
+
|
| 103 |
+
self.data.append((img_path, label))
|
| 104 |
+
|
| 105 |
+
def extra_repr(self) -> str:
|
| 106 |
+
return f"train={self.train}"
|
doctr/datasets/ocr.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from .datasets import AbstractDataset
|
| 14 |
+
|
| 15 |
+
__all__ = ["OCRDataset"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OCRDataset(AbstractDataset):
|
| 19 |
+
"""Implements an OCR dataset
|
| 20 |
+
|
| 21 |
+
>>> from doctr.datasets import OCRDataset
|
| 22 |
+
>>> train_set = OCRDataset(img_folder="/path/to/images",
|
| 23 |
+
>>> label_file="/path/to/labels.json")
|
| 24 |
+
>>> img, target = train_set[0]
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
----
|
| 28 |
+
img_folder: local path to image folder (all jpg at the root)
|
| 29 |
+
label_file: local path to the label file
|
| 30 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 31 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
img_folder: str,
|
| 37 |
+
label_file: str,
|
| 38 |
+
use_polygons: bool = False,
|
| 39 |
+
**kwargs: Any,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__(img_folder, **kwargs)
|
| 42 |
+
|
| 43 |
+
# List images
|
| 44 |
+
self.data: List[Tuple[str, Dict[str, Any]]] = []
|
| 45 |
+
np_dtype = np.float32
|
| 46 |
+
with open(label_file, "rb") as f:
|
| 47 |
+
data = json.load(f)
|
| 48 |
+
|
| 49 |
+
for img_name, annotations in data.items():
|
| 50 |
+
# Get image path
|
| 51 |
+
img_name = Path(img_name)
|
| 52 |
+
# File existence check
|
| 53 |
+
if not os.path.exists(os.path.join(self.root, img_name)):
|
| 54 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
|
| 55 |
+
|
| 56 |
+
# handle empty images
|
| 57 |
+
if len(annotations["typed_words"]) == 0:
|
| 58 |
+
self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[])))
|
| 59 |
+
continue
|
| 60 |
+
# Unpack the straight boxes (xmin, ymin, xmax, ymax)
|
| 61 |
+
geoms = [list(map(float, obj["geometry"][:4])) for obj in annotations["typed_words"]]
|
| 62 |
+
if use_polygons:
|
| 63 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 64 |
+
geoms = [
|
| 65 |
+
[geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item]
|
| 66 |
+
for geom in geoms
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
text_targets = [obj["value"] for obj in annotations["typed_words"]]
|
| 70 |
+
|
| 71 |
+
self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets)))
|
doctr/datasets/orientation.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, List, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from .datasets import AbstractDataset
|
| 12 |
+
|
| 13 |
+
__all__ = ["OrientationDataset"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OrientationDataset(AbstractDataset):
|
| 17 |
+
"""Implements a basic image dataset where targets are filled with zeros.
|
| 18 |
+
|
| 19 |
+
>>> from doctr.datasets import OrientationDataset
|
| 20 |
+
>>> train_set = OrientationDataset(img_folder="/path/to/images")
|
| 21 |
+
>>> img, target = train_set[0]
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
----
|
| 25 |
+
img_folder: folder with all the images of the dataset
|
| 26 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
img_folder: str,
|
| 32 |
+
**kwargs: Any,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(
|
| 35 |
+
img_folder,
|
| 36 |
+
**kwargs,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# initialize dataset with 0 degree rotation targets
|
| 40 |
+
self.data: List[Tuple[str, np.ndarray]] = [(img_name, np.array([0])) for img_name in os.listdir(self.root)]
|
doctr/datasets/recognition.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, List, Tuple
|
| 10 |
+
|
| 11 |
+
from .datasets import AbstractDataset
|
| 12 |
+
|
| 13 |
+
__all__ = ["RecognitionDataset"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RecognitionDataset(AbstractDataset):
|
| 17 |
+
"""Dataset implementation for text recognition tasks
|
| 18 |
+
|
| 19 |
+
>>> from doctr.datasets import RecognitionDataset
|
| 20 |
+
>>> train_set = RecognitionDataset(img_folder="/path/to/images",
|
| 21 |
+
>>> labels_path="/path/to/labels.json")
|
| 22 |
+
>>> img, target = train_set[0]
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
----
|
| 26 |
+
img_folder: path to the images folder
|
| 27 |
+
labels_path: pathe to the json file containing all labels (character sequences)
|
| 28 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
img_folder: str,
|
| 34 |
+
labels_path: str,
|
| 35 |
+
**kwargs: Any,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__(img_folder, **kwargs)
|
| 38 |
+
|
| 39 |
+
self.data: List[Tuple[str, str]] = []
|
| 40 |
+
with open(labels_path, encoding="utf-8") as f:
|
| 41 |
+
labels = json.load(f)
|
| 42 |
+
|
| 43 |
+
for img_name, label in labels.items():
|
| 44 |
+
if not os.path.exists(os.path.join(self.root, img_name)):
|
| 45 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
|
| 46 |
+
|
| 47 |
+
self.data.append((img_name, label))
|
| 48 |
+
|
| 49 |
+
def merge_dataset(self, ds: AbstractDataset) -> None:
|
| 50 |
+
# Update data with new root for self
|
| 51 |
+
self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data]
|
| 52 |
+
# Define new root
|
| 53 |
+
self.root = Path("/")
|
| 54 |
+
# Merge with ds data
|
| 55 |
+
for img_path, label in ds.data:
|
| 56 |
+
self.data.append((str(Path(ds.root).joinpath(img_path)), label))
|
doctr/datasets/sroie.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import csv
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from .datasets import VisionDataset
|
| 15 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 16 |
+
|
| 17 |
+
__all__ = ["SROIE"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SROIE(VisionDataset):
|
| 21 |
+
"""SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction"
|
| 22 |
+
<https://arxiv.org/pdf/2103.10213.pdf>`_.
|
| 23 |
+
|
| 24 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/sroie-grid.png&src=0
|
| 25 |
+
:align: center
|
| 26 |
+
|
| 27 |
+
>>> from doctr.datasets import SROIE
|
| 28 |
+
>>> train_set = SROIE(train=True, download=True)
|
| 29 |
+
>>> img, target = train_set[0]
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
----
|
| 33 |
+
train: whether the subset should be the training one
|
| 34 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 35 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 36 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
TRAIN = (
|
| 40 |
+
"https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_train_task1.zip&src=0",
|
| 41 |
+
"d4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f",
|
| 42 |
+
"sroie2019_train_task1.zip",
|
| 43 |
+
)
|
| 44 |
+
TEST = (
|
| 45 |
+
"https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_test.zip&src=0",
|
| 46 |
+
"41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2",
|
| 47 |
+
"sroie2019_test.zip",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
train: bool = True,
|
| 53 |
+
use_polygons: bool = False,
|
| 54 |
+
recognition_task: bool = False,
|
| 55 |
+
**kwargs: Any,
|
| 56 |
+
) -> None:
|
| 57 |
+
url, sha256, name = self.TRAIN if train else self.TEST
|
| 58 |
+
super().__init__(
|
| 59 |
+
url,
|
| 60 |
+
name,
|
| 61 |
+
sha256,
|
| 62 |
+
True,
|
| 63 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 64 |
+
**kwargs,
|
| 65 |
+
)
|
| 66 |
+
self.train = train
|
| 67 |
+
|
| 68 |
+
tmp_root = os.path.join(self.root, "images")
|
| 69 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 70 |
+
np_dtype = np.float32
|
| 71 |
+
|
| 72 |
+
for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
|
| 73 |
+
# File existence check
|
| 74 |
+
if not os.path.exists(os.path.join(tmp_root, img_path)):
|
| 75 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
|
| 76 |
+
|
| 77 |
+
stem = Path(img_path).stem
|
| 78 |
+
with open(os.path.join(self.root, "annotations", f"{stem}.txt"), encoding="latin") as f:
|
| 79 |
+
_rows = [row for row in list(csv.reader(f, delimiter=",")) if len(row) > 0]
|
| 80 |
+
|
| 81 |
+
labels = [",".join(row[8:]) for row in _rows]
|
| 82 |
+
# reorder coordinates (8 -> (4,2) ->
|
| 83 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty lines
|
| 84 |
+
coords: np.ndarray = np.stack(
|
| 85 |
+
[np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2)) for row in _rows], axis=0
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if not use_polygons:
|
| 89 |
+
# xmin, ymin, xmax, ymax
|
| 90 |
+
coords = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1)
|
| 91 |
+
|
| 92 |
+
if recognition_task:
|
| 93 |
+
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path), geoms=coords)
|
| 94 |
+
for crop, label in zip(crops, labels):
|
| 95 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 96 |
+
self.data.append((crop, label))
|
| 97 |
+
else:
|
| 98 |
+
self.data.append((img_path, dict(boxes=coords, labels=labels)))
|
| 99 |
+
|
| 100 |
+
self.root = tmp_root
|
| 101 |
+
|
| 102 |
+
def extra_repr(self) -> str:
|
| 103 |
+
return f"train={self.train}"
|
doctr/datasets/svhn.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import h5py
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .datasets import VisionDataset
|
| 14 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 15 |
+
|
| 16 |
+
__all__ = ["SVHN"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SVHN(VisionDataset):
|
| 20 |
+
"""SVHN dataset from `"The Street View House Numbers (SVHN) Dataset"
|
| 21 |
+
<http://ufldl.stanford.edu/housenumbers/>`_.
|
| 22 |
+
|
| 23 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svhn-grid.png&src=0
|
| 24 |
+
:align: center
|
| 25 |
+
|
| 26 |
+
>>> from doctr.datasets import SVHN
|
| 27 |
+
>>> train_set = SVHN(train=True, download=True)
|
| 28 |
+
>>> img, target = train_set[0]
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
----
|
| 32 |
+
train: whether the subset should be the training one
|
| 33 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 34 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 35 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
TRAIN = (
|
| 39 |
+
"http://ufldl.stanford.edu/housenumbers/train.tar.gz",
|
| 40 |
+
"4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898",
|
| 41 |
+
"svhn_train.tar",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
TEST = (
|
| 45 |
+
"http://ufldl.stanford.edu/housenumbers/test.tar.gz",
|
| 46 |
+
"57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5",
|
| 47 |
+
"svhn_test.tar",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
train: bool = True,
|
| 53 |
+
use_polygons: bool = False,
|
| 54 |
+
recognition_task: bool = False,
|
| 55 |
+
**kwargs: Any,
|
| 56 |
+
) -> None:
|
| 57 |
+
url, sha256, name = self.TRAIN if train else self.TEST
|
| 58 |
+
super().__init__(
|
| 59 |
+
url,
|
| 60 |
+
file_name=name,
|
| 61 |
+
file_hash=sha256,
|
| 62 |
+
extract_archive=True,
|
| 63 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 64 |
+
**kwargs,
|
| 65 |
+
)
|
| 66 |
+
self.train = train
|
| 67 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 68 |
+
np_dtype = np.float32
|
| 69 |
+
|
| 70 |
+
tmp_root = os.path.join(self.root, "train" if train else "test")
|
| 71 |
+
|
| 72 |
+
# Load mat data (matlab v7.3 - can not be loaded with scipy)
|
| 73 |
+
with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
|
| 74 |
+
img_refs = f["digitStruct/name"]
|
| 75 |
+
box_refs = f["digitStruct/bbox"]
|
| 76 |
+
for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc="Unpacking SVHN", total=len(img_refs)):
|
| 77 |
+
# convert ascii matrix to string
|
| 78 |
+
img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
|
| 79 |
+
|
| 80 |
+
# File existence check
|
| 81 |
+
if not os.path.exists(os.path.join(tmp_root, img_name)):
|
| 82 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
|
| 83 |
+
|
| 84 |
+
# Unpack the information
|
| 85 |
+
box = f[box_ref[0]]
|
| 86 |
+
if box["left"].shape[0] == 1:
|
| 87 |
+
box_dict = {k: [int(vals[0][0])] for k, vals in box.items()}
|
| 88 |
+
else:
|
| 89 |
+
box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()}
|
| 90 |
+
|
| 91 |
+
# Convert it to the right format
|
| 92 |
+
coords: np.ndarray = np.array(
|
| 93 |
+
[box_dict["left"], box_dict["top"], box_dict["width"], box_dict["height"]], dtype=np_dtype
|
| 94 |
+
).transpose()
|
| 95 |
+
label_targets = list(map(str, box_dict["label"]))
|
| 96 |
+
|
| 97 |
+
if use_polygons:
|
| 98 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 99 |
+
box_targets: np.ndarray = np.stack(
|
| 100 |
+
[
|
| 101 |
+
np.stack([coords[:, 0], coords[:, 1]], axis=-1),
|
| 102 |
+
np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1),
|
| 103 |
+
np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1),
|
| 104 |
+
np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1),
|
| 105 |
+
],
|
| 106 |
+
axis=1,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# x, y, width, height -> xmin, ymin, xmax, ymax
|
| 110 |
+
box_targets = np.stack(
|
| 111 |
+
[
|
| 112 |
+
coords[:, 0],
|
| 113 |
+
coords[:, 1],
|
| 114 |
+
coords[:, 0] + coords[:, 2],
|
| 115 |
+
coords[:, 1] + coords[:, 3],
|
| 116 |
+
],
|
| 117 |
+
axis=-1,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if recognition_task:
|
| 121 |
+
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
|
| 122 |
+
for crop, label in zip(crops, label_targets):
|
| 123 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 124 |
+
self.data.append((crop, label))
|
| 125 |
+
else:
|
| 126 |
+
self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))
|
| 127 |
+
|
| 128 |
+
self.root = tmp_root
|
| 129 |
+
|
| 130 |
+
def extra_repr(self) -> str:
|
| 131 |
+
return f"train={self.train}"
|
doctr/datasets/svt.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import defusedxml.ElementTree as ET
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .datasets import VisionDataset
|
| 14 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 15 |
+
|
| 16 |
+
__all__ = ["SVT"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SVT(VisionDataset):
|
| 20 |
+
"""SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
|
| 21 |
+
<http://vision.ucsd.edu/~kai/svt/>`_.
|
| 22 |
+
|
| 23 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
|
| 24 |
+
:align: center
|
| 25 |
+
|
| 26 |
+
>>> from doctr.datasets import SVT
|
| 27 |
+
>>> train_set = SVT(train=True, download=True)
|
| 28 |
+
>>> img, target = train_set[0]
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
----
|
| 32 |
+
train: whether the subset should be the training one
|
| 33 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 34 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 35 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
URL = "http://vision.ucsd.edu/~kai/svt/svt.zip"
|
| 39 |
+
SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
train: bool = True,
|
| 44 |
+
use_polygons: bool = False,
|
| 45 |
+
recognition_task: bool = False,
|
| 46 |
+
**kwargs: Any,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__(
|
| 49 |
+
self.URL,
|
| 50 |
+
None,
|
| 51 |
+
self.SHA256,
|
| 52 |
+
True,
|
| 53 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 54 |
+
**kwargs,
|
| 55 |
+
)
|
| 56 |
+
self.train = train
|
| 57 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 58 |
+
np_dtype = np.float32
|
| 59 |
+
|
| 60 |
+
# Load xml data
|
| 61 |
+
tmp_root = os.path.join(self.root, "svt1") if self.SHA256 else self.root
|
| 62 |
+
xml_tree = (
|
| 63 |
+
ET.parse(os.path.join(tmp_root, "train.xml"))
|
| 64 |
+
if self.train
|
| 65 |
+
else ET.parse(os.path.join(tmp_root, "test.xml"))
|
| 66 |
+
)
|
| 67 |
+
xml_root = xml_tree.getroot()
|
| 68 |
+
|
| 69 |
+
for image in tqdm(iterable=xml_root, desc="Unpacking SVT", total=len(xml_root)):
|
| 70 |
+
name, _, _, _resolution, rectangles = image
|
| 71 |
+
|
| 72 |
+
# File existence check
|
| 73 |
+
if not os.path.exists(os.path.join(tmp_root, name.text)):
|
| 74 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
|
| 75 |
+
|
| 76 |
+
if use_polygons:
|
| 77 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 78 |
+
_boxes = [
|
| 79 |
+
[
|
| 80 |
+
[float(rect.attrib["x"]), float(rect.attrib["y"])],
|
| 81 |
+
[float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
|
| 82 |
+
[
|
| 83 |
+
float(rect.attrib["x"]) + float(rect.attrib["width"]),
|
| 84 |
+
float(rect.attrib["y"]) + float(rect.attrib["height"]),
|
| 85 |
+
],
|
| 86 |
+
[float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
|
| 87 |
+
]
|
| 88 |
+
for rect in rectangles
|
| 89 |
+
]
|
| 90 |
+
else:
|
| 91 |
+
# x_min, y_min, x_max, y_max
|
| 92 |
+
_boxes = [
|
| 93 |
+
[
|
| 94 |
+
float(rect.attrib["x"]), # type: ignore[list-item]
|
| 95 |
+
float(rect.attrib["y"]), # type: ignore[list-item]
|
| 96 |
+
float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
|
| 97 |
+
float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
|
| 98 |
+
]
|
| 99 |
+
for rect in rectangles
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
|
| 103 |
+
# Get the labels
|
| 104 |
+
labels = [lab.text for rect in rectangles for lab in rect]
|
| 105 |
+
|
| 106 |
+
if recognition_task:
|
| 107 |
+
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
|
| 108 |
+
for crop, label in zip(crops, labels):
|
| 109 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 110 |
+
self.data.append((crop, label))
|
| 111 |
+
else:
|
| 112 |
+
self.data.append((name.text, dict(boxes=boxes, labels=labels)))
|
| 113 |
+
|
| 114 |
+
self.root = tmp_root
|
| 115 |
+
|
| 116 |
+
def extra_repr(self) -> str:
|
| 117 |
+
return f"train={self.train}"
|
doctr/datasets/synthtext.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from scipy import io as sio
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from .datasets import VisionDataset
|
| 16 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 17 |
+
|
| 18 |
+
__all__ = ["SynthText"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SynthText(VisionDataset):
|
| 22 |
+
"""SynthText dataset from `"Synthetic Data for Text Localisation in Natural Images"
|
| 23 |
+
<https://arxiv.org/abs/1604.06646>`_ | `"repository" <https://github.com/ankush-me/SynthText>`_ |
|
| 24 |
+
`"website" <https://www.robots.ox.ac.uk/~vgg/data/scenetext/>`_.
|
| 25 |
+
|
| 26 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
|
| 27 |
+
:align: center
|
| 28 |
+
|
| 29 |
+
>>> from doctr.datasets import SynthText
|
| 30 |
+
>>> train_set = SynthText(train=True, download=True)
|
| 31 |
+
>>> img, target = train_set[0]
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
----
|
| 35 |
+
train: whether the subset should be the training one
|
| 36 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 37 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 38 |
+
**kwargs: keyword arguments from `VisionDataset`.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
|
| 42 |
+
SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
train: bool = True,
|
| 47 |
+
use_polygons: bool = False,
|
| 48 |
+
recognition_task: bool = False,
|
| 49 |
+
**kwargs: Any,
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__(
|
| 52 |
+
self.URL,
|
| 53 |
+
None,
|
| 54 |
+
file_hash=None,
|
| 55 |
+
extract_archive=True,
|
| 56 |
+
pre_transforms=convert_target_to_relative if not recognition_task else None,
|
| 57 |
+
**kwargs,
|
| 58 |
+
)
|
| 59 |
+
self.train = train
|
| 60 |
+
self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 61 |
+
np_dtype = np.float32
|
| 62 |
+
|
| 63 |
+
# Load mat data
|
| 64 |
+
tmp_root = os.path.join(self.root, "SynthText") if self.SHA256 else self.root
|
| 65 |
+
# define folder to write SynthText recognition dataset
|
| 66 |
+
reco_folder_name = "SynthText_recognition_train" if self.train else "SynthText_recognition_test"
|
| 67 |
+
reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
|
| 68 |
+
reco_folder_path = os.path.join(tmp_root, reco_folder_name)
|
| 69 |
+
reco_images_counter = 0
|
| 70 |
+
|
| 71 |
+
if recognition_task and os.path.isdir(reco_folder_path):
|
| 72 |
+
self._read_from_folder(reco_folder_path)
|
| 73 |
+
return
|
| 74 |
+
elif recognition_task and not os.path.isdir(reco_folder_path):
|
| 75 |
+
os.makedirs(reco_folder_path, exist_ok=False)
|
| 76 |
+
|
| 77 |
+
mat_data = sio.loadmat(os.path.join(tmp_root, "gt.mat"))
|
| 78 |
+
train_samples = int(len(mat_data["imnames"][0]) * 0.9)
|
| 79 |
+
set_slice = slice(train_samples) if self.train else slice(train_samples, None)
|
| 80 |
+
paths = mat_data["imnames"][0][set_slice]
|
| 81 |
+
boxes = mat_data["wordBB"][0][set_slice]
|
| 82 |
+
labels = mat_data["txt"][0][set_slice]
|
| 83 |
+
del mat_data
|
| 84 |
+
|
| 85 |
+
for img_path, word_boxes, txt in tqdm(
|
| 86 |
+
iterable=zip(paths, boxes, labels), desc="Unpacking SynthText", total=len(paths)
|
| 87 |
+
):
|
| 88 |
+
# File existence check
|
| 89 |
+
if not os.path.exists(os.path.join(tmp_root, img_path[0])):
|
| 90 |
+
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}")
|
| 91 |
+
|
| 92 |
+
labels = [elt for word in txt.tolist() for elt in word.split()]
|
| 93 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 94 |
+
word_boxes = (
|
| 95 |
+
word_boxes.transpose(2, 1, 0)
|
| 96 |
+
if word_boxes.ndim == 3
|
| 97 |
+
else np.expand_dims(word_boxes.transpose(1, 0), axis=0)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if not use_polygons:
|
| 101 |
+
# xmin, ymin, xmax, ymax
|
| 102 |
+
word_boxes = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1)
|
| 103 |
+
|
| 104 |
+
if recognition_task:
|
| 105 |
+
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
|
| 106 |
+
for crop, label in zip(crops, labels):
|
| 107 |
+
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
|
| 108 |
+
# write data to disk
|
| 109 |
+
with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
|
| 110 |
+
f.write(label)
|
| 111 |
+
tmp_img = Image.fromarray(crop)
|
| 112 |
+
tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
|
| 113 |
+
reco_images_counter += 1
|
| 114 |
+
else:
|
| 115 |
+
self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels)))
|
| 116 |
+
|
| 117 |
+
if recognition_task:
|
| 118 |
+
self._read_from_folder(reco_folder_path)
|
| 119 |
+
|
| 120 |
+
self.root = tmp_root
|
| 121 |
+
|
| 122 |
+
def extra_repr(self) -> str:
|
| 123 |
+
return f"train={self.train}"
|
| 124 |
+
|
| 125 |
+
def _read_from_folder(self, path: str) -> None:
|
| 126 |
+
for img_path in glob.glob(os.path.join(path, "*.png")):
|
| 127 |
+
with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
|
| 128 |
+
self.data.append((img_path, f.read()))
|
doctr/datasets/utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import string
|
| 7 |
+
import unicodedata
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
from functools import partial
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
| 12 |
+
from typing import Sequence as SequenceType
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from doctr.io.image import get_img_shape
|
| 18 |
+
from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extract_rcrops
|
| 19 |
+
|
| 20 |
+
from .vocabs import VOCABS
|
| 21 |
+
|
| 22 |
+
__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"]
|
| 23 |
+
|
| 24 |
+
ImageTensor = TypeVar("ImageTensor")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def translate(
|
| 28 |
+
input_string: str,
|
| 29 |
+
vocab_name: str,
|
| 30 |
+
unknown_char: str = "■",
|
| 31 |
+
) -> str:
|
| 32 |
+
"""Translate a string input in a given vocabulary
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
----
|
| 36 |
+
input_string: input string to translate
|
| 37 |
+
vocab_name: vocabulary to use (french, latin, ...)
|
| 38 |
+
unknown_char: unknown character for non-translatable characters
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
-------
|
| 42 |
+
A string translated in a given vocab
|
| 43 |
+
"""
|
| 44 |
+
if VOCABS.get(vocab_name) is None:
|
| 45 |
+
raise KeyError("output vocabulary must be in vocabs dictionnary")
|
| 46 |
+
|
| 47 |
+
translated = ""
|
| 48 |
+
for char in input_string:
|
| 49 |
+
if char not in VOCABS[vocab_name]:
|
| 50 |
+
# we need to translate char into a vocab char
|
| 51 |
+
if char in string.whitespace:
|
| 52 |
+
# remove whitespaces
|
| 53 |
+
continue
|
| 54 |
+
# normalize character if it is not in vocab
|
| 55 |
+
char = unicodedata.normalize("NFD", char).encode("ascii", "ignore").decode("ascii")
|
| 56 |
+
if char == "" or char not in VOCABS[vocab_name]:
|
| 57 |
+
# if normalization fails or char still not in vocab, return unknown character)
|
| 58 |
+
char = unknown_char
|
| 59 |
+
translated += char
|
| 60 |
+
return translated
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def encode_string(
|
| 64 |
+
input_string: str,
|
| 65 |
+
vocab: str,
|
| 66 |
+
) -> List[int]:
|
| 67 |
+
"""Given a predefined mapping, encode the string to a sequence of numbers
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
----
|
| 71 |
+
input_string: string to encode
|
| 72 |
+
vocab: vocabulary (string), the encoding is given by the indexing of the character sequence
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
-------
|
| 76 |
+
A list encoding the input_string
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
return list(map(vocab.index, input_string))
|
| 80 |
+
except ValueError:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"some characters cannot be found in 'vocab'. \
|
| 83 |
+
Please check the input string {input_string} and the vocabulary {vocab}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def decode_sequence(
|
| 88 |
+
input_seq: Union[np.ndarray, SequenceType[int]],
|
| 89 |
+
mapping: str,
|
| 90 |
+
) -> str:
|
| 91 |
+
"""Given a predefined mapping, decode the sequence of numbers to a string
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
----
|
| 95 |
+
input_seq: array to decode
|
| 96 |
+
mapping: vocabulary (string), the encoding is given by the indexing of the character sequence
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
-------
|
| 100 |
+
A string, decoded from input_seq
|
| 101 |
+
"""
|
| 102 |
+
if not isinstance(input_seq, (Sequence, np.ndarray)):
|
| 103 |
+
raise TypeError("Invalid sequence type")
|
| 104 |
+
if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)):
|
| 105 |
+
raise AssertionError("Input must be an array of int, with max less than mapping size")
|
| 106 |
+
|
| 107 |
+
return "".join(map(mapping.__getitem__, input_seq))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def encode_sequences(
|
| 111 |
+
sequences: List[str],
|
| 112 |
+
vocab: str,
|
| 113 |
+
target_size: Optional[int] = None,
|
| 114 |
+
eos: int = -1,
|
| 115 |
+
sos: Optional[int] = None,
|
| 116 |
+
pad: Optional[int] = None,
|
| 117 |
+
dynamic_seq_length: bool = False,
|
| 118 |
+
) -> np.ndarray:
|
| 119 |
+
"""Encode character sequences using a given vocab as mapping
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
----
|
| 123 |
+
sequences: the list of character sequences of size N
|
| 124 |
+
vocab: the ordered vocab to use for encoding
|
| 125 |
+
target_size: maximum length of the encoded data
|
| 126 |
+
eos: encoding of End Of String
|
| 127 |
+
sos: optional encoding of Start Of String
|
| 128 |
+
pad: optional encoding for padding. In case of padding, all sequences are followed by 1 EOS then PAD
|
| 129 |
+
dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
-------
|
| 133 |
+
the padded encoded data as a tensor
|
| 134 |
+
"""
|
| 135 |
+
if 0 <= eos < len(vocab):
|
| 136 |
+
raise ValueError("argument 'eos' needs to be outside of vocab possible indices")
|
| 137 |
+
|
| 138 |
+
if not isinstance(target_size, int) or dynamic_seq_length:
|
| 139 |
+
# Maximum string length + EOS
|
| 140 |
+
max_length = max(len(w) for w in sequences) + 1
|
| 141 |
+
if isinstance(sos, int):
|
| 142 |
+
max_length += 1
|
| 143 |
+
if isinstance(pad, int):
|
| 144 |
+
max_length += 1
|
| 145 |
+
target_size = max_length if not isinstance(target_size, int) else min(max_length, target_size)
|
| 146 |
+
|
| 147 |
+
# Pad all sequences
|
| 148 |
+
if isinstance(pad, int): # pad with padding symbol
|
| 149 |
+
if 0 <= pad < len(vocab):
|
| 150 |
+
raise ValueError("argument 'pad' needs to be outside of vocab possible indices")
|
| 151 |
+
# In that case, add EOS at the end of the word before padding
|
| 152 |
+
default_symbol = pad
|
| 153 |
+
else: # pad with eos symbol
|
| 154 |
+
default_symbol = eos
|
| 155 |
+
encoded_data: np.ndarray = np.full([len(sequences), target_size], default_symbol, dtype=np.int32)
|
| 156 |
+
|
| 157 |
+
# Encode the strings
|
| 158 |
+
for idx, seq in enumerate(map(partial(encode_string, vocab=vocab), sequences)):
|
| 159 |
+
if isinstance(pad, int): # add eos at the end of the sequence
|
| 160 |
+
seq.append(eos)
|
| 161 |
+
encoded_data[idx, : min(len(seq), target_size)] = seq[: min(len(seq), target_size)]
|
| 162 |
+
|
| 163 |
+
if isinstance(sos, int): # place sos symbol at the beginning of each sequence
|
| 164 |
+
if 0 <= sos < len(vocab):
|
| 165 |
+
raise ValueError("argument 'sos' needs to be outside of vocab possible indices")
|
| 166 |
+
encoded_data = np.roll(encoded_data, 1)
|
| 167 |
+
encoded_data[:, 0] = sos
|
| 168 |
+
|
| 169 |
+
return encoded_data
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]:
|
| 173 |
+
target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
|
| 174 |
+
return img, target
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> List[np.ndarray]:
|
| 178 |
+
"""Crop a set of bounding boxes from an image
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
----
|
| 182 |
+
img_path: path to the image
|
| 183 |
+
geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
-------
|
| 187 |
+
a list of cropped images
|
| 188 |
+
"""
|
| 189 |
+
img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
|
| 190 |
+
# Polygon
|
| 191 |
+
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
|
| 192 |
+
return extract_rcrops(img, geoms.astype(dtype=int))
|
| 193 |
+
if geoms.ndim == 2 and geoms.shape[1] == 4:
|
| 194 |
+
return extract_crops(img, geoms.astype(dtype=int))
|
| 195 |
+
raise ValueError("Invalid geometry format")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]:
|
| 199 |
+
"""Converts multiclass target to relative coordinates.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
----
|
| 203 |
+
img: Image
|
| 204 |
+
target: tuple of target polygons and their classes names
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
-------
|
| 208 |
+
Image and dictionary of boxes, with class names as keys
|
| 209 |
+
"""
|
| 210 |
+
boxes = convert_to_relative_coords(target[0], get_img_shape(img))
|
| 211 |
+
boxes_classes = target[1]
|
| 212 |
+
boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))}
|
| 213 |
+
for k, poly in zip(boxes_classes, boxes):
|
| 214 |
+
boxes_dict[k].append(poly)
|
| 215 |
+
boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()}
|
| 216 |
+
return img, boxes_dict
|
doctr/datasets/vocabs.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import string
|
| 7 |
+
from typing import Dict
|
| 8 |
+
|
| 9 |
+
__all__ = ["VOCABS"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
VOCABS: Dict[str, str] = {
|
| 13 |
+
"digits": string.digits,
|
| 14 |
+
"ascii_letters": string.ascii_letters,
|
| 15 |
+
"punctuation": string.punctuation,
|
| 16 |
+
"currency": "£€¥¢฿",
|
| 17 |
+
"ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
|
| 18 |
+
"arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
|
| 19 |
+
"persian_letters": "پچڢڤگ",
|
| 20 |
+
"hindi_digits": "٠١٢٣٤٥٦٧٨٩",
|
| 21 |
+
"arabic_diacritics": "ًٌٍَُِّْ",
|
| 22 |
+
"arabic_punctuation": "؟؛«»—",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
|
| 26 |
+
VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"]
|
| 27 |
+
VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"]
|
| 28 |
+
VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ"
|
| 29 |
+
VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ"
|
| 30 |
+
VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿"
|
| 31 |
+
VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙÚ"
|
| 32 |
+
VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
|
| 33 |
+
VOCABS["arabic"] = (
|
| 34 |
+
VOCABS["digits"]
|
| 35 |
+
+ VOCABS["hindi_digits"]
|
| 36 |
+
+ VOCABS["arabic_letters"]
|
| 37 |
+
+ VOCABS["persian_letters"]
|
| 38 |
+
+ VOCABS["arabic_diacritics"]
|
| 39 |
+
+ VOCABS["arabic_punctuation"]
|
| 40 |
+
+ VOCABS["punctuation"]
|
| 41 |
+
)
|
| 42 |
+
VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ"
|
| 43 |
+
VOCABS["polish"] = VOCABS["english"] + "ąćęłńóśźżĄĆĘŁŃÓŚŹŻ"
|
| 44 |
+
VOCABS["dutch"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ"
|
| 45 |
+
VOCABS["norwegian"] = VOCABS["english"] + "æøåÆØÅ"
|
| 46 |
+
VOCABS["danish"] = VOCABS["english"] + "æøåÆØÅ"
|
| 47 |
+
VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
|
| 48 |
+
VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ"
|
| 49 |
+
VOCABS["vietnamese"] = (
|
| 50 |
+
VOCABS["english"]
|
| 51 |
+
+ "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
|
| 52 |
+
+ "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
|
| 53 |
+
)
|
| 54 |
+
VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
|
| 55 |
+
VOCABS["multilingual"] = "".join(
|
| 56 |
+
dict.fromkeys(
|
| 57 |
+
VOCABS["french"]
|
| 58 |
+
+ VOCABS["portuguese"]
|
| 59 |
+
+ VOCABS["spanish"]
|
| 60 |
+
+ VOCABS["german"]
|
| 61 |
+
+ VOCABS["czech"]
|
| 62 |
+
+ VOCABS["polish"]
|
| 63 |
+
+ VOCABS["dutch"]
|
| 64 |
+
+ VOCABS["italian"]
|
| 65 |
+
+ VOCABS["norwegian"]
|
| 66 |
+
+ VOCABS["danish"]
|
| 67 |
+
+ VOCABS["finnish"]
|
| 68 |
+
+ VOCABS["swedish"]
|
| 69 |
+
+ "§"
|
| 70 |
+
)
|
| 71 |
+
)
|
doctr/datasets/wildreceipt.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from .datasets import AbstractDataset
|
| 14 |
+
from .utils import convert_target_to_relative, crop_bboxes_from_image
|
| 15 |
+
|
| 16 |
+
__all__ = ["WILDRECEIPT"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class WILDRECEIPT(AbstractDataset):
|
| 20 |
+
"""WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction"
|
| 21 |
+
<https://arxiv.org/abs/2103.14470v1>`_ |
|
| 22 |
+
`repository <https://download.openmmlab.com/mmocr/data/wildreceipt.tar>`_.
|
| 23 |
+
|
| 24 |
+
.. image:: https://doctr-static.mindee.com/models?id=v0.7.0/wildreceipt-dataset.jpg&src=0
|
| 25 |
+
:align: center
|
| 26 |
+
|
| 27 |
+
>>> # NOTE: You need to download the dataset first.
|
| 28 |
+
>>> from doctr.datasets import WILDRECEIPT
|
| 29 |
+
>>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/",
|
| 30 |
+
>>> label_path="/path/to/wildreceipt/train.txt")
|
| 31 |
+
>>> img, target = train_set[0]
|
| 32 |
+
>>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/",
|
| 33 |
+
>>> label_path="/path/to/wildreceipt/test.txt")
|
| 34 |
+
>>> img, target = test_set[0]
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
----
|
| 38 |
+
img_folder: folder with all the images of the dataset
|
| 39 |
+
label_path: path to the annotations file of the dataset
|
| 40 |
+
train: whether the subset should be the training one
|
| 41 |
+
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
|
| 42 |
+
recognition_task: whether the dataset should be used for recognition task
|
| 43 |
+
**kwargs: keyword arguments from `AbstractDataset`.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
img_folder: str,
|
| 49 |
+
label_path: str,
|
| 50 |
+
train: bool = True,
|
| 51 |
+
use_polygons: bool = False,
|
| 52 |
+
recognition_task: bool = False,
|
| 53 |
+
**kwargs: Any,
|
| 54 |
+
) -> None:
|
| 55 |
+
super().__init__(
|
| 56 |
+
img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
|
| 57 |
+
)
|
| 58 |
+
# File existence check
|
| 59 |
+
if not os.path.exists(label_path) or not os.path.exists(img_folder):
|
| 60 |
+
raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
|
| 61 |
+
|
| 62 |
+
tmp_root = img_folder
|
| 63 |
+
self.train = train
|
| 64 |
+
np_dtype = np.float32
|
| 65 |
+
self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
|
| 66 |
+
|
| 67 |
+
with open(label_path, "r") as file:
|
| 68 |
+
data = file.read()
|
| 69 |
+
# Split the text file into separate JSON strings
|
| 70 |
+
json_strings = data.strip().split("\n")
|
| 71 |
+
box: Union[List[float], np.ndarray]
|
| 72 |
+
_targets = []
|
| 73 |
+
for json_string in json_strings:
|
| 74 |
+
json_data = json.loads(json_string)
|
| 75 |
+
img_path = json_data["file_name"]
|
| 76 |
+
annotations = json_data["annotations"]
|
| 77 |
+
for annotation in annotations:
|
| 78 |
+
coordinates = annotation["box"]
|
| 79 |
+
if use_polygons:
|
| 80 |
+
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
|
| 81 |
+
box = np.array(
|
| 82 |
+
[
|
| 83 |
+
[coordinates[0], coordinates[1]],
|
| 84 |
+
[coordinates[2], coordinates[3]],
|
| 85 |
+
[coordinates[4], coordinates[5]],
|
| 86 |
+
[coordinates[6], coordinates[7]],
|
| 87 |
+
],
|
| 88 |
+
dtype=np_dtype,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
x, y = coordinates[::2], coordinates[1::2]
|
| 92 |
+
box = [min(x), min(y), max(x), max(y)]
|
| 93 |
+
_targets.append((annotation["text"], box))
|
| 94 |
+
text_targets, box_targets = zip(*_targets)
|
| 95 |
+
|
| 96 |
+
if recognition_task:
|
| 97 |
+
crops = crop_bboxes_from_image(
|
| 98 |
+
img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
|
| 99 |
+
)
|
| 100 |
+
for crop, label in zip(crops, list(text_targets)):
|
| 101 |
+
if label and " " not in label:
|
| 102 |
+
self.data.append((crop, label))
|
| 103 |
+
else:
|
| 104 |
+
self.data.append((
|
| 105 |
+
img_path,
|
| 106 |
+
dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
|
| 107 |
+
))
|
| 108 |
+
self.root = tmp_root
|
| 109 |
+
|
| 110 |
+
def extra_repr(self) -> str:
|
| 111 |
+
return f"train={self.train}"
|
doctr/file_utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
|
| 7 |
+
|
| 8 |
+
import importlib.util
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
CLASS_NAME: str = "words"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if sys.version_info < (3, 8): # pragma: no cover
|
| 17 |
+
import importlib_metadata
|
| 18 |
+
else:
|
| 19 |
+
import importlib.metadata as importlib_metadata
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
|
| 23 |
+
|
| 24 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
| 25 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
| 26 |
+
|
| 27 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
| 28 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
| 32 |
+
_torch_available = importlib.util.find_spec("torch") is not None
|
| 33 |
+
if _torch_available:
|
| 34 |
+
try:
|
| 35 |
+
_torch_version = importlib_metadata.version("torch")
|
| 36 |
+
logging.info(f"PyTorch version {_torch_version} available.")
|
| 37 |
+
except importlib_metadata.PackageNotFoundError: # pragma: no cover
|
| 38 |
+
_torch_available = False
|
| 39 |
+
else: # pragma: no cover
|
| 40 |
+
logging.info("Disabling PyTorch because USE_TF is set")
|
| 41 |
+
_torch_available = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
| 45 |
+
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
| 46 |
+
if _tf_available:
|
| 47 |
+
candidates = (
|
| 48 |
+
"tensorflow",
|
| 49 |
+
"tensorflow-cpu",
|
| 50 |
+
"tensorflow-gpu",
|
| 51 |
+
"tf-nightly",
|
| 52 |
+
"tf-nightly-cpu",
|
| 53 |
+
"tf-nightly-gpu",
|
| 54 |
+
"intel-tensorflow",
|
| 55 |
+
"tensorflow-rocm",
|
| 56 |
+
"tensorflow-macos",
|
| 57 |
+
)
|
| 58 |
+
_tf_version = None
|
| 59 |
+
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
| 60 |
+
for pkg in candidates:
|
| 61 |
+
try:
|
| 62 |
+
_tf_version = importlib_metadata.version(pkg)
|
| 63 |
+
break
|
| 64 |
+
except importlib_metadata.PackageNotFoundError:
|
| 65 |
+
pass
|
| 66 |
+
_tf_available = _tf_version is not None
|
| 67 |
+
if _tf_available:
|
| 68 |
+
if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
|
| 69 |
+
logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
|
| 70 |
+
_tf_available = False
|
| 71 |
+
else:
|
| 72 |
+
logging.info(f"TensorFlow version {_tf_version} available.")
|
| 73 |
+
else: # pragma: no cover
|
| 74 |
+
logging.info("Disabling Tensorflow because USE_TORCH is set")
|
| 75 |
+
_tf_available = False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if not _torch_available and not _tf_available: # pragma: no cover
|
| 79 |
+
raise ModuleNotFoundError(
|
| 80 |
+
"DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
|
| 81 |
+
" is installed and that either USE_TF or USE_TORCH is enabled."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def is_torch_available():
|
| 86 |
+
"""Whether PyTorch is installed."""
|
| 87 |
+
return _torch_available
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def is_tf_available():
|
| 91 |
+
"""Whether TensorFlow is installed."""
|
| 92 |
+
return _tf_available
|
doctr/io/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .elements import *
|
| 2 |
+
from .html import *
|
| 3 |
+
from .image import *
|
| 4 |
+
from .pdf import *
|
| 5 |
+
from .reader import *
|
doctr/io/elements.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from defusedxml import defuse_stdlib
|
| 9 |
+
|
| 10 |
+
defuse_stdlib()
|
| 11 |
+
from xml.etree import ElementTree as ET
|
| 12 |
+
from xml.etree.ElementTree import Element as ETElement
|
| 13 |
+
from xml.etree.ElementTree import SubElement
|
| 14 |
+
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
import doctr
|
| 19 |
+
from doctr.utils.common_types import BoundingBox
|
| 20 |
+
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
|
| 21 |
+
from doctr.utils.repr import NestedObject
|
| 22 |
+
from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
|
| 23 |
+
|
| 24 |
+
__all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Element(NestedObject):
|
| 28 |
+
"""Implements an abstract document element with exporting and text rendering capabilities"""
|
| 29 |
+
|
| 30 |
+
_children_names: List[str] = []
|
| 31 |
+
_exported_keys: List[str] = []
|
| 32 |
+
|
| 33 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 34 |
+
for k, v in kwargs.items():
|
| 35 |
+
if k in self._children_names:
|
| 36 |
+
setattr(self, k, v)
|
| 37 |
+
else:
|
| 38 |
+
raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'")
|
| 39 |
+
|
| 40 |
+
def export(self) -> Dict[str, Any]:
|
| 41 |
+
"""Exports the object into a nested dict format"""
|
| 42 |
+
export_dict = {k: getattr(self, k) for k in self._exported_keys}
|
| 43 |
+
for children_name in self._children_names:
|
| 44 |
+
if children_name in ["predictions"]:
|
| 45 |
+
export_dict[children_name] = {
|
| 46 |
+
k: [item.export() for item in c] for k, c in getattr(self, children_name).items()
|
| 47 |
+
}
|
| 48 |
+
else:
|
| 49 |
+
export_dict[children_name] = [c.export() for c in getattr(self, children_name)]
|
| 50 |
+
|
| 51 |
+
return export_dict
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
def render(self) -> str:
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Word(Element):
|
| 62 |
+
"""Implements a word element
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
----
|
| 66 |
+
value: the text string of the word
|
| 67 |
+
confidence: the confidence associated with the text prediction
|
| 68 |
+
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
| 69 |
+
the page's size
|
| 70 |
+
crop_orientation: the general orientation of the crop in degrees and its confidence
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
_exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"]
|
| 74 |
+
_children_names: List[str] = []
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
value: str,
|
| 79 |
+
confidence: float,
|
| 80 |
+
geometry: Union[BoundingBox, np.ndarray],
|
| 81 |
+
crop_orientation: Dict[str, Any],
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.value = value
|
| 85 |
+
self.confidence = confidence
|
| 86 |
+
self.geometry = geometry
|
| 87 |
+
self.crop_orientation = crop_orientation
|
| 88 |
+
|
| 89 |
+
def render(self) -> str:
|
| 90 |
+
"""Renders the full text of the element"""
|
| 91 |
+
return self.value
|
| 92 |
+
|
| 93 |
+
def extra_repr(self) -> str:
|
| 94 |
+
return f"value='{self.value}', confidence={self.confidence:.2}"
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 98 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 99 |
+
return cls(**kwargs)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Artefact(Element):
|
| 103 |
+
"""Implements a non-textual element
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
----
|
| 107 |
+
artefact_type: the type of artefact
|
| 108 |
+
confidence: the confidence of the type prediction
|
| 109 |
+
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
| 110 |
+
the page's size.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
_exported_keys: List[str] = ["geometry", "type", "confidence"]
|
| 114 |
+
_children_names: List[str] = []
|
| 115 |
+
|
| 116 |
+
def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None:
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.geometry = geometry
|
| 119 |
+
self.type = artefact_type
|
| 120 |
+
self.confidence = confidence
|
| 121 |
+
|
| 122 |
+
def render(self) -> str:
|
| 123 |
+
"""Renders the full text of the element"""
|
| 124 |
+
return f"[{self.type.upper()}]"
|
| 125 |
+
|
| 126 |
+
def extra_repr(self) -> str:
|
| 127 |
+
return f"type='{self.type}', confidence={self.confidence:.2}"
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 131 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 132 |
+
return cls(**kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Line(Element):
|
| 136 |
+
"""Implements a line element as a collection of words
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
----
|
| 140 |
+
words: list of word elements
|
| 141 |
+
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
| 142 |
+
the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
|
| 143 |
+
all words in it.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
_exported_keys: List[str] = ["geometry"]
|
| 147 |
+
_children_names: List[str] = ["words"]
|
| 148 |
+
words: List[Word] = []
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
words: List[Word],
|
| 153 |
+
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
| 154 |
+
) -> None:
|
| 155 |
+
# Resolve the geometry using the smallest enclosing bounding box
|
| 156 |
+
if geometry is None:
|
| 157 |
+
# Check whether this is a rotated or straight box
|
| 158 |
+
box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
|
| 159 |
+
geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator]
|
| 160 |
+
|
| 161 |
+
super().__init__(words=words)
|
| 162 |
+
self.geometry = geometry
|
| 163 |
+
|
| 164 |
+
def render(self) -> str:
|
| 165 |
+
"""Renders the full text of the element"""
|
| 166 |
+
return " ".join(w.render() for w in self.words)
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 170 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 171 |
+
kwargs.update({
|
| 172 |
+
"words": [Word.from_dict(_dict) for _dict in save_dict["words"]],
|
| 173 |
+
})
|
| 174 |
+
return cls(**kwargs)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class Prediction(Word):
|
| 178 |
+
"""Implements a prediction element"""
|
| 179 |
+
|
| 180 |
+
def render(self) -> str:
|
| 181 |
+
"""Renders the full text of the element"""
|
| 182 |
+
return self.value
|
| 183 |
+
|
| 184 |
+
def extra_repr(self) -> str:
|
| 185 |
+
return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}"
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Block(Element):
|
| 189 |
+
"""Implements a block element as a collection of lines and artefacts
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
----
|
| 193 |
+
lines: list of line elements
|
| 194 |
+
artefacts: list of artefacts
|
| 195 |
+
geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
|
| 196 |
+
the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
|
| 197 |
+
all lines and artefacts in it.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
_exported_keys: List[str] = ["geometry"]
|
| 201 |
+
_children_names: List[str] = ["lines", "artefacts"]
|
| 202 |
+
lines: List[Line] = []
|
| 203 |
+
artefacts: List[Artefact] = []
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
lines: List[Line] = [],
|
| 208 |
+
artefacts: List[Artefact] = [],
|
| 209 |
+
geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
|
| 210 |
+
) -> None:
|
| 211 |
+
# Resolve the geometry using the smallest enclosing bounding box
|
| 212 |
+
if geometry is None:
|
| 213 |
+
line_boxes = [word.geometry for line in lines for word in line.words]
|
| 214 |
+
artefact_boxes = [artefact.geometry for artefact in artefacts]
|
| 215 |
+
box_resolution_fn = (
|
| 216 |
+
resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox
|
| 217 |
+
)
|
| 218 |
+
geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator]
|
| 219 |
+
|
| 220 |
+
super().__init__(lines=lines, artefacts=artefacts)
|
| 221 |
+
self.geometry = geometry
|
| 222 |
+
|
| 223 |
+
def render(self, line_break: str = "\n") -> str:
|
| 224 |
+
"""Renders the full text of the element"""
|
| 225 |
+
return line_break.join(line.render() for line in self.lines)
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 229 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 230 |
+
kwargs.update({
|
| 231 |
+
"lines": [Line.from_dict(_dict) for _dict in save_dict["lines"]],
|
| 232 |
+
"artefacts": [Artefact.from_dict(_dict) for _dict in save_dict["artefacts"]],
|
| 233 |
+
})
|
| 234 |
+
return cls(**kwargs)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class Page(Element):
|
| 238 |
+
"""Implements a page element as a collection of blocks
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
----
|
| 242 |
+
page: image encoded as a numpy array in uint8
|
| 243 |
+
blocks: list of block elements
|
| 244 |
+
page_idx: the index of the page in the input raw document
|
| 245 |
+
dimensions: the page size in pixels in format (height, width)
|
| 246 |
+
orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
|
| 247 |
+
language: a dictionary with the language value and confidence of the prediction
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
_exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
|
| 251 |
+
_children_names: List[str] = ["blocks"]
|
| 252 |
+
blocks: List[Block] = []
|
| 253 |
+
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
page: np.ndarray,
|
| 257 |
+
blocks: List[Block],
|
| 258 |
+
page_idx: int,
|
| 259 |
+
dimensions: Tuple[int, int],
|
| 260 |
+
orientation: Optional[Dict[str, Any]] = None,
|
| 261 |
+
language: Optional[Dict[str, Any]] = None,
|
| 262 |
+
) -> None:
|
| 263 |
+
super().__init__(blocks=blocks)
|
| 264 |
+
self.page = page
|
| 265 |
+
self.page_idx = page_idx
|
| 266 |
+
self.dimensions = dimensions
|
| 267 |
+
self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
|
| 268 |
+
self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
|
| 269 |
+
|
| 270 |
+
def render(self, block_break: str = "\n\n") -> str:
|
| 271 |
+
"""Renders the full text of the element"""
|
| 272 |
+
return block_break.join(b.render() for b in self.blocks)
|
| 273 |
+
|
| 274 |
+
def extra_repr(self) -> str:
|
| 275 |
+
return f"dimensions={self.dimensions}"
|
| 276 |
+
|
| 277 |
+
def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
|
| 278 |
+
"""Overlay the result on a given image
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
interactive: whether the display should be interactive
|
| 282 |
+
preserve_aspect_ratio: pass True if you passed True to the predictor
|
| 283 |
+
**kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
|
| 284 |
+
"""
|
| 285 |
+
visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
|
| 286 |
+
plt.show(**kwargs)
|
| 287 |
+
|
| 288 |
+
def synthesize(self, **kwargs) -> np.ndarray:
|
| 289 |
+
"""Synthesize the page from the predictions
|
| 290 |
+
|
| 291 |
+
Returns
|
| 292 |
+
-------
|
| 293 |
+
synthesized page
|
| 294 |
+
"""
|
| 295 |
+
return synthesize_page(self.export(), **kwargs)
|
| 296 |
+
|
| 297 |
+
def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
|
| 298 |
+
"""Export the page as XML (hOCR-format)
|
| 299 |
+
convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
----
|
| 303 |
+
file_title: the title of the XML file
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
-------
|
| 307 |
+
a tuple of the XML byte string, and its ElementTree
|
| 308 |
+
"""
|
| 309 |
+
p_idx = self.page_idx
|
| 310 |
+
block_count: int = 1
|
| 311 |
+
line_count: int = 1
|
| 312 |
+
word_count: int = 1
|
| 313 |
+
height, width = self.dimensions
|
| 314 |
+
language = self.language if "language" in self.language.keys() else "en"
|
| 315 |
+
# Create the XML root element
|
| 316 |
+
page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
|
| 317 |
+
# Create the header / SubElements of the root element
|
| 318 |
+
head = SubElement(page_hocr, "head")
|
| 319 |
+
SubElement(head, "title").text = file_title
|
| 320 |
+
SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
|
| 321 |
+
SubElement(
|
| 322 |
+
head,
|
| 323 |
+
"meta",
|
| 324 |
+
attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
|
| 325 |
+
)
|
| 326 |
+
SubElement(
|
| 327 |
+
head,
|
| 328 |
+
"meta",
|
| 329 |
+
attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
|
| 330 |
+
)
|
| 331 |
+
# Create the body
|
| 332 |
+
body = SubElement(page_hocr, "body")
|
| 333 |
+
SubElement(
|
| 334 |
+
body,
|
| 335 |
+
"div",
|
| 336 |
+
attrib={
|
| 337 |
+
"class": "ocr_page",
|
| 338 |
+
"id": f"page_{p_idx + 1}",
|
| 339 |
+
"title": f"image; bbox 0 0 {width} {height}; ppageno 0",
|
| 340 |
+
},
|
| 341 |
+
)
|
| 342 |
+
# iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
|
| 343 |
+
for block in self.blocks:
|
| 344 |
+
if len(block.geometry) != 2:
|
| 345 |
+
raise TypeError("XML export is only available for straight bounding boxes for now.")
|
| 346 |
+
(xmin, ymin), (xmax, ymax) = block.geometry
|
| 347 |
+
block_div = SubElement(
|
| 348 |
+
body,
|
| 349 |
+
"div",
|
| 350 |
+
attrib={
|
| 351 |
+
"class": "ocr_carea",
|
| 352 |
+
"id": f"block_{block_count}",
|
| 353 |
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
| 354 |
+
{int(round(xmax * width))} {int(round(ymax * height))}",
|
| 355 |
+
},
|
| 356 |
+
)
|
| 357 |
+
paragraph = SubElement(
|
| 358 |
+
block_div,
|
| 359 |
+
"p",
|
| 360 |
+
attrib={
|
| 361 |
+
"class": "ocr_par",
|
| 362 |
+
"id": f"par_{block_count}",
|
| 363 |
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
| 364 |
+
{int(round(xmax * width))} {int(round(ymax * height))}",
|
| 365 |
+
},
|
| 366 |
+
)
|
| 367 |
+
block_count += 1
|
| 368 |
+
for line in block.lines:
|
| 369 |
+
(xmin, ymin), (xmax, ymax) = line.geometry
|
| 370 |
+
# NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
|
| 371 |
+
line_span = SubElement(
|
| 372 |
+
paragraph,
|
| 373 |
+
"span",
|
| 374 |
+
attrib={
|
| 375 |
+
"class": "ocr_line",
|
| 376 |
+
"id": f"line_{line_count}",
|
| 377 |
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
| 378 |
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
| 379 |
+
baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
|
| 380 |
+
},
|
| 381 |
+
)
|
| 382 |
+
line_count += 1
|
| 383 |
+
for word in line.words:
|
| 384 |
+
(xmin, ymin), (xmax, ymax) = word.geometry
|
| 385 |
+
conf = word.confidence
|
| 386 |
+
word_div = SubElement(
|
| 387 |
+
line_span,
|
| 388 |
+
"span",
|
| 389 |
+
attrib={
|
| 390 |
+
"class": "ocrx_word",
|
| 391 |
+
"id": f"word_{word_count}",
|
| 392 |
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
| 393 |
+
{int(round(xmax * width))} {int(round(ymax * height))}; \
|
| 394 |
+
x_wconf {int(round(conf * 100))}",
|
| 395 |
+
},
|
| 396 |
+
)
|
| 397 |
+
# set the text
|
| 398 |
+
word_div.text = word.value
|
| 399 |
+
word_count += 1
|
| 400 |
+
|
| 401 |
+
return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr))
|
| 402 |
+
|
| 403 |
+
@classmethod
|
| 404 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 405 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 406 |
+
kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]})
|
| 407 |
+
return cls(**kwargs)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class KIEPage(Element):
|
| 411 |
+
"""Implements a KIE page element as a collection of predictions
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
----
|
| 415 |
+
predictions: Dictionary with list of block elements for each detection class
|
| 416 |
+
page: image encoded as a numpy array in uint8
|
| 417 |
+
page_idx: the index of the page in the input raw document
|
| 418 |
+
dimensions: the page size in pixels in format (height, width)
|
| 419 |
+
orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
|
| 420 |
+
language: a dictionary with the language value and confidence of the prediction
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
_exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
|
| 424 |
+
_children_names: List[str] = ["predictions"]
|
| 425 |
+
predictions: Dict[str, List[Prediction]] = {}
|
| 426 |
+
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
page: np.ndarray,
|
| 430 |
+
predictions: Dict[str, List[Prediction]],
|
| 431 |
+
page_idx: int,
|
| 432 |
+
dimensions: Tuple[int, int],
|
| 433 |
+
orientation: Optional[Dict[str, Any]] = None,
|
| 434 |
+
language: Optional[Dict[str, Any]] = None,
|
| 435 |
+
) -> None:
|
| 436 |
+
super().__init__(predictions=predictions)
|
| 437 |
+
self.page = page
|
| 438 |
+
self.page_idx = page_idx
|
| 439 |
+
self.dimensions = dimensions
|
| 440 |
+
self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
|
| 441 |
+
self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
|
| 442 |
+
|
| 443 |
+
def render(self, prediction_break: str = "\n\n") -> str:
|
| 444 |
+
"""Renders the full text of the element"""
|
| 445 |
+
return prediction_break.join(
|
| 446 |
+
f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
def extra_repr(self) -> str:
|
| 450 |
+
return f"dimensions={self.dimensions}"
|
| 451 |
+
|
| 452 |
+
def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
|
| 453 |
+
"""Overlay the result on a given image
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
interactive: whether the display should be interactive
|
| 457 |
+
preserve_aspect_ratio: pass True if you passed True to the predictor
|
| 458 |
+
**kwargs: keyword arguments passed to the matplotlib.pyplot.show method
|
| 459 |
+
"""
|
| 460 |
+
visualize_kie_page(
|
| 461 |
+
self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
|
| 462 |
+
)
|
| 463 |
+
plt.show(**kwargs)
|
| 464 |
+
|
| 465 |
+
def synthesize(self, **kwargs) -> np.ndarray:
|
| 466 |
+
"""Synthesize the page from the predictions
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
----
|
| 470 |
+
**kwargs: keyword arguments passed to the matplotlib.pyplot.show method
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
-------
|
| 474 |
+
synthesized page
|
| 475 |
+
"""
|
| 476 |
+
return synthesize_kie_page(self.export(), **kwargs)
|
| 477 |
+
|
| 478 |
+
def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
|
| 479 |
+
"""Export the page as XML (hOCR-format)
|
| 480 |
+
convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
----
|
| 484 |
+
file_title: the title of the XML file
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
-------
|
| 488 |
+
a tuple of the XML byte string, and its ElementTree
|
| 489 |
+
"""
|
| 490 |
+
p_idx = self.page_idx
|
| 491 |
+
prediction_count: int = 1
|
| 492 |
+
height, width = self.dimensions
|
| 493 |
+
language = self.language if "language" in self.language.keys() else "en"
|
| 494 |
+
# Create the XML root element
|
| 495 |
+
page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
|
| 496 |
+
# Create the header / SubElements of the root element
|
| 497 |
+
head = SubElement(page_hocr, "head")
|
| 498 |
+
SubElement(head, "title").text = file_title
|
| 499 |
+
SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
|
| 500 |
+
SubElement(
|
| 501 |
+
head,
|
| 502 |
+
"meta",
|
| 503 |
+
attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
|
| 504 |
+
)
|
| 505 |
+
SubElement(
|
| 506 |
+
head,
|
| 507 |
+
"meta",
|
| 508 |
+
attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
|
| 509 |
+
)
|
| 510 |
+
# Create the body
|
| 511 |
+
body = SubElement(page_hocr, "body")
|
| 512 |
+
SubElement(
|
| 513 |
+
body,
|
| 514 |
+
"div",
|
| 515 |
+
attrib={
|
| 516 |
+
"class": "ocr_page",
|
| 517 |
+
"id": f"page_{p_idx + 1}",
|
| 518 |
+
"title": f"image; bbox 0 0 {width} {height}; ppageno 0",
|
| 519 |
+
},
|
| 520 |
+
)
|
| 521 |
+
# iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
|
| 522 |
+
for class_name, predictions in self.predictions.items():
|
| 523 |
+
for prediction in predictions:
|
| 524 |
+
if len(prediction.geometry) != 2:
|
| 525 |
+
raise TypeError("XML export is only available for straight bounding boxes for now.")
|
| 526 |
+
(xmin, ymin), (xmax, ymax) = prediction.geometry
|
| 527 |
+
prediction_div = SubElement(
|
| 528 |
+
body,
|
| 529 |
+
"div",
|
| 530 |
+
attrib={
|
| 531 |
+
"class": "ocr_carea",
|
| 532 |
+
"id": f"{class_name}_prediction_{prediction_count}",
|
| 533 |
+
"title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
|
| 534 |
+
{int(round(xmax * width))} {int(round(ymax * height))}",
|
| 535 |
+
},
|
| 536 |
+
)
|
| 537 |
+
prediction_div.text = prediction.value
|
| 538 |
+
prediction_count += 1
|
| 539 |
+
|
| 540 |
+
return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
|
| 541 |
+
|
| 542 |
+
@classmethod
|
| 543 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 544 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 545 |
+
kwargs.update({
|
| 546 |
+
"predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]
|
| 547 |
+
})
|
| 548 |
+
return cls(**kwargs)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class Document(Element):
|
| 552 |
+
"""Implements a document element as a collection of pages
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
----
|
| 556 |
+
pages: list of page elements
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
_children_names: List[str] = ["pages"]
|
| 560 |
+
pages: List[Page] = []
|
| 561 |
+
|
| 562 |
+
def __init__(
|
| 563 |
+
self,
|
| 564 |
+
pages: List[Page],
|
| 565 |
+
) -> None:
|
| 566 |
+
super().__init__(pages=pages)
|
| 567 |
+
|
| 568 |
+
def render(self, page_break: str = "\n\n\n\n") -> str:
|
| 569 |
+
"""Renders the full text of the element"""
|
| 570 |
+
return page_break.join(p.render() for p in self.pages)
|
| 571 |
+
|
| 572 |
+
def show(self, **kwargs) -> None:
|
| 573 |
+
"""Overlay the result on a given image"""
|
| 574 |
+
for result in self.pages:
|
| 575 |
+
result.show(**kwargs)
|
| 576 |
+
|
| 577 |
+
def synthesize(self, **kwargs) -> List[np.ndarray]:
|
| 578 |
+
"""Synthesize all pages from their predictions
|
| 579 |
+
|
| 580 |
+
Returns
|
| 581 |
+
-------
|
| 582 |
+
list of synthesized pages
|
| 583 |
+
"""
|
| 584 |
+
return [page.synthesize() for page in self.pages]
|
| 585 |
+
|
| 586 |
+
def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]:
|
| 587 |
+
"""Export the document as XML (hOCR-format)
|
| 588 |
+
|
| 589 |
+
Args:
|
| 590 |
+
----
|
| 591 |
+
**kwargs: additional keyword arguments passed to the Page.export_as_xml method
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
-------
|
| 595 |
+
list of tuple of (bytes, ElementTree)
|
| 596 |
+
"""
|
| 597 |
+
return [page.export_as_xml(**kwargs) for page in self.pages]
|
| 598 |
+
|
| 599 |
+
@classmethod
|
| 600 |
+
def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
|
| 601 |
+
kwargs = {k: save_dict[k] for k in cls._exported_keys}
|
| 602 |
+
kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]})
|
| 603 |
+
return cls(**kwargs)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class KIEDocument(Document):
|
| 607 |
+
"""Implements a document element as a collection of pages
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
----
|
| 611 |
+
pages: list of page elements
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
_children_names: List[str] = ["pages"]
|
| 615 |
+
pages: List[KIEPage] = [] # type: ignore[assignment]
|
| 616 |
+
|
| 617 |
+
def __init__(
|
| 618 |
+
self,
|
| 619 |
+
pages: List[KIEPage],
|
| 620 |
+
) -> None:
|
| 621 |
+
super().__init__(pages=pages) # type: ignore[arg-type]
|
doctr/io/html.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from weasyprint import HTML
|
| 9 |
+
|
| 10 |
+
__all__ = ["read_html"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_html(url: str, **kwargs: Any) -> bytes:
|
| 14 |
+
"""Read a PDF file and convert it into an image in numpy format
|
| 15 |
+
|
| 16 |
+
>>> from doctr.io import read_html
|
| 17 |
+
>>> doc = read_html("https://www.yoursite.com")
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
----
|
| 21 |
+
url: URL of the target web page
|
| 22 |
+
**kwargs: keyword arguments from `weasyprint.HTML`
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
-------
|
| 26 |
+
decoded PDF file as a bytes stream
|
| 27 |
+
"""
|
| 28 |
+
return HTML(url, **kwargs).write_pdf()
|
doctr/io/image/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from doctr.file_utils import is_tf_available, is_torch_available
|
| 2 |
+
|
| 3 |
+
from .base import *
|
| 4 |
+
|
| 5 |
+
if is_tf_available():
|
| 6 |
+
from .tensorflow import *
|
| 7 |
+
elif is_torch_available():
|
| 8 |
+
from .pytorch import *
|
doctr/io/image/base.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from doctr.utils.common_types import AbstractFile
|
| 13 |
+
|
| 14 |
+
__all__ = ["read_img_as_numpy"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def read_img_as_numpy(
|
| 18 |
+
file: AbstractFile,
|
| 19 |
+
output_size: Optional[Tuple[int, int]] = None,
|
| 20 |
+
rgb_output: bool = True,
|
| 21 |
+
) -> np.ndarray:
|
| 22 |
+
"""Read an image file into numpy format
|
| 23 |
+
|
| 24 |
+
>>> from doctr.io import read_img_as_numpy
|
| 25 |
+
>>> page = read_img_as_numpy("path/to/your/doc.jpg")
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
----
|
| 29 |
+
file: the path to the image file
|
| 30 |
+
output_size: the expected output size of each page in format H x W
|
| 31 |
+
rgb_output: whether the output ndarray channel order should be RGB instead of BGR.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
-------
|
| 35 |
+
the page decoded as numpy ndarray of shape H x W x 3
|
| 36 |
+
"""
|
| 37 |
+
if isinstance(file, (str, Path)):
|
| 38 |
+
if not Path(file).is_file():
|
| 39 |
+
raise FileNotFoundError(f"unable to access {file}")
|
| 40 |
+
img = cv2.imread(str(file), cv2.IMREAD_COLOR)
|
| 41 |
+
elif isinstance(file, bytes):
|
| 42 |
+
_file: np.ndarray = np.frombuffer(file, np.uint8)
|
| 43 |
+
img = cv2.imdecode(_file, cv2.IMREAD_COLOR)
|
| 44 |
+
else:
|
| 45 |
+
raise TypeError("unsupported object type for argument 'file'")
|
| 46 |
+
|
| 47 |
+
# Validity check
|
| 48 |
+
if img is None:
|
| 49 |
+
raise ValueError("unable to read file.")
|
| 50 |
+
# Resizing
|
| 51 |
+
if isinstance(output_size, tuple):
|
| 52 |
+
img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR)
|
| 53 |
+
# Switch the channel order
|
| 54 |
+
if rgb_output:
|
| 55 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 56 |
+
return img
|
doctr/io/image/pytorch.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision.transforms.functional import to_tensor
|
| 13 |
+
|
| 14 |
+
from doctr.utils.common_types import AbstractPath
|
| 15 |
+
|
| 16 |
+
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 20 |
+
"""Convert a PIL Image to a PyTorch tensor
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
----
|
| 24 |
+
pil_img: a PIL image
|
| 25 |
+
dtype: the output tensor data type
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
-------
|
| 29 |
+
decoded image as tensor
|
| 30 |
+
"""
|
| 31 |
+
if dtype == torch.float32:
|
| 32 |
+
img = to_tensor(pil_img)
|
| 33 |
+
else:
|
| 34 |
+
img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype)
|
| 35 |
+
|
| 36 |
+
return img
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 40 |
+
"""Read an image file as a PyTorch tensor
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
----
|
| 44 |
+
img_path: location of the image file
|
| 45 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
-------
|
| 49 |
+
decoded image as a tensor
|
| 50 |
+
"""
|
| 51 |
+
if dtype not in (torch.uint8, torch.float16, torch.float32):
|
| 52 |
+
raise ValueError("insupported value for dtype")
|
| 53 |
+
|
| 54 |
+
pil_img = Image.open(img_path, mode="r").convert("RGB")
|
| 55 |
+
|
| 56 |
+
return tensor_from_pil(pil_img, dtype)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 60 |
+
"""Read a byte stream as a PyTorch tensor
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
----
|
| 64 |
+
img_content: bytes of a decoded image
|
| 65 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
-------
|
| 69 |
+
decoded image as a tensor
|
| 70 |
+
"""
|
| 71 |
+
if dtype not in (torch.uint8, torch.float16, torch.float32):
|
| 72 |
+
raise ValueError("insupported value for dtype")
|
| 73 |
+
|
| 74 |
+
pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")
|
| 75 |
+
|
| 76 |
+
return tensor_from_pil(pil_img, dtype)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 80 |
+
"""Read an image file as a PyTorch tensor
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
----
|
| 84 |
+
npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
|
| 85 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
-------
|
| 89 |
+
same image as a tensor of shape (C, H, W)
|
| 90 |
+
"""
|
| 91 |
+
if dtype not in (torch.uint8, torch.float16, torch.float32):
|
| 92 |
+
raise ValueError("insupported value for dtype")
|
| 93 |
+
|
| 94 |
+
if dtype == torch.float32:
|
| 95 |
+
img = to_tensor(npy_img)
|
| 96 |
+
else:
|
| 97 |
+
img = torch.from_numpy(npy_img)
|
| 98 |
+
# put it from HWC to CHW format
|
| 99 |
+
img = img.permute((2, 0, 1)).contiguous()
|
| 100 |
+
if dtype == torch.float16:
|
| 101 |
+
# Switch to FP16
|
| 102 |
+
img = img.to(dtype=torch.float16).div(255)
|
| 103 |
+
|
| 104 |
+
return img
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
|
| 108 |
+
"""Get the shape of an image"""
|
| 109 |
+
return img.shape[-2:] # type: ignore[return-value]
|
doctr/io/image/tensorflow.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tensorflow.keras.utils import img_to_array
|
| 12 |
+
|
| 13 |
+
from doctr.utils.common_types import AbstractPath
|
| 14 |
+
|
| 15 |
+
__all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
| 19 |
+
"""Convert a PIL Image to a TensorFlow tensor
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
----
|
| 23 |
+
pil_img: a PIL image
|
| 24 |
+
dtype: the output tensor data type
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
-------
|
| 28 |
+
decoded image as tensor
|
| 29 |
+
"""
|
| 30 |
+
npy_img = img_to_array(pil_img)
|
| 31 |
+
|
| 32 |
+
return tensor_from_numpy(npy_img, dtype)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
| 36 |
+
"""Read an image file as a TensorFlow tensor
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
----
|
| 40 |
+
img_path: location of the image file
|
| 41 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
-------
|
| 45 |
+
decoded image as a tensor
|
| 46 |
+
"""
|
| 47 |
+
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
| 48 |
+
raise ValueError("insupported value for dtype")
|
| 49 |
+
|
| 50 |
+
img = tf.io.read_file(img_path)
|
| 51 |
+
img = tf.image.decode_jpeg(img, channels=3)
|
| 52 |
+
|
| 53 |
+
if dtype != tf.uint8:
|
| 54 |
+
img = tf.image.convert_image_dtype(img, dtype=dtype)
|
| 55 |
+
img = tf.clip_by_value(img, 0, 1)
|
| 56 |
+
|
| 57 |
+
return img
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
| 61 |
+
"""Read a byte stream as a TensorFlow tensor
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
----
|
| 65 |
+
img_content: bytes of a decoded image
|
| 66 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
-------
|
| 70 |
+
decoded image as a tensor
|
| 71 |
+
"""
|
| 72 |
+
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
| 73 |
+
raise ValueError("insupported value for dtype")
|
| 74 |
+
|
| 75 |
+
img = tf.io.decode_image(img_content, channels=3)
|
| 76 |
+
|
| 77 |
+
if dtype != tf.uint8:
|
| 78 |
+
img = tf.image.convert_image_dtype(img, dtype=dtype)
|
| 79 |
+
img = tf.clip_by_value(img, 0, 1)
|
| 80 |
+
|
| 81 |
+
return img
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
|
| 85 |
+
"""Read an image file as a TensorFlow tensor
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
----
|
| 89 |
+
npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
|
| 90 |
+
dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
-------
|
| 94 |
+
same image as a tensor of shape (H, W, C)
|
| 95 |
+
"""
|
| 96 |
+
if dtype not in (tf.uint8, tf.float16, tf.float32):
|
| 97 |
+
raise ValueError("insupported value for dtype")
|
| 98 |
+
|
| 99 |
+
if dtype == tf.uint8:
|
| 100 |
+
img = tf.convert_to_tensor(npy_img, dtype=dtype)
|
| 101 |
+
else:
|
| 102 |
+
img = tf.image.convert_image_dtype(npy_img, dtype=dtype)
|
| 103 |
+
img = tf.clip_by_value(img, 0, 1)
|
| 104 |
+
|
| 105 |
+
return img
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_img_shape(img: tf.Tensor) -> Tuple[int, int]:
|
| 109 |
+
"""Get the shape of an image"""
|
| 110 |
+
return img.shape[:2]
|
doctr/io/pdf.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from typing import Any, List, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pypdfium2 as pdfium
|
| 10 |
+
|
| 11 |
+
from doctr.utils.common_types import AbstractFile
|
| 12 |
+
|
| 13 |
+
__all__ = ["read_pdf"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def read_pdf(
|
| 17 |
+
file: AbstractFile,
|
| 18 |
+
scale: float = 2,
|
| 19 |
+
rgb_mode: bool = True,
|
| 20 |
+
password: Optional[str] = None,
|
| 21 |
+
**kwargs: Any,
|
| 22 |
+
) -> List[np.ndarray]:
|
| 23 |
+
"""Read a PDF file and convert it into an image in numpy format
|
| 24 |
+
|
| 25 |
+
>>> from doctr.io import read_pdf
|
| 26 |
+
>>> doc = read_pdf("path/to/your/doc.pdf")
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
----
|
| 30 |
+
file: the path to the PDF file
|
| 31 |
+
scale: rendering scale (1 corresponds to 72dpi)
|
| 32 |
+
rgb_mode: if True, the output will be RGB, otherwise BGR
|
| 33 |
+
password: a password to unlock the document, if encrypted
|
| 34 |
+
**kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
-------
|
| 38 |
+
the list of pages decoded as numpy ndarray of shape H x W x C
|
| 39 |
+
"""
|
| 40 |
+
# Rasterise pages to numpy ndarrays with pypdfium2
|
| 41 |
+
pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
|
| 42 |
+
return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
|
doctr/io/reader.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2021-2024, Mindee.
|
| 2 |
+
|
| 3 |
+
# This program is licensed under the Apache License 2.0.
|
| 4 |
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Sequence, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from doctr.utils.common_types import AbstractFile
|
| 12 |
+
|
| 13 |
+
from .html import read_html
|
| 14 |
+
from .image import read_img_as_numpy
|
| 15 |
+
from .pdf import read_pdf
|
| 16 |
+
|
| 17 |
+
__all__ = ["DocumentFile"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DocumentFile:
|
| 21 |
+
"""Read a document from multiple extensions"""
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]:
|
| 25 |
+
"""Read a PDF file
|
| 26 |
+
|
| 27 |
+
>>> from doctr.io import DocumentFile
|
| 28 |
+
>>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
----
|
| 32 |
+
file: the path to the PDF file or a binary stream
|
| 33 |
+
**kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
-------
|
| 37 |
+
the list of pages decoded as numpy ndarray of shape H x W x 3
|
| 38 |
+
"""
|
| 39 |
+
return read_pdf(file, **kwargs)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def from_url(cls, url: str, **kwargs) -> List[np.ndarray]:
|
| 43 |
+
"""Interpret a web page as a PDF document
|
| 44 |
+
|
| 45 |
+
>>> from doctr.io import DocumentFile
|
| 46 |
+
>>> doc = DocumentFile.from_url("https://www.yoursite.com")
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
----
|
| 50 |
+
url: the URL of the target web page
|
| 51 |
+
**kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
-------
|
| 55 |
+
the list of pages decoded as numpy ndarray of shape H x W x 3
|
| 56 |
+
"""
|
| 57 |
+
pdf_stream = read_html(url)
|
| 58 |
+
return cls.from_pdf(pdf_stream, **kwargs)
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]:
|
| 62 |
+
"""Read an image file (or a collection of image files) and convert it into an image in numpy format
|
| 63 |
+
|
| 64 |
+
>>> from doctr.io import DocumentFile
|
| 65 |
+
>>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"])
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
----
|
| 69 |
+
files: the path to the image file or a binary stream, or a collection of those
|
| 70 |
+
**kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy`
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
-------
|
| 74 |
+
the list of pages decoded as numpy ndarray of shape H x W x 3
|
| 75 |
+
"""
|
| 76 |
+
if isinstance(files, (str, Path, bytes)):
|
| 77 |
+
files = [files]
|
| 78 |
+
|
| 79 |
+
return [read_img_as_numpy(file, **kwargs) for file in files]
|