legacies commited on
Commit
0e17e4e
·
1 Parent(s): 4ce5fde

initial files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +140 -0
  2. .pre-commit-config.yaml +23 -0
  3. CODE_OF_CONDUCT.md +128 -0
  4. CONTRIBUTING.md +92 -0
  5. Dockerfile +75 -0
  6. LICENSE +201 -0
  7. Makefile +33 -0
  8. README.md +384 -12
  9. backend/pytorch.py +93 -0
  10. doctr/__init__.py +3 -0
  11. doctr/datasets/__init__.py +26 -0
  12. doctr/datasets/cord.py +121 -0
  13. doctr/datasets/datasets/__init__.py +6 -0
  14. doctr/datasets/datasets/base.py +132 -0
  15. doctr/datasets/datasets/pytorch.py +59 -0
  16. doctr/datasets/datasets/tensorflow.py +59 -0
  17. doctr/datasets/detection.py +98 -0
  18. doctr/datasets/doc_artefacts.py +82 -0
  19. doctr/datasets/funsd.py +112 -0
  20. doctr/datasets/generator/__init__.py +6 -0
  21. doctr/datasets/generator/base.py +155 -0
  22. doctr/datasets/generator/pytorch.py +54 -0
  23. doctr/datasets/generator/tensorflow.py +60 -0
  24. doctr/datasets/ic03.py +126 -0
  25. doctr/datasets/ic13.py +99 -0
  26. doctr/datasets/iiit5k.py +103 -0
  27. doctr/datasets/iiithws.py +75 -0
  28. doctr/datasets/imgur5k.py +147 -0
  29. doctr/datasets/loader.py +102 -0
  30. doctr/datasets/mjsynth.py +106 -0
  31. doctr/datasets/ocr.py +71 -0
  32. doctr/datasets/orientation.py +40 -0
  33. doctr/datasets/recognition.py +56 -0
  34. doctr/datasets/sroie.py +103 -0
  35. doctr/datasets/svhn.py +131 -0
  36. doctr/datasets/svt.py +117 -0
  37. doctr/datasets/synthtext.py +128 -0
  38. doctr/datasets/utils.py +216 -0
  39. doctr/datasets/vocabs.py +71 -0
  40. doctr/datasets/wildreceipt.py +111 -0
  41. doctr/file_utils.py +92 -0
  42. doctr/io/__init__.py +5 -0
  43. doctr/io/elements.py +621 -0
  44. doctr/io/html.py +28 -0
  45. doctr/io/image/__init__.py +8 -0
  46. doctr/io/image/base.py +56 -0
  47. doctr/io/image/pytorch.py +109 -0
  48. doctr/io/image/tensorflow.py +110 -0
  49. doctr/io/pdf.py +42 -0
  50. doctr/io/reader.py +79 -0
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Temp files
132
+ doctr/version.py
133
+ logs/
134
+ wandb/
135
+ .idea/
136
+
137
+ # Checkpoints
138
+ *.pt
139
+ *.pb
140
+ *.index
.pre-commit-config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-ast
6
+ - id: check-yaml
7
+ exclude: .conda
8
+ - id: check-toml
9
+ - id: check-json
10
+ - id: check-added-large-files
11
+ exclude: docs/images/
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ - id: debug-statements
15
+ - id: check-merge-conflict
16
+ - id: no-commit-to-branch
17
+ args: ['--branch', 'main']
18
+ - repo: https://github.com/astral-sh/ruff-pre-commit
19
+ rev: v0.3.2
20
+ hooks:
21
+ - id: ruff
22
+ args: [ --fix ]
23
+ - id: ruff-format
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ contact@mindee.com.
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
CONTRIBUTING.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to docTR
2
+
3
+ Everything you need to know to contribute efficiently to the project.
4
+
5
+ ## Codebase structure
6
+
7
+ - [doctr](https://github.com/mindee/doctr/blob/main/doctr) - The package codebase
8
+ - [tests](https://github.com/mindee/doctr/blob/main/tests) - Python unit tests
9
+ - [docs](https://github.com/mindee/doctr/blob/main/docs) - Library documentation building
10
+ - [scripts](https://github.com/mindee/doctr/blob/main/scripts) - Example scripts
11
+ - [references](https://github.com/mindee/doctr/blob/main/references) - Reference training scripts
12
+ - [demo](https://github.com/mindee/doctr/blob/main/demo) - Small demo app to showcase docTR capabilities
13
+ - [api](https://github.com/mindee/doctr/blob/main/api) - A minimal template to deploy a REST API with docTR
14
+
15
+ ## Continuous Integration
16
+
17
+ This project uses the following integrations to ensure proper codebase maintenance:
18
+
19
+ - [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage
20
+ - [Codecov](https://codecov.io/) - reports back coverage results
21
+
22
+ As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code.
23
+
24
+ ## Feedback
25
+
26
+ ### Feature requests & bug report
27
+
28
+ Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/mindee/doctr/issues).
29
+
30
+ First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in.
31
+
32
+ ### Questions
33
+
34
+ If you are wondering how to do something with docTR, or a more general question, you should consider checking out Github [discussions](https://github.com/mindee/doctr/discussions). See it as a Q&A forum, or the docTR-specific StackOverflow!
35
+
36
+ ## Developing docTR
37
+
38
+ ### Developer mode installation
39
+
40
+ Install all additional dependencies with the following command:
41
+
42
+ ```shell
43
+ python -m pip install --upgrade pip
44
+ pip install -e .[dev]
45
+ pre-commit install
46
+ ```
47
+
48
+ ### Commits
49
+
50
+ - **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later.
51
+ - **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/)
52
+
53
+ ### Unit tests
54
+
55
+ In order to run the same unit tests as the CI workflows, you can run unittests locally:
56
+
57
+ ```shell
58
+ make test
59
+ ```
60
+
61
+ ### Code quality
62
+
63
+ To run all quality checks together
64
+
65
+ ```shell
66
+ make quality
67
+ ```
68
+
69
+ #### Code style verification
70
+
71
+ To run all style checks together
72
+
73
+ ```shell
74
+ make style
75
+ ```
76
+
77
+ ### Modifying the documentation
78
+
79
+ The current documentation is built using `sphinx` thanks to our CI.
80
+ You can build the documentation locally:
81
+
82
+ ```shell
83
+ make docs-single-version
84
+ ```
85
+
86
+ Please note that files that have not been modified will not be rebuilt. If you want to force a complete rebuild, you can delete the `_build` directory. Additionally, you may need to clear your web browser's cache to see the modifications.
87
+
88
+ You can now open your local version of the documentation located at `docs/_build/index.html` in your browser
89
+
90
+ ## Let's connect
91
+
92
+ Should you wish to connect somewhere else than on GitHub, feel free to join us on [Slack](https://join.slack.com/t/mindee-community/shared_invite/zt-uzgmljfl-MotFVfH~IdEZxjp~0zldww), where you will find a `#doctr` channel!
Dockerfile ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV LANG=C.UTF-8
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PYTHONDONTWRITEBYTECODE=1
7
+
8
+ ARG SYSTEM=gpu
9
+
10
+ # Enroll NVIDIA GPG public key and install CUDA
11
+ RUN if [ "$SYSTEM" = "gpu" ]; then \
12
+ apt-get update && \
13
+ apt-get install -y gnupg ca-certificates wget && \
14
+ # - Install Nvidia repo keys
15
+ # - See: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#network-repo-installation-for-ubuntu
16
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
17
+ dpkg -i cuda-keyring_1.1-1_all.deb && \
18
+ apt-get update && apt-get install -y --no-install-recommends \
19
+ cuda-command-line-tools-11-8 \
20
+ cuda-cudart-dev-11-8 \
21
+ cuda-nvcc-11-8 \
22
+ cuda-cupti-11-8 \
23
+ cuda-nvprune-11-8 \
24
+ cuda-libraries-11-8 \
25
+ cuda-nvrtc-11-8 \
26
+ libcufft-11-8 \
27
+ libcurand-11-8 \
28
+ libcusolver-11-8 \
29
+ libcusparse-11-8 \
30
+ libcublas-11-8 \
31
+ # - CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation
32
+ libcudnn8=8.6.0.163-1+cuda11.8 \
33
+ libnvinfer-plugin8=8.6.1.6-1+cuda11.8 \
34
+ libnvinfer8=8.6.1.6-1+cuda11.8; \
35
+ fi
36
+
37
+ RUN apt-get update && apt-get install -y --no-install-recommends \
38
+ # - Other packages
39
+ build-essential \
40
+ pkg-config \
41
+ curl \
42
+ wget \
43
+ software-properties-common \
44
+ unzip \
45
+ git \
46
+ # - Packages to build Python
47
+ tar make gcc zlib1g-dev libffi-dev libssl-dev liblzma-dev libbz2-dev libsqlite3-dev \
48
+ # - Packages for docTR
49
+ libgl1-mesa-dev libsm6 libxext6 libxrender-dev libpangocairo-1.0-0 \
50
+ && apt-get clean \
51
+ && rm -rf /var/lib/apt/lists/* \
52
+ fi
53
+
54
+ # Install Python
55
+ ARG PYTHON_VERSION=3.10.13
56
+
57
+ RUN wget http://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \
58
+ tar -zxf Python-$PYTHON_VERSION.tgz && \
59
+ cd Python-$PYTHON_VERSION && \
60
+ mkdir /opt/python/ && \
61
+ ./configure --prefix=/opt/python && \
62
+ make && \
63
+ make install && \
64
+ cd .. && \
65
+ rm Python-$PYTHON_VERSION.tgz && \
66
+ rm -r Python-$PYTHON_VERSION
67
+
68
+ ENV PATH=/opt/python/bin:$PATH
69
+
70
+ # Install docTR
71
+ ARG FRAMEWORK=tf
72
+ ARG DOCTR_REPO='mindee/doctr'
73
+ ARG DOCTR_VERSION=main
74
+ RUN pip3 install -U pip setuptools wheel && \
75
+ pip3 install "python-doctr[$FRAMEWORK]@git+https://github.com/$DOCTR_REPO.git@$DOCTR_VERSION"
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2022 Mindee
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style test test-common test-tf test-torch docs-single-version docs
2
+ # this target runs checks on all files
3
+ quality:
4
+ ruff check .
5
+ mypy doctr/
6
+
7
+ # this target runs checks on all files and potentially modifies some of them
8
+ style:
9
+ ruff check --fix .
10
+ ruff format .
11
+
12
+ # Run tests for the library
13
+ test:
14
+ coverage run -m pytest tests/common/
15
+ USE_TF='1' coverage run -m pytest tests/tensorflow/
16
+ USE_TORCH='1' coverage run -m pytest tests/pytorch/
17
+
18
+ test-common:
19
+ coverage run -m pytest tests/common/
20
+
21
+ test-tf:
22
+ USE_TF='1' coverage run -m pytest tests/tensorflow/
23
+
24
+ test-torch:
25
+ USE_TORCH='1' coverage run -m pytest tests/pytorch/
26
+
27
+ # Check that docs can build
28
+ docs-single-version:
29
+ sphinx-build docs/source docs/_build -a
30
+
31
+ # Check that docs can build
32
+ docs:
33
+ cd docs && bash build.sh
README.md CHANGED
@@ -1,12 +1,384 @@
1
- ---
2
- title: Doctr
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.35.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="https://github.com/mindee/doctr/raw/main/docs/images/Logo_doctr.gif" width="40%">
3
+ </p>
4
+
5
+ [![Slack Icon](https://img.shields.io/badge/Slack-Community-4A154B?style=flat-square&logo=slack&logoColor=white)](https://slack.mindee.com) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) ![Build Status](https://github.com/mindee/doctr/workflows/builds/badge.svg) [![Docker Images](https://img.shields.io/badge/Docker-4287f5?style=flat&logo=docker&logoColor=white)](https://github.com/mindee/doctr/pkgs/container/doctr) [![codecov](https://codecov.io/gh/mindee/doctr/branch/main/graph/badge.svg?token=577MO567NM)](https://codecov.io/gh/mindee/doctr) [![CodeFactor](https://www.codefactor.io/repository/github/mindee/doctr/badge?s=bae07db86bb079ce9d6542315b8c6e70fa708a7e)](https://www.codefactor.io/repository/github/mindee/doctr) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/340a76749b634586a498e1c0ab998f08)](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [![Doc Status](https://github.com/mindee/doctr/workflows/doc-status/badge.svg)](https://mindee.github.io/doctr) [![Pypi](https://img.shields.io/badge/pypi-v0.8.1-blue.svg)](https://pypi.org/project/python-doctr/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb)
6
+
7
+
8
+ **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
9
+
10
+ What you can expect from this repository:
11
+
12
+ - efficient ways to parse textual information (localize and identify each word) from your documents
13
+ - guidance on how to integrate this in your current architecture
14
+
15
+ ![OCR_example](https://github.com/mindee/doctr/raw/main/docs/images/ocr.png)
16
+
17
+ ## Quick Tour
18
+
19
+ ### Getting your pretrained model
20
+
21
+ End-to-End OCR is achieved in docTR using a two-stage approach: text detection (localizing words), then text recognition (identify all characters in the word).
22
+ As such, you can select the architecture used for [text detection](https://mindee.github.io/doctr/latest/modules/models.html#doctr-models-detection), and the one for [text recognition](https://mindee.github.io/doctr/latest//modules/models.html#doctr-models-recognition) from the list of available implementations.
23
+
24
+ ```python
25
+ from doctr.models import ocr_predictor
26
+
27
+ model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
28
+ ```
29
+
30
+ ### Reading files
31
+
32
+ Documents can be interpreted from PDF or images:
33
+
34
+ ```python
35
+ from doctr.io import DocumentFile
36
+ # PDF
37
+ pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
38
+ # Image
39
+ single_img_doc = DocumentFile.from_images("path/to/your/img.jpg")
40
+ # Webpage
41
+ webpage_doc = DocumentFile.from_url("https://www.yoursite.com")
42
+ # Multiple page images
43
+ multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"])
44
+ ```
45
+
46
+ ### Putting it together
47
+
48
+ Let's use the default pretrained model for an example:
49
+
50
+ ```python
51
+ from doctr.io import DocumentFile
52
+ from doctr.models import ocr_predictor
53
+
54
+ model = ocr_predictor(pretrained=True)
55
+ # PDF
56
+ doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
57
+ # Analyze
58
+ result = model(doc)
59
+ ```
60
+
61
+ ### Dealing with rotated documents
62
+
63
+ Should you use docTR on documents that include rotated pages, or pages with multiple box orientations,
64
+ you have multiple options to handle it:
65
+
66
+ - If you only use straight document pages with straight words (horizontal, same reading direction),
67
+ consider passing `assume_straight_boxes=True` to the ocr_predictor. It will directly fit straight boxes
68
+ on your page and return straight boxes, which makes it the fastest option.
69
+
70
+ - If you want the predictor to output straight boxes (no matter the orientation of your pages, the final localizations
71
+ will be converted to straight boxes), you need to pass `export_as_straight_boxes=True` in the predictor. Otherwise, if `assume_straight_pages=False`, it will return rotated bounding boxes (potentially with an angle of 0°).
72
+
73
+ If both options are set to False, the predictor will always fit and return rotated boxes.
74
+
75
+ To interpret your model's predictions, you can visualize them interactively as follows:
76
+
77
+ ```python
78
+ result.show()
79
+ ```
80
+
81
+ ![Visualization sample](https://github.com/mindee/doctr/raw/main/docs/images/doctr_example_script.gif)
82
+
83
+ Or even rebuild the original document from its predictions:
84
+
85
+ ```python
86
+ import matplotlib.pyplot as plt
87
+
88
+ synthetic_pages = result.synthesize()
89
+ plt.imshow(synthetic_pages[0]); plt.axis('off'); plt.show()
90
+ ```
91
+
92
+ ![Synthesis sample](https://github.com/mindee/doctr/raw/main/docs/images/synthesized_sample.png)
93
+
94
+ The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`).
95
+ To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/modules/io.html#document-structure):
96
+
97
+ You can also export them as a nested dict, more appropriate for JSON format:
98
+
99
+ ```python
100
+ json_output = result.export()
101
+ ```
102
+
103
+ ### Use the KIE predictor
104
+
105
+ The KIE predictor is a more flexible predictor compared to OCR as your detection model can detect multiple classes in a document. For example, you can have a detection model to detect just dates and addresses in a document.
106
+
107
+ The KIE predictor makes it possible to use detector with multiple classes with a recognition model and to have the whole pipeline already setup for you.
108
+
109
+ ```python
110
+ from doctr.io import DocumentFile
111
+ from doctr.models import kie_predictor
112
+
113
+ # Model
114
+ model = kie_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
115
+ # PDF
116
+ doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
117
+ # Analyze
118
+ result = model(doc)
119
+
120
+ predictions = result.pages[0].predictions
121
+ for class_name in predictions.keys():
122
+ list_predictions = predictions[class_name]
123
+ for prediction in list_predictions:
124
+ print(f"Prediction for {class_name}: {prediction}")
125
+ ```
126
+
127
+ The KIE predictor results per page are in a dictionary format with each key representing a class name and it's value are the predictions for that class.
128
+
129
+ ### If you are looking for support from the Mindee team
130
+
131
+ [![Bad OCR test detection image asking the developer if they need help](https://github.com/mindee/doctr/raw/main/docs/images/doctr-need-help.png)](https://mindee.com/product/doctr)
132
+
133
+ ## Installation
134
+
135
+ ### Prerequisites
136
+
137
+ Python 3.9 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR.
138
+
139
+ Since we use [weasyprint](https://weasyprint.org/), you will need extra dependencies if you are not running Linux.
140
+
141
+ For MacOS users, you can install them as follows:
142
+
143
+ ```shell
144
+ brew install cairo pango gdk-pixbuf libffi
145
+ ```
146
+
147
+ For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases).
148
+
149
+ ### Latest release
150
+
151
+ You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows:
152
+
153
+ ```shell
154
+ pip install python-doctr
155
+ ```
156
+
157
+ > :warning: Please note that the basic installation is not standalone, as it does not provide a deep learning framework, which is required for the package to run.
158
+
159
+ We try to keep framework-specific dependencies to a minimum. You can install framework-specific builds as follows:
160
+
161
+ ```shell
162
+ # for TensorFlow
163
+ pip install "python-doctr[tf]"
164
+ # for PyTorch
165
+ pip install "python-doctr[torch]"
166
+ ```
167
+
168
+ For MacBooks with M1 chip, you will need some additional packages or specific versions:
169
+
170
+ - TensorFlow 2: [metal plugin](https://developer.apple.com/metal/tensorflow-plugin/)
171
+ - PyTorch: [version >= 1.12.0](https://pytorch.org/get-started/locally/#start-locally)
172
+
173
+ ### Developer mode
174
+
175
+ Alternatively, you can install it from source, which will require you to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git).
176
+ First clone the project repository:
177
+
178
+ ```shell
179
+ git clone https://github.com/mindee/doctr.git
180
+ pip install -e doctr/.
181
+ ```
182
+
183
+ Again, if you prefer to avoid the risk of missing dependencies, you can install the TensorFlow or the PyTorch build:
184
+
185
+ ```shell
186
+ # for TensorFlow
187
+ pip install -e doctr/.[tf]
188
+ # for PyTorch
189
+ pip install -e doctr/.[torch]
190
+ ```
191
+
192
+ ## Models architectures
193
+
194
+ Credits where it's due: this repository is implementing, among others, architectures from published research papers.
195
+
196
+ ### Text Detection
197
+
198
+ - DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf).
199
+ - LinkNet: [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/pdf/1707.03718.pdf)
200
+ - FAST: [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/pdf/2111.02394.pdf)
201
+
202
+ ### Text Recognition
203
+
204
+ - CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
205
+ - SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
206
+ - MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
207
+ - ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).
208
+ - PARSeq: [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/pdf/2207.06966).
209
+
210
+ ## More goodies
211
+
212
+ ### Documentation
213
+
214
+ The full package documentation is available [here](https://mindee.github.io/doctr/) for detailed specifications.
215
+
216
+ ### Demo app
217
+
218
+ A minimal demo app is provided for you to play with our end-to-end OCR models!
219
+
220
+ ![Demo app](https://github.com/mindee/doctr/raw/main/docs/images/demo_update.png)
221
+
222
+ #### Live demo
223
+
224
+ Courtesy of :hugs: [Hugging Face](https://huggingface.co/) :hugs:, docTR has now a fully deployed version available on [Spaces](https://huggingface.co/spaces)!
225
+ Check it out [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr)
226
+
227
+ #### Running it locally
228
+
229
+ If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required.
230
+
231
+ ##### Tensorflow version
232
+
233
+ ```shell
234
+ pip install -r demo/tf-requirements.txt
235
+ ```
236
+
237
+ Then run your app in your default browser with:
238
+
239
+ ```shell
240
+ USE_TF=1 streamlit run demo/app.py
241
+ ```
242
+
243
+ ##### PyTorch version
244
+
245
+ ```shell
246
+ pip install -r demo/pt-requirements.txt
247
+ ```
248
+
249
+ Then run your app in your default browser with:
250
+
251
+ ```shell
252
+ USE_TORCH=1 streamlit run demo/app.py
253
+ ```
254
+
255
+ #### TensorFlow.js
256
+
257
+ Instead of having your demo actually running Python, you would prefer to run everything in your web browser?
258
+ Check out our [TensorFlow.js demo](https://github.com/mindee/doctr-tfjs-demo) to get started!
259
+
260
+ ![TFJS demo](https://github.com/mindee/doctr/raw/main/docs/images/demo_illustration_mini.png)
261
+
262
+ ### Docker container
263
+
264
+ [We offer Docker container support for easy testing and deployment](https://github.com/mindee/doctr/pkgs/container/doctr).
265
+
266
+ #### Using GPU with docTR Docker Images
267
+
268
+ The docTR Docker images are GPU-ready and based on CUDA `11.8`.
269
+ However, to use GPU support with these Docker images, please ensure that Docker is configured to use your GPU.
270
+
271
+ To verify and configure GPU support for Docker, please follow the instructions provided in the [NVIDIA Container Toolkit Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
272
+
273
+ Once Docker is configured to use GPUs, you can run docTR Docker containers with GPU support:
274
+
275
+ ```shell
276
+ docker run -it --gpus all ghcr.io/mindee/doctr:tf-py3.8.18-gpu-2023-09 bash
277
+ ```
278
+
279
+ #### Available Tags
280
+
281
+ The Docker images for docTR follow a specific tag nomenclature: `<framework>-py<python_version>-<system>-<doctr_version|YYYY-MM>`. Here's a breakdown of the tag structure:
282
+
283
+ - `<framework>`: `tf` (TensorFlow) or `torch` (PyTorch).
284
+ - `<python_version>`: `3.8.18`, `3.9.18`, or `3.10.13`.
285
+ - `<system>`: `cpu` or `gpu`
286
+ - `<doctr_version>`: a tag >= `v0.7.1`
287
+ - `<YYYY-MM>`: e.g. `2023-09`
288
+
289
+ Here are examples of different image tags:
290
+
291
+ | Tag | Description |
292
+ |----------------------------|---------------------------------------------------|
293
+ | `tf-py3.8.18-cpu-v0.7.1` | TensorFlow version `3.8.18` with docTR `v0.7.1`. |
294
+ | `torch-py3.9.18-gpu-2023-09`| PyTorch version `3.9.18` with GPU support and a monthly build from `2023-09`. |
295
+
296
+ #### Building Docker Images Locally
297
+
298
+ You can also build docTR Docker images locally on your computer.
299
+
300
+ ```shell
301
+ docker build -t doctr .
302
+ ```
303
+
304
+ You can specify custom Python versions and docTR versions using build arguments. For example, to build a docTR image with TensorFlow, Python version `3.9.10`, and docTR version `v0.7.0`, run the following command:
305
+
306
+ ```shell
307
+ docker build -t doctr --build-arg FRAMEWORK=tf --build-arg PYTHON_VERSION=3.9.10 --build-arg DOCTR_VERSION=v0.7.0 .
308
+ ```
309
+
310
+ ### Example script
311
+
312
+ An example script is provided for a simple documentation analysis of a PDF or image file:
313
+
314
+ ```shell
315
+ python scripts/analyze.py path/to/your/doc.pdf
316
+ ```
317
+
318
+ All script arguments can be checked using `python scripts/analyze.py --help`
319
+
320
+ ### Minimal API integration
321
+
322
+ Looking to integrate docTR into your API? Here is a template to get you started with a fully working API using the wonderful [FastAPI](https://github.com/tiangolo/fastapi) framework.
323
+
324
+ #### Deploy your API locally
325
+
326
+ Specific dependencies are required to run the API template, which you can install as follows:
327
+
328
+ ```shell
329
+ cd api/
330
+ pip install poetry
331
+ make lock
332
+ pip install -r requirements.txt
333
+ ```
334
+
335
+ You can now run your API locally:
336
+
337
+ ```shell
338
+ uvicorn --reload --workers 1 --host 0.0.0.0 --port=8002 --app-dir api/ app.main:app
339
+ ```
340
+
341
+ Alternatively, you can run the same server on a docker container if you prefer using:
342
+
343
+ ```shell
344
+ PORT=8002 docker-compose up -d --build
345
+ ```
346
+
347
+ #### What you have deployed
348
+
349
+ Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr", "/kie"). Here is an example with Python to send a request to the OCR route:
350
+
351
+ ```python
352
+ import requests
353
+ with open('/path/to/your/doc.jpg', 'rb') as f:
354
+ data = f.read()
355
+ response = requests.post("http://localhost:8002/ocr", files={'file': data}).json()
356
+ ```
357
+
358
+ ### Example notebooks
359
+
360
+ Looking for more illustrations of docTR features? You might want to check the [Jupyter notebooks](https://github.com/mindee/doctr/tree/main/notebooks) designed to give you a broader overview.
361
+
362
+ ## Citation
363
+
364
+ If you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference:
365
+
366
+ ```bibtex
367
+ @misc{doctr2021,
368
+ title={docTR: Document Text Recognition},
369
+ author={Mindee},
370
+ year={2021},
371
+ publisher = {GitHub},
372
+ howpublished = {\url{https://github.com/mindee/doctr}}
373
+ }
374
+ ```
375
+
376
+ ## Contributing
377
+
378
+ If you scrolled down to this section, you most likely appreciate open source. Do you feel like extending the range of our supported characters? Or perhaps submitting a paper implementation? Or contributing in any other way?
379
+
380
+ You're in luck, we compiled a short guide (cf. [`CONTRIBUTING`](https://mindee.github.io/doctr/contributing/contributing.html)) for you to easily do so!
381
+
382
+ ## License
383
+
384
+ Distributed under the Apache 2.0 License. See [`LICENSE`](https://github.com/mindee/doctr?tab=Apache-2.0-1-ov-file#readme) for more information.
backend/pytorch.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from doctr.models import ocr_predictor
10
+ from doctr.models.predictor import OCRPredictor
11
+
12
+ DET_ARCHS = [
13
+ "db_resnet50",
14
+ "db_resnet34",
15
+ "db_mobilenet_v3_large",
16
+ "linknet_resnet18",
17
+ "linknet_resnet34",
18
+ "linknet_resnet50",
19
+ "fast_tiny",
20
+ "fast_small",
21
+ "fast_base",
22
+ ]
23
+
24
+ RECO_ARCHS = [
25
+ "crnn_vgg16_bn",
26
+ "crnn_mobilenet_v3_small",
27
+ "crnn_mobilenet_v3_large",
28
+ "master",
29
+ "sar_resnet31",
30
+ "vitstr_small",
31
+ "vitstr_base",
32
+ "parseq",
33
+ ]
34
+
35
+
36
+ def load_predictor(
37
+ det_arch: str,
38
+ reco_arch: str,
39
+ assume_straight_pages: bool,
40
+ straighten_pages: bool,
41
+ bin_thresh: float,
42
+ box_thresh: float,
43
+ device: torch.device,
44
+ ) -> OCRPredictor:
45
+ """Load a predictor from doctr.models
46
+
47
+ Args:
48
+ ----
49
+ det_arch: detection architecture
50
+ reco_arch: recognition architecture
51
+ assume_straight_pages: whether to assume straight pages or not
52
+ straighten_pages: whether to straighten rotated pages or not
53
+ bin_thresh: binarization threshold for the segmentation map
54
+ box_thresh: minimal objectness score to consider a box
55
+ device: torch.device, the device to load the predictor on
56
+
57
+ Returns:
58
+ -------
59
+ instance of OCRPredictor
60
+ """
61
+ predictor = ocr_predictor(
62
+ det_arch,
63
+ reco_arch,
64
+ pretrained=True,
65
+ assume_straight_pages=assume_straight_pages,
66
+ straighten_pages=straighten_pages,
67
+ export_as_straight_boxes=straighten_pages,
68
+ detect_orientation=not assume_straight_pages,
69
+ ).to(device)
70
+ predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
71
+ predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
72
+ return predictor
73
+
74
+
75
+ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
76
+ """Forward an image through the predictor
77
+
78
+ Args:
79
+ ----
80
+ predictor: instance of OCRPredictor
81
+ image: image to process
82
+ device: torch.device, the device to process the image on
83
+
84
+ Returns:
85
+ -------
86
+ segmentation map
87
+ """
88
+ with torch.no_grad():
89
+ processed_batches = predictor.det_predictor.pre_processor([image])
90
+ out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
91
+ seg_map = out["out_map"].to("cpu").numpy()
92
+
93
+ return seg_map
doctr/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import io, models, datasets, transforms, utils
2
+ from .file_utils import is_tf_available, is_torch_available
3
+ from .version import __version__ # noqa: F401
doctr/datasets/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from doctr.file_utils import is_tf_available
2
+
3
+ from .generator import *
4
+ from .cord import *
5
+ from .detection import *
6
+ from .doc_artefacts import *
7
+ from .funsd import *
8
+ from .ic03 import *
9
+ from .ic13 import *
10
+ from .iiit5k import *
11
+ from .iiithws import *
12
+ from .imgur5k import *
13
+ from .mjsynth import *
14
+ from .ocr import *
15
+ from .recognition import *
16
+ from .orientation import *
17
+ from .sroie import *
18
+ from .svhn import *
19
+ from .svt import *
20
+ from .synthtext import *
21
+ from .utils import *
22
+ from .vocabs import *
23
+ from .wildreceipt import *
24
+
25
+ if is_tf_available():
26
+ from .loader import *
doctr/datasets/cord.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple, Union
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from .datasets import VisionDataset
15
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
16
+
17
+ __all__ = ["CORD"]
18
+
19
+
20
+ class CORD(VisionDataset):
21
+ """CORD dataset from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing"
22
+ <https://openreview.net/pdf?id=SJl3z659UH>`_.
23
+
24
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/cord-grid.png&src=0
25
+ :align: center
26
+
27
+ >>> from doctr.datasets import CORD
28
+ >>> train_set = CORD(train=True, download=True)
29
+ >>> img, target = train_set[0]
30
+
31
+ Args:
32
+ ----
33
+ train: whether the subset should be the training one
34
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
+ recognition_task: whether the dataset should be used for recognition task
36
+ **kwargs: keyword arguments from `VisionDataset`.
37
+ """
38
+
39
+ TRAIN = (
40
+ "https://doctr-static.mindee.com/models?id=v0.1.1/cord_train.zip&src=0",
41
+ "45f9dc77f126490f3e52d7cb4f70ef3c57e649ea86d19d862a2757c9c455d7f8",
42
+ "cord_train.zip",
43
+ )
44
+
45
+ TEST = (
46
+ "https://doctr-static.mindee.com/models?id=v0.1.1/cord_test.zip&src=0",
47
+ "8c895e3d6f7e1161c5b7245e3723ce15c04d84be89eaa6093949b75a66fb3c58",
48
+ "cord_test.zip",
49
+ )
50
+
51
+ def __init__(
52
+ self,
53
+ train: bool = True,
54
+ use_polygons: bool = False,
55
+ recognition_task: bool = False,
56
+ **kwargs: Any,
57
+ ) -> None:
58
+ url, sha256, name = self.TRAIN if train else self.TEST
59
+ super().__init__(
60
+ url,
61
+ name,
62
+ sha256,
63
+ True,
64
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
65
+ **kwargs,
66
+ )
67
+
68
+ # List images
69
+ tmp_root = os.path.join(self.root, "image")
70
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
71
+ self.train = train
72
+ np_dtype = np.float32
73
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking CORD", total=len(os.listdir(tmp_root))):
74
+ # File existence check
75
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
76
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
77
+
78
+ stem = Path(img_path).stem
79
+ _targets = []
80
+ with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f:
81
+ label = json.load(f)
82
+ for line in label["valid_line"]:
83
+ for word in line["words"]:
84
+ if len(word["text"]) > 0:
85
+ x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"]
86
+ y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"]
87
+ box: Union[List[float], np.ndarray]
88
+ if use_polygons:
89
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
90
+ box = np.array(
91
+ [
92
+ [x[0], y[0]],
93
+ [x[1], y[1]],
94
+ [x[2], y[2]],
95
+ [x[3], y[3]],
96
+ ],
97
+ dtype=np_dtype,
98
+ )
99
+ else:
100
+ # Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax
101
+ box = [min(x), min(y), max(x), max(y)]
102
+ _targets.append((word["text"], box))
103
+
104
+ text_targets, box_targets = zip(*_targets)
105
+
106
+ if recognition_task:
107
+ crops = crop_bboxes_from_image(
108
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
109
+ )
110
+ for crop, label in zip(crops, list(text_targets)):
111
+ self.data.append((crop, label))
112
+ else:
113
+ self.data.append((
114
+ img_path,
115
+ dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
116
+ ))
117
+
118
+ self.root = tmp_root
119
+
120
+ def extra_repr(self) -> str:
121
+ return f"train={self.train}"
doctr/datasets/datasets/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
doctr/datasets/datasets/base.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ import shutil
8
+ from pathlib import Path
9
+ from typing import Any, Callable, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+
13
+ from doctr.io.image import get_img_shape
14
+ from doctr.utils.data import download_from_url
15
+
16
+ from ...models.utils import _copy_tensor
17
+
18
+ __all__ = ["_AbstractDataset", "_VisionDataset"]
19
+
20
+
21
+ class _AbstractDataset:
22
+ data: List[Any] = []
23
+ _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None
24
+
25
+ def __init__(
26
+ self,
27
+ root: Union[str, Path],
28
+ img_transforms: Optional[Callable[[Any], Any]] = None,
29
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
30
+ pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
31
+ ) -> None:
32
+ if not Path(root).is_dir():
33
+ raise ValueError(f"expected a path to a reachable folder: {root}")
34
+
35
+ self.root = root
36
+ self.img_transforms = img_transforms
37
+ self.sample_transforms = sample_transforms
38
+ self._pre_transforms = pre_transforms
39
+ self._get_img_shape = get_img_shape
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.data)
43
+
44
+ def _read_sample(self, index: int) -> Tuple[Any, Any]:
45
+ raise NotImplementedError
46
+
47
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
+ # Read image
49
+ img, target = self._read_sample(index)
50
+ # Pre-transforms (format conversion at run-time etc.)
51
+ if self._pre_transforms is not None:
52
+ img, target = self._pre_transforms(img, target)
53
+
54
+ if self.img_transforms is not None:
55
+ # typing issue cf. https://github.com/python/mypy/issues/5485
56
+ img = self.img_transforms(img)
57
+
58
+ if self.sample_transforms is not None:
59
+ # Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
60
+ if (
61
+ isinstance(target, dict)
62
+ and all(isinstance(item, np.ndarray) for item in target.values())
63
+ and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
64
+ ):
65
+ img_transformed = _copy_tensor(img)
66
+ for class_name, bboxes in target.items():
67
+ img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
68
+ img = img_transformed
69
+ else:
70
+ img, target = self.sample_transforms(img, target)
71
+
72
+ return img, target
73
+
74
+ def extra_repr(self) -> str:
75
+ return ""
76
+
77
+ def __repr__(self) -> str:
78
+ return f"{self.__class__.__name__}({self.extra_repr()})"
79
+
80
+
81
+ class _VisionDataset(_AbstractDataset):
82
+ """Implements an abstract dataset
83
+
84
+ Args:
85
+ ----
86
+ url: URL of the dataset
87
+ file_name: name of the file once downloaded
88
+ file_hash: expected SHA256 of the file
89
+ extract_archive: whether the downloaded file is an archive to be extracted
90
+ download: whether the dataset should be downloaded if not present on disk
91
+ overwrite: whether the archive should be re-extracted
92
+ cache_dir: cache directory
93
+ cache_subdir: subfolder to use in the cache
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ url: str,
99
+ file_name: Optional[str] = None,
100
+ file_hash: Optional[str] = None,
101
+ extract_archive: bool = False,
102
+ download: bool = False,
103
+ overwrite: bool = False,
104
+ cache_dir: Optional[str] = None,
105
+ cache_subdir: Optional[str] = None,
106
+ **kwargs: Any,
107
+ ) -> None:
108
+ cache_dir = (
109
+ str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr")))
110
+ if cache_dir is None
111
+ else cache_dir
112
+ )
113
+
114
+ cache_subdir = "datasets" if cache_subdir is None else cache_subdir
115
+
116
+ file_name = file_name if isinstance(file_name, str) else os.path.basename(url)
117
+ # Download the file if not present
118
+ archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name)
119
+
120
+ if not os.path.exists(archive_path) and not download:
121
+ raise ValueError("the dataset needs to be downloaded first with download=True")
122
+
123
+ archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
124
+
125
+ # Extract the archive
126
+ if extract_archive:
127
+ archive_path = Path(archive_path)
128
+ dataset_path = archive_path.parent.joinpath(archive_path.stem)
129
+ if not dataset_path.is_dir() or overwrite:
130
+ shutil.unpack_archive(archive_path, dataset_path)
131
+
132
+ super().__init__(dataset_path if extract_archive else archive_path, **kwargs)
doctr/datasets/datasets/pytorch.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from copy import deepcopy
8
+ from typing import Any, List, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from doctr.io import read_img_as_tensor, tensor_from_numpy
14
+
15
+ from .base import _AbstractDataset, _VisionDataset
16
+
17
+ __all__ = ["AbstractDataset", "VisionDataset"]
18
+
19
+
20
+ class AbstractDataset(_AbstractDataset):
21
+ """Abstract class for all datasets"""
22
+
23
+ def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]:
24
+ img_name, target = self.data[index]
25
+
26
+ # Check target
27
+ if isinstance(target, dict):
28
+ assert "boxes" in target, "Target should contain 'boxes' key"
29
+ assert "labels" in target, "Target should contain 'labels' key"
30
+ elif isinstance(target, tuple):
31
+ assert len(target) == 2
32
+ assert isinstance(target[0], str) or isinstance(
33
+ target[0], np.ndarray
34
+ ), "first element of the tuple should be a string or a numpy array"
35
+ assert isinstance(target[1], list), "second element of the tuple should be a list"
36
+ else:
37
+ assert isinstance(target, str) or isinstance(
38
+ target, np.ndarray
39
+ ), "Target should be a string or a numpy array"
40
+
41
+ # Read image
42
+ img = (
43
+ tensor_from_numpy(img_name, dtype=torch.float32)
44
+ if isinstance(img_name, np.ndarray)
45
+ else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32)
46
+ )
47
+
48
+ return img, deepcopy(target)
49
+
50
+ @staticmethod
51
+ def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]:
52
+ images, targets = zip(*samples)
53
+ images = torch.stack(images, dim=0) # type: ignore[assignment]
54
+
55
+ return images, list(targets) # type: ignore[return-value]
56
+
57
+
58
+ class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
59
+ pass
doctr/datasets/datasets/tensorflow.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from copy import deepcopy
8
+ from typing import Any, List, Tuple
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+
13
+ from doctr.io import read_img_as_tensor, tensor_from_numpy
14
+
15
+ from .base import _AbstractDataset, _VisionDataset
16
+
17
+ __all__ = ["AbstractDataset", "VisionDataset"]
18
+
19
+
20
+ class AbstractDataset(_AbstractDataset):
21
+ """Abstract class for all datasets"""
22
+
23
+ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]:
24
+ img_name, target = self.data[index]
25
+
26
+ # Check target
27
+ if isinstance(target, dict):
28
+ assert "boxes" in target, "Target should contain 'boxes' key"
29
+ assert "labels" in target, "Target should contain 'labels' key"
30
+ elif isinstance(target, tuple):
31
+ assert len(target) == 2
32
+ assert isinstance(target[0], str) or isinstance(
33
+ target[0], np.ndarray
34
+ ), "first element of the tuple should be a string or a numpy array"
35
+ assert isinstance(target[1], list), "second element of the tuple should be a list"
36
+ else:
37
+ assert isinstance(target, str) or isinstance(
38
+ target, np.ndarray
39
+ ), "Target should be a string or a numpy array"
40
+
41
+ # Read image
42
+ img = (
43
+ tensor_from_numpy(img_name, dtype=tf.float32)
44
+ if isinstance(img_name, np.ndarray)
45
+ else read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32)
46
+ )
47
+
48
+ return img, deepcopy(target)
49
+
50
+ @staticmethod
51
+ def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]:
52
+ images, targets = zip(*samples)
53
+ images = tf.stack(images, axis=0)
54
+
55
+ return images, list(targets)
56
+
57
+
58
+ class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101
59
+ pass
doctr/datasets/detection.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from typing import Any, Dict, List, Tuple, Type, Union
9
+
10
+ import numpy as np
11
+
12
+ from doctr.file_utils import CLASS_NAME
13
+
14
+ from .datasets import AbstractDataset
15
+ from .utils import pre_transform_multiclass
16
+
17
+ __all__ = ["DetectionDataset"]
18
+
19
+
20
+ class DetectionDataset(AbstractDataset):
21
+ """Implements a text detection dataset
22
+
23
+ >>> from doctr.datasets import DetectionDataset
24
+ >>> train_set = DetectionDataset(img_folder="/path/to/images",
25
+ >>> label_path="/path/to/labels.json")
26
+ >>> img, target = train_set[0]
27
+
28
+ Args:
29
+ ----
30
+ img_folder: folder with all the images of the dataset
31
+ label_path: path to the annotations of each image
32
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
33
+ **kwargs: keyword arguments from `AbstractDataset`.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ img_folder: str,
39
+ label_path: str,
40
+ use_polygons: bool = False,
41
+ **kwargs: Any,
42
+ ) -> None:
43
+ super().__init__(
44
+ img_folder,
45
+ pre_transforms=pre_transform_multiclass,
46
+ **kwargs,
47
+ )
48
+
49
+ # File existence check
50
+ self._class_names: List = []
51
+ if not os.path.exists(label_path):
52
+ raise FileNotFoundError(f"unable to locate {label_path}")
53
+ with open(label_path, "rb") as f:
54
+ labels = json.load(f)
55
+
56
+ self.data: List[Tuple[str, Tuple[np.ndarray, List[str]]]] = []
57
+ np_dtype = np.float32
58
+ for img_name, label in labels.items():
59
+ # File existence check
60
+ if not os.path.exists(os.path.join(self.root, img_name)):
61
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
62
+
63
+ geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype)
64
+
65
+ self.data.append((img_name, (np.asarray(geoms, dtype=np_dtype), polygons_classes)))
66
+
67
+ def format_polygons(
68
+ self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type
69
+ ) -> Tuple[np.ndarray, List[str]]:
70
+ """Format polygons into an array
71
+
72
+ Args:
73
+ ----
74
+ polygons: the bounding boxes
75
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
76
+ np_dtype: dtype of array
77
+
78
+ Returns:
79
+ -------
80
+ geoms: bounding boxes as np array
81
+ polygons_classes: list of classes for each bounding box
82
+ """
83
+ if isinstance(polygons, list):
84
+ self._class_names += [CLASS_NAME]
85
+ polygons_classes = [CLASS_NAME for _ in polygons]
86
+ _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype)
87
+ elif isinstance(polygons, dict):
88
+ self._class_names += list(polygons.keys())
89
+ polygons_classes = [k for k, v in polygons.items() for _ in v]
90
+ _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0)
91
+ else:
92
+ raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")
93
+ geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1)
94
+ return geoms, polygons_classes
95
+
96
+ @property
97
+ def class_names(self):
98
+ return sorted(set(self._class_names))
doctr/datasets/doc_artefacts.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from typing import Any, Dict, List, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from .datasets import VisionDataset
13
+
14
+ __all__ = ["DocArtefacts"]
15
+
16
+
17
+ class DocArtefacts(VisionDataset):
18
+ """Object detection dataset for non-textual elements in documents.
19
+ The dataset includes a variety of synthetic document pages with non-textual elements.
20
+
21
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0
22
+ :align: center
23
+
24
+ >>> from doctr.datasets import DocArtefacts
25
+ >>> train_set = DocArtefacts(train=True, download=True)
26
+ >>> img, target = train_set[0]
27
+
28
+ Args:
29
+ ----
30
+ train: whether the subset should be the training one
31
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
32
+ **kwargs: keyword arguments from `VisionDataset`.
33
+ """
34
+
35
+ URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0"
36
+ SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b"
37
+ CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"]
38
+
39
+ def __init__(
40
+ self,
41
+ train: bool = True,
42
+ use_polygons: bool = False,
43
+ **kwargs: Any,
44
+ ) -> None:
45
+ super().__init__(self.URL, None, self.SHA256, True, **kwargs)
46
+ self.train = train
47
+
48
+ # Update root
49
+ self.root = os.path.join(self.root, "train" if train else "val")
50
+ # List images
51
+ tmp_root = os.path.join(self.root, "images")
52
+ with open(os.path.join(self.root, "labels.json"), "rb") as f:
53
+ labels = json.load(f)
54
+ self.data: List[Tuple[str, Dict[str, Any]]] = []
55
+ img_list = os.listdir(tmp_root)
56
+ if len(labels) != len(img_list):
57
+ raise AssertionError("the number of images and labels do not match")
58
+ np_dtype = np.float32
59
+ for img_name, label in labels.items():
60
+ # File existence check
61
+ if not os.path.exists(os.path.join(tmp_root, img_name)):
62
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
63
+
64
+ # xmin, ymin, xmax, ymax
65
+ boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype)
66
+ classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64)
67
+ if use_polygons:
68
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
69
+ boxes = np.stack(
70
+ [
71
+ np.stack([boxes[:, 0], boxes[:, 1]], axis=-1),
72
+ np.stack([boxes[:, 2], boxes[:, 1]], axis=-1),
73
+ np.stack([boxes[:, 2], boxes[:, 3]], axis=-1),
74
+ np.stack([boxes[:, 0], boxes[:, 3]], axis=-1),
75
+ ],
76
+ axis=1,
77
+ )
78
+ self.data.append((img_name, dict(boxes=boxes, labels=classes)))
79
+ self.root = tmp_root
80
+
81
+ def extra_repr(self) -> str:
82
+ return f"train={self.train}"
doctr/datasets/funsd.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple, Union
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from .datasets import VisionDataset
15
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
16
+
17
+ __all__ = ["FUNSD"]
18
+
19
+
20
+ class FUNSD(VisionDataset):
21
+ """FUNSD dataset from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents"
22
+ <https://arxiv.org/pdf/1905.13538.pdf>`_.
23
+
24
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/funsd-grid.png&src=0
25
+ :align: center
26
+
27
+ >>> from doctr.datasets import FUNSD
28
+ >>> train_set = FUNSD(train=True, download=True)
29
+ >>> img, target = train_set[0]
30
+
31
+ Args:
32
+ ----
33
+ train: whether the subset should be the training one
34
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
+ recognition_task: whether the dataset should be used for recognition task
36
+ **kwargs: keyword arguments from `VisionDataset`.
37
+ """
38
+
39
+ URL = "https://guillaumejaume.github.io/FUNSD/dataset.zip"
40
+ SHA256 = "c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f"
41
+ FILE_NAME = "funsd.zip"
42
+
43
+ def __init__(
44
+ self,
45
+ train: bool = True,
46
+ use_polygons: bool = False,
47
+ recognition_task: bool = False,
48
+ **kwargs: Any,
49
+ ) -> None:
50
+ super().__init__(
51
+ self.URL,
52
+ self.FILE_NAME,
53
+ self.SHA256,
54
+ True,
55
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
56
+ **kwargs,
57
+ )
58
+ self.train = train
59
+ np_dtype = np.float32
60
+
61
+ # Use the subset
62
+ subfolder = os.path.join("dataset", "training_data" if train else "testing_data")
63
+
64
+ # # List images
65
+ tmp_root = os.path.join(self.root, subfolder, "images")
66
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
67
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking FUNSD", total=len(os.listdir(tmp_root))):
68
+ # File existence check
69
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
70
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
71
+
72
+ stem = Path(img_path).stem
73
+ with open(os.path.join(self.root, subfolder, "annotations", f"{stem}.json"), "rb") as f:
74
+ data = json.load(f)
75
+
76
+ _targets = [
77
+ (word["text"], word["box"])
78
+ for block in data["form"]
79
+ for word in block["words"]
80
+ if len(word["text"]) > 0
81
+ ]
82
+ text_targets, box_targets = zip(*_targets)
83
+ if use_polygons:
84
+ # xmin, ymin, xmax, ymax -> (x, y) coordinates of top left, top right, bottom right, bottom left corners
85
+ box_targets = [ # type: ignore[assignment]
86
+ [
87
+ [box[0], box[1]],
88
+ [box[2], box[1]],
89
+ [box[2], box[3]],
90
+ [box[0], box[3]],
91
+ ]
92
+ for box in box_targets
93
+ ]
94
+
95
+ if recognition_task:
96
+ crops = crop_bboxes_from_image(
97
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=np_dtype)
98
+ )
99
+ for crop, label in zip(crops, list(text_targets)):
100
+ # filter labels with unknown characters
101
+ if not any(char in label for char in ["☑", "☐", "\uf703", "\uf702"]):
102
+ self.data.append((crop, label))
103
+ else:
104
+ self.data.append((
105
+ img_path,
106
+ dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(text_targets)),
107
+ ))
108
+
109
+ self.root = tmp_root
110
+
111
+ def extra_repr(self) -> str:
112
+ return f"train={self.train}"
doctr/datasets/generator/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ if is_tf_available():
4
+ from .tensorflow import *
5
+ elif is_torch_available():
6
+ from .pytorch import * # type: ignore[assignment]
doctr/datasets/generator/base.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import random
7
+ from typing import Any, Callable, List, Optional, Tuple, Union
8
+
9
+ from PIL import Image, ImageDraw
10
+
11
+ from doctr.io.image import tensor_from_pil
12
+ from doctr.utils.fonts import get_font
13
+
14
+ from ..datasets import AbstractDataset
15
+
16
+
17
+ def synthesize_text_img(
18
+ text: str,
19
+ font_size: int = 32,
20
+ font_family: Optional[str] = None,
21
+ background_color: Optional[Tuple[int, int, int]] = None,
22
+ text_color: Optional[Tuple[int, int, int]] = None,
23
+ ) -> Image.Image:
24
+ """Generate a synthetic text image
25
+
26
+ Args:
27
+ ----
28
+ text: the text to render as an image
29
+ font_size: the size of the font
30
+ font_family: the font family (has to be installed on your system)
31
+ background_color: background color of the final image
32
+ text_color: text color on the final image
33
+
34
+ Returns:
35
+ -------
36
+ PIL image of the text
37
+ """
38
+ background_color = (0, 0, 0) if background_color is None else background_color
39
+ text_color = (255, 255, 255) if text_color is None else text_color
40
+
41
+ font = get_font(font_family, font_size)
42
+ left, top, right, bottom = font.getbbox(text)
43
+ text_w, text_h = right - left, bottom - top
44
+ h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
45
+ # If single letter, make the image square, otherwise expand to meet the text size
46
+ img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))
47
+
48
+ img = Image.new("RGB", img_size[::-1], color=background_color)
49
+ d = ImageDraw.Draw(img)
50
+
51
+ # Offset so that the text is centered
52
+ text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
53
+ # Draw the text
54
+ d.text(text_pos, text, font=font, fill=text_color)
55
+ return img
56
+
57
+
58
+ class _CharacterGenerator(AbstractDataset):
59
+ def __init__(
60
+ self,
61
+ vocab: str,
62
+ num_samples: int,
63
+ cache_samples: bool = False,
64
+ font_family: Optional[Union[str, List[str]]] = None,
65
+ img_transforms: Optional[Callable[[Any], Any]] = None,
66
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
67
+ ) -> None:
68
+ self.vocab = vocab
69
+ self._num_samples = num_samples
70
+ self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
71
+ # Validate fonts
72
+ if isinstance(font_family, list):
73
+ for font in self.font_family:
74
+ try:
75
+ _ = get_font(font, 10)
76
+ except OSError:
77
+ raise ValueError(f"unable to locate font: {font}")
78
+ self.img_transforms = img_transforms
79
+ self.sample_transforms = sample_transforms
80
+
81
+ self._data: List[Image.Image] = []
82
+ if cache_samples:
83
+ self._data = [
84
+ (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
85
+ for idx, char in enumerate(self.vocab)
86
+ for font in self.font_family
87
+ ]
88
+
89
+ def __len__(self) -> int:
90
+ return self._num_samples
91
+
92
+ def _read_sample(self, index: int) -> Tuple[Any, int]:
93
+ # Samples are already cached
94
+ if len(self._data) > 0:
95
+ idx = index % len(self._data)
96
+ pil_img, target = self._data[idx] # type: ignore[misc]
97
+ else:
98
+ target = index % len(self.vocab)
99
+ pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
100
+ img = tensor_from_pil(pil_img)
101
+
102
+ return img, target
103
+
104
+
105
+ class _WordGenerator(AbstractDataset):
106
+ def __init__(
107
+ self,
108
+ vocab: str,
109
+ min_chars: int,
110
+ max_chars: int,
111
+ num_samples: int,
112
+ cache_samples: bool = False,
113
+ font_family: Optional[Union[str, List[str]]] = None,
114
+ img_transforms: Optional[Callable[[Any], Any]] = None,
115
+ sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
116
+ ) -> None:
117
+ self.vocab = vocab
118
+ self.wordlen_range = (min_chars, max_chars)
119
+ self._num_samples = num_samples
120
+ self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
121
+ # Validate fonts
122
+ if isinstance(font_family, list):
123
+ for font in self.font_family:
124
+ try:
125
+ _ = get_font(font, 10)
126
+ except OSError:
127
+ raise ValueError(f"unable to locate font: {font}")
128
+ self.img_transforms = img_transforms
129
+ self.sample_transforms = sample_transforms
130
+
131
+ self._data: List[Image.Image] = []
132
+ if cache_samples:
133
+ _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
134
+ self._data = [
135
+ (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
136
+ for text in _words
137
+ ]
138
+
139
+ def _generate_string(self, min_chars: int, max_chars: int) -> str:
140
+ num_chars = random.randint(min_chars, max_chars)
141
+ return "".join(random.choice(self.vocab) for _ in range(num_chars))
142
+
143
+ def __len__(self) -> int:
144
+ return self._num_samples
145
+
146
+ def _read_sample(self, index: int) -> Tuple[Any, str]:
147
+ # Samples are already cached
148
+ if len(self._data) > 0:
149
+ pil_img, target = self._data[index] # type: ignore[misc]
150
+ else:
151
+ target = self._generate_string(*self.wordlen_range)
152
+ pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
153
+ img = tensor_from_pil(pil_img)
154
+
155
+ return img, target
doctr/datasets/generator/pytorch.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from torch.utils.data._utils.collate import default_collate
7
+
8
+ from .base import _CharacterGenerator, _WordGenerator
9
+
10
+ __all__ = ["CharacterGenerator", "WordGenerator"]
11
+
12
+
13
+ class CharacterGenerator(_CharacterGenerator):
14
+ """Implements a character image generation dataset
15
+
16
+ >>> from doctr.datasets import CharacterGenerator
17
+ >>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
18
+ >>> img, target = ds[0]
19
+
20
+ Args:
21
+ ----
22
+ vocab: vocabulary to take the character from
23
+ num_samples: number of samples that will be generated iterating over the dataset
24
+ cache_samples: whether generated images should be cached firsthand
25
+ font_family: font to use to generate the text images
26
+ img_transforms: composable transformations that will be applied to each image
27
+ sample_transforms: composable transformations that will be applied to both the image and the target
28
+ """
29
+
30
+ def __init__(self, *args, **kwargs) -> None:
31
+ super().__init__(*args, **kwargs)
32
+ setattr(self, "collate_fn", default_collate)
33
+
34
+
35
+ class WordGenerator(_WordGenerator):
36
+ """Implements a character image generation dataset
37
+
38
+ >>> from doctr.datasets import WordGenerator
39
+ >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
40
+ >>> img, target = ds[0]
41
+
42
+ Args:
43
+ ----
44
+ vocab: vocabulary to take the character from
45
+ min_chars: minimum number of characters in a word
46
+ max_chars: maximum number of characters in a word
47
+ num_samples: number of samples that will be generated iterating over the dataset
48
+ cache_samples: whether generated images should be cached firsthand
49
+ font_family: font to use to generate the text images
50
+ img_transforms: composable transformations that will be applied to each image
51
+ sample_transforms: composable transformations that will be applied to both the image and the target
52
+ """
53
+
54
+ pass
doctr/datasets/generator/tensorflow.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import tensorflow as tf
7
+
8
+ from .base import _CharacterGenerator, _WordGenerator
9
+
10
+ __all__ = ["CharacterGenerator", "WordGenerator"]
11
+
12
+
13
+ class CharacterGenerator(_CharacterGenerator):
14
+ """Implements a character image generation dataset
15
+
16
+ >>> from doctr.datasets import CharacterGenerator
17
+ >>> ds = CharacterGenerator(vocab='abdef', num_samples=100)
18
+ >>> img, target = ds[0]
19
+
20
+ Args:
21
+ ----
22
+ vocab: vocabulary to take the character from
23
+ num_samples: number of samples that will be generated iterating over the dataset
24
+ cache_samples: whether generated images should be cached firsthand
25
+ font_family: font to use to generate the text images
26
+ img_transforms: composable transformations that will be applied to each image
27
+ sample_transforms: composable transformations that will be applied to both the image and the target
28
+ """
29
+
30
+ def __init__(self, *args, **kwargs) -> None:
31
+ super().__init__(*args, **kwargs)
32
+
33
+ @staticmethod
34
+ def collate_fn(samples):
35
+ images, targets = zip(*samples)
36
+ images = tf.stack(images, axis=0)
37
+
38
+ return images, tf.convert_to_tensor(targets)
39
+
40
+
41
+ class WordGenerator(_WordGenerator):
42
+ """Implements a character image generation dataset
43
+
44
+ >>> from doctr.datasets import WordGenerator
45
+ >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100)
46
+ >>> img, target = ds[0]
47
+
48
+ Args:
49
+ ----
50
+ vocab: vocabulary to take the character from
51
+ min_chars: minimum number of characters in a word
52
+ max_chars: maximum number of characters in a word
53
+ num_samples: number of samples that will be generated iterating over the dataset
54
+ cache_samples: whether generated images should be cached firsthand
55
+ font_family: font to use to generate the text images
56
+ img_transforms: composable transformations that will be applied to each image
57
+ sample_transforms: composable transformations that will be applied to both the image and the target
58
+ """
59
+
60
+ pass
doctr/datasets/ic03.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ import defusedxml.ElementTree as ET
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+ from .datasets import VisionDataset
14
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
15
+
16
+ __all__ = ["IC03"]
17
+
18
+
19
+ class IC03(VisionDataset):
20
+ """IC03 dataset from `"ICDAR 2003 Robust Reading Competitions: Entries, Results and Future Directions"
21
+ <http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions>`_.
22
+
23
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic03-grid.png&src=0
24
+ :align: center
25
+
26
+ >>> from doctr.datasets import IC03
27
+ >>> train_set = IC03(train=True, download=True)
28
+ >>> img, target = train_set[0]
29
+
30
+ Args:
31
+ ----
32
+ train: whether the subset should be the training one
33
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
+ recognition_task: whether the dataset should be used for recognition task
35
+ **kwargs: keyword arguments from `VisionDataset`.
36
+ """
37
+
38
+ TRAIN = (
39
+ "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTrain/scene.zip",
40
+ "9d86df514eb09dd693fb0b8c671ef54a0cfe02e803b1bbef9fc676061502eb94",
41
+ "ic03_train.zip",
42
+ )
43
+ TEST = (
44
+ "http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTest/scene.zip",
45
+ "dbc4b5fd5d04616b8464a1b42ea22db351ee22c2546dd15ac35611857ea111f8",
46
+ "ic03_test.zip",
47
+ )
48
+
49
+ def __init__(
50
+ self,
51
+ train: bool = True,
52
+ use_polygons: bool = False,
53
+ recognition_task: bool = False,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ url, sha256, file_name = self.TRAIN if train else self.TEST
57
+ super().__init__(
58
+ url,
59
+ file_name,
60
+ sha256,
61
+ True,
62
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
63
+ **kwargs,
64
+ )
65
+ self.train = train
66
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
67
+ np_dtype = np.float32
68
+
69
+ # Load xml data
70
+ tmp_root = (
71
+ os.path.join(self.root, "SceneTrialTrain" if self.train else "SceneTrialTest") if sha256 else self.root
72
+ )
73
+ xml_tree = ET.parse(os.path.join(tmp_root, "words.xml"))
74
+ xml_root = xml_tree.getroot()
75
+
76
+ for image in tqdm(iterable=xml_root, desc="Unpacking IC03", total=len(xml_root)):
77
+ name, _resolution, rectangles = image
78
+
79
+ # File existence check
80
+ if not os.path.exists(os.path.join(tmp_root, name.text)):
81
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
82
+
83
+ if use_polygons:
84
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
85
+ _boxes = [
86
+ [
87
+ [float(rect.attrib["x"]), float(rect.attrib["y"])],
88
+ [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
89
+ [
90
+ float(rect.attrib["x"]) + float(rect.attrib["width"]),
91
+ float(rect.attrib["y"]) + float(rect.attrib["height"]),
92
+ ],
93
+ [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
94
+ ]
95
+ for rect in rectangles
96
+ ]
97
+ else:
98
+ # x_min, y_min, x_max, y_max
99
+ _boxes = [
100
+ [
101
+ float(rect.attrib["x"]), # type: ignore[list-item]
102
+ float(rect.attrib["y"]), # type: ignore[list-item]
103
+ float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
104
+ float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
105
+ ]
106
+ for rect in rectangles
107
+ ]
108
+
109
+ # filter images without boxes
110
+ if len(_boxes) > 0:
111
+ boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
112
+ # Get the labels
113
+ labels = [lab.text for rect in rectangles for lab in rect if lab.text]
114
+
115
+ if recognition_task:
116
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
117
+ for crop, label in zip(crops, labels):
118
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
119
+ self.data.append((crop, label))
120
+ else:
121
+ self.data.append((name.text, dict(boxes=boxes, labels=labels)))
122
+
123
+ self.root = tmp_root
124
+
125
+ def extra_repr(self) -> str:
126
+ return f"train={self.train}"
doctr/datasets/ic13.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import csv
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple, Union
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from .datasets import AbstractDataset
15
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
16
+
17
+ __all__ = ["IC13"]
18
+
19
+
20
+ class IC13(AbstractDataset):
21
+ """IC13 dataset from `"ICDAR 2013 Robust Reading Competition" <https://rrc.cvc.uab.es/>`_.
22
+
23
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/ic13-grid.png&src=0
24
+ :align: center
25
+
26
+ >>> # NOTE: You need to download both image and label parts from Focused Scene Text challenge Task2.1 2013-2015.
27
+ >>> from doctr.datasets import IC13
28
+ >>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images",
29
+ >>> label_folder="/path/to/Challenge2_Training_Task1_GT")
30
+ >>> img, target = train_set[0]
31
+ >>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images",
32
+ >>> label_folder="/path/to/Challenge2_Test_Task1_GT")
33
+ >>> img, target = test_set[0]
34
+
35
+ Args:
36
+ ----
37
+ img_folder: folder with all the images of the dataset
38
+ label_folder: folder with all annotation files for the images
39
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
40
+ recognition_task: whether the dataset should be used for recognition task
41
+ **kwargs: keyword arguments from `AbstractDataset`.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ img_folder: str,
47
+ label_folder: str,
48
+ use_polygons: bool = False,
49
+ recognition_task: bool = False,
50
+ **kwargs: Any,
51
+ ) -> None:
52
+ super().__init__(
53
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
54
+ )
55
+
56
+ # File existence check
57
+ if not os.path.exists(label_folder) or not os.path.exists(img_folder):
58
+ raise FileNotFoundError(
59
+ f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}"
60
+ )
61
+
62
+ self.data: List[Tuple[Union[Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
63
+ np_dtype = np.float32
64
+
65
+ img_names = os.listdir(img_folder)
66
+
67
+ for img_name in tqdm(iterable=img_names, desc="Unpacking IC13", total=len(img_names)):
68
+ img_path = Path(img_folder, img_name)
69
+ label_path = Path(label_folder, "gt_" + Path(img_name).stem + ".txt")
70
+
71
+ with open(label_path, newline="\n") as f:
72
+ _lines = [
73
+ [val[:-1] if val.endswith(",") else val for val in row]
74
+ for row in csv.reader(f, delimiter=" ", quotechar="'")
75
+ ]
76
+ labels = [line[-1].replace('"', "") for line in _lines]
77
+ # xmin, ymin, xmax, ymax
78
+ box_targets: np.ndarray = np.array([list(map(int, line[:4])) for line in _lines], dtype=np_dtype)
79
+ if use_polygons:
80
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
81
+ box_targets = np.array(
82
+ [
83
+ [
84
+ [coords[0], coords[1]],
85
+ [coords[2], coords[1]],
86
+ [coords[2], coords[3]],
87
+ [coords[0], coords[3]],
88
+ ]
89
+ for coords in box_targets
90
+ ],
91
+ dtype=np_dtype,
92
+ )
93
+
94
+ if recognition_task:
95
+ crops = crop_bboxes_from_image(img_path=img_path, geoms=box_targets)
96
+ for crop, label in zip(crops, labels):
97
+ self.data.append((crop, label))
98
+ else:
99
+ self.data.append((img_path, dict(boxes=box_targets, labels=labels)))
doctr/datasets/iiit5k.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ import numpy as np
10
+ import scipy.io as sio
11
+ from tqdm import tqdm
12
+
13
+ from .datasets import VisionDataset
14
+ from .utils import convert_target_to_relative
15
+
16
+ __all__ = ["IIIT5K"]
17
+
18
+
19
+ class IIIT5K(VisionDataset):
20
+ """IIIT-5K character-level localization dataset from
21
+ `"BMVC 2012 Scene Text Recognition using Higher Order Language Priors"
22
+ <https://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/home/mishraBMVC12.pdf>`_.
23
+
24
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/iiit5k-grid.png&src=0
25
+ :align: center
26
+
27
+ >>> # NOTE: this dataset is for character-level localization
28
+ >>> from doctr.datasets import IIIT5K
29
+ >>> train_set = IIIT5K(train=True, download=True)
30
+ >>> img, target = train_set[0]
31
+
32
+ Args:
33
+ ----
34
+ train: whether the subset should be the training one
35
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
36
+ recognition_task: whether the dataset should be used for recognition task
37
+ **kwargs: keyword arguments from `VisionDataset`.
38
+ """
39
+
40
+ URL = "https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz"
41
+ SHA256 = "7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e"
42
+
43
+ def __init__(
44
+ self,
45
+ train: bool = True,
46
+ use_polygons: bool = False,
47
+ recognition_task: bool = False,
48
+ **kwargs: Any,
49
+ ) -> None:
50
+ super().__init__(
51
+ self.URL,
52
+ None,
53
+ file_hash=self.SHA256,
54
+ extract_archive=True,
55
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
56
+ **kwargs,
57
+ )
58
+ self.train = train
59
+
60
+ # Load mat data
61
+ tmp_root = os.path.join(self.root, "IIIT5K") if self.SHA256 else self.root
62
+ mat_file = "trainCharBound" if self.train else "testCharBound"
63
+ mat_data = sio.loadmat(os.path.join(tmp_root, f"{mat_file}.mat"))[mat_file][0]
64
+
65
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
66
+ np_dtype = np.float32
67
+
68
+ for img_path, label, box_targets in tqdm(iterable=mat_data, desc="Unpacking IIIT5K", total=len(mat_data)):
69
+ _raw_path = img_path[0]
70
+ _raw_label = label[0]
71
+
72
+ # File existence check
73
+ if not os.path.exists(os.path.join(tmp_root, _raw_path)):
74
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")
75
+
76
+ if recognition_task:
77
+ self.data.append((_raw_path, _raw_label))
78
+ else:
79
+ if use_polygons:
80
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
81
+ box_targets = [
82
+ [
83
+ [box[0], box[1]],
84
+ [box[0] + box[2], box[1]],
85
+ [box[0] + box[2], box[1] + box[3]],
86
+ [box[0], box[1] + box[3]],
87
+ ]
88
+ for box in box_targets
89
+ ]
90
+ else:
91
+ # xmin, ymin, xmax, ymax
92
+ box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets]
93
+
94
+ # label are casted to list where each char corresponds to the character's bounding box
95
+ self.data.append((
96
+ _raw_path,
97
+ dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(_raw_label)),
98
+ ))
99
+
100
+ self.root = tmp_root
101
+
102
+ def extra_repr(self) -> str:
103
+ return f"train={self.train}"
doctr/datasets/iiithws.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from random import sample
8
+ from typing import Any, List, Tuple
9
+
10
+ from tqdm import tqdm
11
+
12
+ from .datasets import AbstractDataset
13
+
14
+ __all__ = ["IIITHWS"]
15
+
16
+
17
+ class IIITHWS(AbstractDataset):
18
+ """IIITHWS dataset from `"Generating Synthetic Data for Text Recognition"
19
+ <https://arxiv.org/pdf/1608.04224.pdf>`_ | `"repository" <https://github.com/kris314/hwnet>`_ |
20
+ `"website" <https://cvit.iiit.ac.in/research/projects/cvit-projects/matchdocimgs>`_.
21
+
22
+ >>> # NOTE: This is a pure recognition dataset without bounding box labels.
23
+ >>> # NOTE: You need to download the dataset.
24
+ >>> from doctr.datasets import IIITHWS
25
+ >>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
26
+ >>> label_path="/path/to/IIIT-HWS-90K.txt",
27
+ >>> train=True)
28
+ >>> img, target = train_set[0]
29
+ >>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized",
30
+ >>> label_path="/path/to/IIIT-HWS-90K.txt")
31
+ >>> train=False)
32
+ >>> img, target = test_set[0]
33
+
34
+ Args:
35
+ ----
36
+ img_folder: folder with all the images of the dataset
37
+ label_path: path to the file with the labels
38
+ train: whether the subset should be the training one
39
+ **kwargs: keyword arguments from `AbstractDataset`.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ img_folder: str,
45
+ label_path: str,
46
+ train: bool = True,
47
+ **kwargs: Any,
48
+ ) -> None:
49
+ super().__init__(img_folder, **kwargs)
50
+
51
+ # File existence check
52
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
53
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
54
+
55
+ self.data: List[Tuple[str, str]] = []
56
+ self.train = train
57
+
58
+ with open(label_path) as f:
59
+ annotations = f.readlines()
60
+
61
+ # Shuffle the dataset otherwise the test set will contain the same labels n times
62
+ annotations = sample(annotations, len(annotations))
63
+ train_samples = int(len(annotations) * 0.9)
64
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
65
+
66
+ for annotation in tqdm(
67
+ iterable=annotations[set_slice], desc="Unpacking IIITHWS", total=len(annotations[set_slice])
68
+ ):
69
+ img_path, label = annotation.split()[0:2]
70
+ img_path = os.path.join(img_folder, img_path)
71
+
72
+ self.data.append((img_path, label))
73
+
74
+ def extra_repr(self) -> str:
75
+ return f"train={self.train}"
doctr/datasets/imgur5k.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import glob
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+ import cv2
13
+ import numpy as np
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+
17
+ from .datasets import AbstractDataset
18
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
19
+
20
+ __all__ = ["IMGUR5K"]
21
+
22
+
23
+ class IMGUR5K(AbstractDataset):
24
+ """IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example"
25
+ <https://arxiv.org/abs/2106.08385>`_ |
26
+ `repository <https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset>`_.
27
+
28
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/imgur5k-grid.png&src=0
29
+ :align: center
30
+ :width: 630
31
+ :height: 400
32
+
33
+ >>> # NOTE: You need to download/generate the dataset from the repository.
34
+ >>> from doctr.datasets import IMGUR5K
35
+ >>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
36
+ >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
37
+ >>> img, target = train_set[0]
38
+ >>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
39
+ >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
40
+ >>> img, target = test_set[0]
41
+
42
+ Args:
43
+ ----
44
+ img_folder: folder with all the images of the dataset
45
+ label_path: path to the annotations file of the dataset
46
+ train: whether the subset should be the training one
47
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
48
+ recognition_task: whether the dataset should be used for recognition task
49
+ **kwargs: keyword arguments from `AbstractDataset`.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ img_folder: str,
55
+ label_path: str,
56
+ train: bool = True,
57
+ use_polygons: bool = False,
58
+ recognition_task: bool = False,
59
+ **kwargs: Any,
60
+ ) -> None:
61
+ super().__init__(
62
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
63
+ )
64
+
65
+ # File existence check
66
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
67
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
68
+
69
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
70
+ self.train = train
71
+ np_dtype = np.float32
72
+
73
+ img_names = os.listdir(img_folder)
74
+ train_samples = int(len(img_names) * 0.9)
75
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
76
+
77
+ # define folder to write IMGUR5K recognition dataset
78
+ reco_folder_name = "IMGUR5K_recognition_train" if self.train else "IMGUR5K_recognition_test"
79
+ reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
80
+ reco_folder_path = os.path.join(os.path.dirname(self.root), reco_folder_name)
81
+ reco_images_counter = 0
82
+
83
+ if recognition_task and os.path.isdir(reco_folder_path):
84
+ self._read_from_folder(reco_folder_path)
85
+ return
86
+ elif recognition_task and not os.path.isdir(reco_folder_path):
87
+ os.makedirs(reco_folder_path, exist_ok=False)
88
+
89
+ with open(label_path) as f:
90
+ annotation_file = json.load(f)
91
+
92
+ for img_name in tqdm(iterable=img_names[set_slice], desc="Unpacking IMGUR5K", total=len(img_names[set_slice])):
93
+ img_path = Path(img_folder, img_name)
94
+ img_id = img_name.split(".")[0]
95
+
96
+ # File existence check
97
+ if not os.path.exists(os.path.join(self.root, img_name)):
98
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
99
+
100
+ # some files have no annotations which are marked with only a dot in the 'word' key
101
+ # ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md
102
+ if img_id not in annotation_file["index_to_ann_map"].keys():
103
+ continue
104
+ ann_ids = annotation_file["index_to_ann_map"][img_id]
105
+ annotations = [annotation_file["ann_id"][a_id] for a_id in ann_ids]
106
+
107
+ labels = [ann["word"] for ann in annotations if ann["word"] != "."]
108
+ # x_center, y_center, width, height, angle
109
+ _boxes = [
110
+ list(map(float, ann["bounding_box"].strip("[ ]").split(", ")))
111
+ for ann in annotations
112
+ if ann["word"] != "."
113
+ ]
114
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
115
+ box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] # type: ignore[arg-type]
116
+
117
+ if not use_polygons:
118
+ # xmin, ymin, xmax, ymax
119
+ box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets]
120
+
121
+ # filter images without boxes
122
+ if len(box_targets) > 0:
123
+ if recognition_task:
124
+ crops = crop_bboxes_from_image(
125
+ img_path=os.path.join(self.root, img_name), geoms=np.asarray(box_targets, dtype=np_dtype)
126
+ )
127
+ for crop, label in zip(crops, labels):
128
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
129
+ # write data to disk
130
+ with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
131
+ f.write(label)
132
+ tmp_img = Image.fromarray(crop)
133
+ tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
134
+ reco_images_counter += 1
135
+ else:
136
+ self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))
137
+
138
+ if recognition_task:
139
+ self._read_from_folder(reco_folder_path)
140
+
141
+ def extra_repr(self) -> str:
142
+ return f"train={self.train}"
143
+
144
+ def _read_from_folder(self, path: str) -> None:
145
+ for img_path in glob.glob(os.path.join(path, "*.png")):
146
+ with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
147
+ self.data.append((img_path, f.read()))
doctr/datasets/loader.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import math
7
+ from typing import Callable, Optional
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+
12
+ from doctr.utils.multithreading import multithread_exec
13
+
14
+ __all__ = ["DataLoader"]
15
+
16
+
17
+ def default_collate(samples):
18
+ """Collate multiple elements into batches
19
+
20
+ Args:
21
+ ----
22
+ samples: list of N tuples containing M elements
23
+
24
+ Returns:
25
+ -------
26
+ Tuple of M sequences contianing N elements each
27
+ """
28
+ batch_data = zip(*samples)
29
+
30
+ tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data)
31
+
32
+ return tf_data
33
+
34
+
35
+ class DataLoader:
36
+ """Implements a dataset wrapper for fast data loading
37
+
38
+ >>> from doctr.datasets import CORD, DataLoader
39
+ >>> train_set = CORD(train=True, download=True)
40
+ >>> train_loader = DataLoader(train_set, batch_size=32)
41
+ >>> train_iter = iter(train_loader)
42
+ >>> images, targets = next(train_iter)
43
+
44
+ Args:
45
+ ----
46
+ dataset: the dataset
47
+ shuffle: whether the samples should be shuffled before passing it to the iterator
48
+ batch_size: number of elements in each batch
49
+ drop_last: if `True`, drops the last batch if it isn't full
50
+ num_workers: number of workers to use for data loading
51
+ collate_fn: function to merge samples into a batch
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dataset,
57
+ shuffle: bool = True,
58
+ batch_size: int = 1,
59
+ drop_last: bool = False,
60
+ num_workers: Optional[int] = None,
61
+ collate_fn: Optional[Callable] = None,
62
+ ) -> None:
63
+ self.dataset = dataset
64
+ self.shuffle = shuffle
65
+ self.batch_size = batch_size
66
+ nb = len(self.dataset) / batch_size
67
+ self.num_batches = math.floor(nb) if drop_last else math.ceil(nb)
68
+ if collate_fn is None:
69
+ self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, "collate_fn") else default_collate
70
+ else:
71
+ self.collate_fn = collate_fn
72
+ self.num_workers = num_workers
73
+ self.reset()
74
+
75
+ def __len__(self) -> int:
76
+ return self.num_batches
77
+
78
+ def reset(self) -> None:
79
+ # Updates indices after each epoch
80
+ self._num_yielded = 0
81
+ self.indices = np.arange(len(self.dataset))
82
+ if self.shuffle is True:
83
+ np.random.shuffle(self.indices)
84
+
85
+ def __iter__(self):
86
+ self.reset()
87
+ return self
88
+
89
+ def __next__(self):
90
+ if self._num_yielded < self.num_batches:
91
+ # Get next indices
92
+ idx = self._num_yielded * self.batch_size
93
+ indices = self.indices[idx : min(len(self.dataset), idx + self.batch_size)]
94
+
95
+ samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))
96
+
97
+ batch_data = self.collate_fn(samples)
98
+
99
+ self._num_yielded += 1
100
+ return batch_data
101
+ else:
102
+ raise StopIteration
doctr/datasets/mjsynth.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, List, Tuple
8
+
9
+ from tqdm import tqdm
10
+
11
+ from .datasets import AbstractDataset
12
+
13
+ __all__ = ["MJSynth"]
14
+
15
+
16
+ class MJSynth(AbstractDataset):
17
+ """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition"
18
+ <https://www.robots.ox.ac.uk/~vgg/data/text/>`_.
19
+
20
+ >>> # NOTE: This is a pure recognition dataset without bounding box labels.
21
+ >>> # NOTE: You need to download the dataset.
22
+ >>> from doctr.datasets import MJSynth
23
+ >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
24
+ >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt",
25
+ >>> train=True)
26
+ >>> img, target = train_set[0]
27
+ >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px",
28
+ >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt")
29
+ >>> train=False)
30
+ >>> img, target = test_set[0]
31
+
32
+ Args:
33
+ ----
34
+ img_folder: folder with all the images of the dataset
35
+ label_path: path to the file with the labels
36
+ train: whether the subset should be the training one
37
+ **kwargs: keyword arguments from `AbstractDataset`.
38
+ """
39
+
40
+ # filter corrupted or missing images
41
+ BLACKLIST = [
42
+ "./1881/4/225_Marbling_46673.jpg\n",
43
+ "./2069/4/192_whittier_86389.jpg\n",
44
+ "./869/4/234_TRIASSIC_80582.jpg\n",
45
+ "./173/2/358_BURROWING_10395.jpg\n",
46
+ "./913/4/231_randoms_62372.jpg\n",
47
+ "./596/2/372_Ump_81662.jpg\n",
48
+ "./936/2/375_LOCALITIES_44992.jpg\n",
49
+ "./2540/4/246_SQUAMOUS_73902.jpg\n",
50
+ "./1332/4/224_TETHERED_78397.jpg\n",
51
+ "./627/6/83_PATRIARCHATE_55931.jpg\n",
52
+ "./2013/2/370_refract_63890.jpg\n",
53
+ "./2911/6/77_heretical_35885.jpg\n",
54
+ "./1730/2/361_HEREON_35880.jpg\n",
55
+ "./2194/2/334_EFFLORESCENT_24742.jpg\n",
56
+ "./2025/2/364_SNORTERS_72304.jpg\n",
57
+ "./368/4/232_friar_30876.jpg\n",
58
+ "./275/6/96_hackle_34465.jpg\n",
59
+ "./384/4/220_bolts_8596.jpg\n",
60
+ "./905/4/234_Postscripts_59142.jpg\n",
61
+ "./2749/6/101_Chided_13155.jpg\n",
62
+ "./495/6/81_MIDYEAR_48332.jpg\n",
63
+ "./2852/6/60_TOILSOME_79481.jpg\n",
64
+ "./554/2/366_Teleconferences_77948.jpg\n",
65
+ "./1696/4/211_Queened_61779.jpg\n",
66
+ "./2128/2/369_REDACTED_63458.jpg\n",
67
+ "./2557/2/351_DOWN_23492.jpg\n",
68
+ "./2489/4/221_snored_72290.jpg\n",
69
+ "./1650/2/355_stony_74902.jpg\n",
70
+ "./1863/4/223_Diligently_21672.jpg\n",
71
+ "./264/2/362_FORETASTE_30276.jpg\n",
72
+ "./429/4/208_Mainmasts_46140.jpg\n",
73
+ "./1817/2/363_actuating_904.jpg\n",
74
+ ]
75
+
76
+ def __init__(
77
+ self,
78
+ img_folder: str,
79
+ label_path: str,
80
+ train: bool = True,
81
+ **kwargs: Any,
82
+ ) -> None:
83
+ super().__init__(img_folder, **kwargs)
84
+
85
+ # File existence check
86
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
87
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
88
+
89
+ self.data: List[Tuple[str, str]] = []
90
+ self.train = train
91
+
92
+ with open(label_path) as f:
93
+ img_paths = f.readlines()
94
+
95
+ train_samples = int(len(img_paths) * 0.9)
96
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
97
+
98
+ for path in tqdm(iterable=img_paths[set_slice], desc="Unpacking MJSynth", total=len(img_paths[set_slice])):
99
+ if path not in self.BLACKLIST:
100
+ label = path.split("_")[1]
101
+ img_path = os.path.join(img_folder, path[2:]).strip()
102
+
103
+ self.data.append((img_path, label))
104
+
105
+ def extra_repr(self) -> str:
106
+ return f"train={self.train}"
doctr/datasets/ocr.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ import numpy as np
12
+
13
+ from .datasets import AbstractDataset
14
+
15
+ __all__ = ["OCRDataset"]
16
+
17
+
18
+ class OCRDataset(AbstractDataset):
19
+ """Implements an OCR dataset
20
+
21
+ >>> from doctr.datasets import OCRDataset
22
+ >>> train_set = OCRDataset(img_folder="/path/to/images",
23
+ >>> label_file="/path/to/labels.json")
24
+ >>> img, target = train_set[0]
25
+
26
+ Args:
27
+ ----
28
+ img_folder: local path to image folder (all jpg at the root)
29
+ label_file: local path to the label file
30
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
31
+ **kwargs: keyword arguments from `AbstractDataset`.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ img_folder: str,
37
+ label_file: str,
38
+ use_polygons: bool = False,
39
+ **kwargs: Any,
40
+ ) -> None:
41
+ super().__init__(img_folder, **kwargs)
42
+
43
+ # List images
44
+ self.data: List[Tuple[str, Dict[str, Any]]] = []
45
+ np_dtype = np.float32
46
+ with open(label_file, "rb") as f:
47
+ data = json.load(f)
48
+
49
+ for img_name, annotations in data.items():
50
+ # Get image path
51
+ img_name = Path(img_name)
52
+ # File existence check
53
+ if not os.path.exists(os.path.join(self.root, img_name)):
54
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
55
+
56
+ # handle empty images
57
+ if len(annotations["typed_words"]) == 0:
58
+ self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[])))
59
+ continue
60
+ # Unpack the straight boxes (xmin, ymin, xmax, ymax)
61
+ geoms = [list(map(float, obj["geometry"][:4])) for obj in annotations["typed_words"]]
62
+ if use_polygons:
63
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
64
+ geoms = [
65
+ [geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item]
66
+ for geom in geoms
67
+ ]
68
+
69
+ text_targets = [obj["value"] for obj in annotations["typed_words"]]
70
+
71
+ self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets)))
doctr/datasets/orientation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, List, Tuple
8
+
9
+ import numpy as np
10
+
11
+ from .datasets import AbstractDataset
12
+
13
+ __all__ = ["OrientationDataset"]
14
+
15
+
16
+ class OrientationDataset(AbstractDataset):
17
+ """Implements a basic image dataset where targets are filled with zeros.
18
+
19
+ >>> from doctr.datasets import OrientationDataset
20
+ >>> train_set = OrientationDataset(img_folder="/path/to/images")
21
+ >>> img, target = train_set[0]
22
+
23
+ Args:
24
+ ----
25
+ img_folder: folder with all the images of the dataset
26
+ **kwargs: keyword arguments from `AbstractDataset`.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ img_folder: str,
32
+ **kwargs: Any,
33
+ ) -> None:
34
+ super().__init__(
35
+ img_folder,
36
+ **kwargs,
37
+ )
38
+
39
+ # initialize dataset with 0 degree rotation targets
40
+ self.data: List[Tuple[str, np.ndarray]] = [(img_name, np.array([0])) for img_name in os.listdir(self.root)]
doctr/datasets/recognition.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, List, Tuple
10
+
11
+ from .datasets import AbstractDataset
12
+
13
+ __all__ = ["RecognitionDataset"]
14
+
15
+
16
+ class RecognitionDataset(AbstractDataset):
17
+ """Dataset implementation for text recognition tasks
18
+
19
+ >>> from doctr.datasets import RecognitionDataset
20
+ >>> train_set = RecognitionDataset(img_folder="/path/to/images",
21
+ >>> labels_path="/path/to/labels.json")
22
+ >>> img, target = train_set[0]
23
+
24
+ Args:
25
+ ----
26
+ img_folder: path to the images folder
27
+ labels_path: pathe to the json file containing all labels (character sequences)
28
+ **kwargs: keyword arguments from `AbstractDataset`.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ img_folder: str,
34
+ labels_path: str,
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ super().__init__(img_folder, **kwargs)
38
+
39
+ self.data: List[Tuple[str, str]] = []
40
+ with open(labels_path, encoding="utf-8") as f:
41
+ labels = json.load(f)
42
+
43
+ for img_name, label in labels.items():
44
+ if not os.path.exists(os.path.join(self.root, img_name)):
45
+ raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")
46
+
47
+ self.data.append((img_name, label))
48
+
49
+ def merge_dataset(self, ds: AbstractDataset) -> None:
50
+ # Update data with new root for self
51
+ self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data]
52
+ # Define new root
53
+ self.root = Path("/")
54
+ # Merge with ds data
55
+ for img_path, label in ds.data:
56
+ self.data.append((str(Path(ds.root).joinpath(img_path)), label))
doctr/datasets/sroie.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import csv
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple, Union
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from .datasets import VisionDataset
15
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
16
+
17
+ __all__ = ["SROIE"]
18
+
19
+
20
+ class SROIE(VisionDataset):
21
+ """SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction"
22
+ <https://arxiv.org/pdf/2103.10213.pdf>`_.
23
+
24
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/sroie-grid.png&src=0
25
+ :align: center
26
+
27
+ >>> from doctr.datasets import SROIE
28
+ >>> train_set = SROIE(train=True, download=True)
29
+ >>> img, target = train_set[0]
30
+
31
+ Args:
32
+ ----
33
+ train: whether the subset should be the training one
34
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
35
+ recognition_task: whether the dataset should be used for recognition task
36
+ **kwargs: keyword arguments from `VisionDataset`.
37
+ """
38
+
39
+ TRAIN = (
40
+ "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_train_task1.zip&src=0",
41
+ "d4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f",
42
+ "sroie2019_train_task1.zip",
43
+ )
44
+ TEST = (
45
+ "https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_test.zip&src=0",
46
+ "41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2",
47
+ "sroie2019_test.zip",
48
+ )
49
+
50
+ def __init__(
51
+ self,
52
+ train: bool = True,
53
+ use_polygons: bool = False,
54
+ recognition_task: bool = False,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ url, sha256, name = self.TRAIN if train else self.TEST
58
+ super().__init__(
59
+ url,
60
+ name,
61
+ sha256,
62
+ True,
63
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
64
+ **kwargs,
65
+ )
66
+ self.train = train
67
+
68
+ tmp_root = os.path.join(self.root, "images")
69
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
70
+ np_dtype = np.float32
71
+
72
+ for img_path in tqdm(iterable=os.listdir(tmp_root), desc="Unpacking SROIE", total=len(os.listdir(tmp_root))):
73
+ # File existence check
74
+ if not os.path.exists(os.path.join(tmp_root, img_path)):
75
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}")
76
+
77
+ stem = Path(img_path).stem
78
+ with open(os.path.join(self.root, "annotations", f"{stem}.txt"), encoding="latin") as f:
79
+ _rows = [row for row in list(csv.reader(f, delimiter=",")) if len(row) > 0]
80
+
81
+ labels = [",".join(row[8:]) for row in _rows]
82
+ # reorder coordinates (8 -> (4,2) ->
83
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty lines
84
+ coords: np.ndarray = np.stack(
85
+ [np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2)) for row in _rows], axis=0
86
+ )
87
+
88
+ if not use_polygons:
89
+ # xmin, ymin, xmax, ymax
90
+ coords = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1)
91
+
92
+ if recognition_task:
93
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path), geoms=coords)
94
+ for crop, label in zip(crops, labels):
95
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
96
+ self.data.append((crop, label))
97
+ else:
98
+ self.data.append((img_path, dict(boxes=coords, labels=labels)))
99
+
100
+ self.root = tmp_root
101
+
102
+ def extra_repr(self) -> str:
103
+ return f"train={self.train}"
doctr/datasets/svhn.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ import h5py
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+ from .datasets import VisionDataset
14
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
15
+
16
+ __all__ = ["SVHN"]
17
+
18
+
19
+ class SVHN(VisionDataset):
20
+ """SVHN dataset from `"The Street View House Numbers (SVHN) Dataset"
21
+ <http://ufldl.stanford.edu/housenumbers/>`_.
22
+
23
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svhn-grid.png&src=0
24
+ :align: center
25
+
26
+ >>> from doctr.datasets import SVHN
27
+ >>> train_set = SVHN(train=True, download=True)
28
+ >>> img, target = train_set[0]
29
+
30
+ Args:
31
+ ----
32
+ train: whether the subset should be the training one
33
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
+ recognition_task: whether the dataset should be used for recognition task
35
+ **kwargs: keyword arguments from `VisionDataset`.
36
+ """
37
+
38
+ TRAIN = (
39
+ "http://ufldl.stanford.edu/housenumbers/train.tar.gz",
40
+ "4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898",
41
+ "svhn_train.tar",
42
+ )
43
+
44
+ TEST = (
45
+ "http://ufldl.stanford.edu/housenumbers/test.tar.gz",
46
+ "57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5",
47
+ "svhn_test.tar",
48
+ )
49
+
50
+ def __init__(
51
+ self,
52
+ train: bool = True,
53
+ use_polygons: bool = False,
54
+ recognition_task: bool = False,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ url, sha256, name = self.TRAIN if train else self.TEST
58
+ super().__init__(
59
+ url,
60
+ file_name=name,
61
+ file_hash=sha256,
62
+ extract_archive=True,
63
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
64
+ **kwargs,
65
+ )
66
+ self.train = train
67
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
68
+ np_dtype = np.float32
69
+
70
+ tmp_root = os.path.join(self.root, "train" if train else "test")
71
+
72
+ # Load mat data (matlab v7.3 - can not be loaded with scipy)
73
+ with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
74
+ img_refs = f["digitStruct/name"]
75
+ box_refs = f["digitStruct/bbox"]
76
+ for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc="Unpacking SVHN", total=len(img_refs)):
77
+ # convert ascii matrix to string
78
+ img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
79
+
80
+ # File existence check
81
+ if not os.path.exists(os.path.join(tmp_root, img_name)):
82
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
83
+
84
+ # Unpack the information
85
+ box = f[box_ref[0]]
86
+ if box["left"].shape[0] == 1:
87
+ box_dict = {k: [int(vals[0][0])] for k, vals in box.items()}
88
+ else:
89
+ box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()}
90
+
91
+ # Convert it to the right format
92
+ coords: np.ndarray = np.array(
93
+ [box_dict["left"], box_dict["top"], box_dict["width"], box_dict["height"]], dtype=np_dtype
94
+ ).transpose()
95
+ label_targets = list(map(str, box_dict["label"]))
96
+
97
+ if use_polygons:
98
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
99
+ box_targets: np.ndarray = np.stack(
100
+ [
101
+ np.stack([coords[:, 0], coords[:, 1]], axis=-1),
102
+ np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1),
103
+ np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1),
104
+ np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1),
105
+ ],
106
+ axis=1,
107
+ )
108
+ else:
109
+ # x, y, width, height -> xmin, ymin, xmax, ymax
110
+ box_targets = np.stack(
111
+ [
112
+ coords[:, 0],
113
+ coords[:, 1],
114
+ coords[:, 0] + coords[:, 2],
115
+ coords[:, 1] + coords[:, 3],
116
+ ],
117
+ axis=-1,
118
+ )
119
+
120
+ if recognition_task:
121
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
122
+ for crop, label in zip(crops, label_targets):
123
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
124
+ self.data.append((crop, label))
125
+ else:
126
+ self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))
127
+
128
+ self.root = tmp_root
129
+
130
+ def extra_repr(self) -> str:
131
+ return f"train={self.train}"
doctr/datasets/svt.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import os
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ import defusedxml.ElementTree as ET
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+ from .datasets import VisionDataset
14
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
15
+
16
+ __all__ = ["SVT"]
17
+
18
+
19
+ class SVT(VisionDataset):
20
+ """SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
21
+ <http://vision.ucsd.edu/~kai/svt/>`_.
22
+
23
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
24
+ :align: center
25
+
26
+ >>> from doctr.datasets import SVT
27
+ >>> train_set = SVT(train=True, download=True)
28
+ >>> img, target = train_set[0]
29
+
30
+ Args:
31
+ ----
32
+ train: whether the subset should be the training one
33
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
34
+ recognition_task: whether the dataset should be used for recognition task
35
+ **kwargs: keyword arguments from `VisionDataset`.
36
+ """
37
+
38
+ URL = "http://vision.ucsd.edu/~kai/svt/svt.zip"
39
+ SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
40
+
41
+ def __init__(
42
+ self,
43
+ train: bool = True,
44
+ use_polygons: bool = False,
45
+ recognition_task: bool = False,
46
+ **kwargs: Any,
47
+ ) -> None:
48
+ super().__init__(
49
+ self.URL,
50
+ None,
51
+ self.SHA256,
52
+ True,
53
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
54
+ **kwargs,
55
+ )
56
+ self.train = train
57
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
58
+ np_dtype = np.float32
59
+
60
+ # Load xml data
61
+ tmp_root = os.path.join(self.root, "svt1") if self.SHA256 else self.root
62
+ xml_tree = (
63
+ ET.parse(os.path.join(tmp_root, "train.xml"))
64
+ if self.train
65
+ else ET.parse(os.path.join(tmp_root, "test.xml"))
66
+ )
67
+ xml_root = xml_tree.getroot()
68
+
69
+ for image in tqdm(iterable=xml_root, desc="Unpacking SVT", total=len(xml_root)):
70
+ name, _, _, _resolution, rectangles = image
71
+
72
+ # File existence check
73
+ if not os.path.exists(os.path.join(tmp_root, name.text)):
74
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
75
+
76
+ if use_polygons:
77
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
78
+ _boxes = [
79
+ [
80
+ [float(rect.attrib["x"]), float(rect.attrib["y"])],
81
+ [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
82
+ [
83
+ float(rect.attrib["x"]) + float(rect.attrib["width"]),
84
+ float(rect.attrib["y"]) + float(rect.attrib["height"]),
85
+ ],
86
+ [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
87
+ ]
88
+ for rect in rectangles
89
+ ]
90
+ else:
91
+ # x_min, y_min, x_max, y_max
92
+ _boxes = [
93
+ [
94
+ float(rect.attrib["x"]), # type: ignore[list-item]
95
+ float(rect.attrib["y"]), # type: ignore[list-item]
96
+ float(rect.attrib["x"]) + float(rect.attrib["width"]), # type: ignore[list-item]
97
+ float(rect.attrib["y"]) + float(rect.attrib["height"]), # type: ignore[list-item]
98
+ ]
99
+ for rect in rectangles
100
+ ]
101
+
102
+ boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
103
+ # Get the labels
104
+ labels = [lab.text for rect in rectangles for lab in rect]
105
+
106
+ if recognition_task:
107
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
108
+ for crop, label in zip(crops, labels):
109
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
110
+ self.data.append((crop, label))
111
+ else:
112
+ self.data.append((name.text, dict(boxes=boxes, labels=labels)))
113
+
114
+ self.root = tmp_root
115
+
116
+ def extra_repr(self) -> str:
117
+ return f"train={self.train}"
doctr/datasets/synthtext.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import glob
7
+ import os
8
+ from typing import Any, Dict, List, Tuple, Union
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+ from scipy import io as sio
13
+ from tqdm import tqdm
14
+
15
+ from .datasets import VisionDataset
16
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
17
+
18
+ __all__ = ["SynthText"]
19
+
20
+
21
+ class SynthText(VisionDataset):
22
+ """SynthText dataset from `"Synthetic Data for Text Localisation in Natural Images"
23
+ <https://arxiv.org/abs/1604.06646>`_ | `"repository" <https://github.com/ankush-me/SynthText>`_ |
24
+ `"website" <https://www.robots.ox.ac.uk/~vgg/data/scenetext/>`_.
25
+
26
+ .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
27
+ :align: center
28
+
29
+ >>> from doctr.datasets import SynthText
30
+ >>> train_set = SynthText(train=True, download=True)
31
+ >>> img, target = train_set[0]
32
+
33
+ Args:
34
+ ----
35
+ train: whether the subset should be the training one
36
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
37
+ recognition_task: whether the dataset should be used for recognition task
38
+ **kwargs: keyword arguments from `VisionDataset`.
39
+ """
40
+
41
+ URL = "https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip"
42
+ SHA256 = "28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1"
43
+
44
+ def __init__(
45
+ self,
46
+ train: bool = True,
47
+ use_polygons: bool = False,
48
+ recognition_task: bool = False,
49
+ **kwargs: Any,
50
+ ) -> None:
51
+ super().__init__(
52
+ self.URL,
53
+ None,
54
+ file_hash=None,
55
+ extract_archive=True,
56
+ pre_transforms=convert_target_to_relative if not recognition_task else None,
57
+ **kwargs,
58
+ )
59
+ self.train = train
60
+ self.data: List[Tuple[Union[str, np.ndarray], Union[str, Dict[str, Any]]]] = []
61
+ np_dtype = np.float32
62
+
63
+ # Load mat data
64
+ tmp_root = os.path.join(self.root, "SynthText") if self.SHA256 else self.root
65
+ # define folder to write SynthText recognition dataset
66
+ reco_folder_name = "SynthText_recognition_train" if self.train else "SynthText_recognition_test"
67
+ reco_folder_name = "Poly_" + reco_folder_name if use_polygons else reco_folder_name
68
+ reco_folder_path = os.path.join(tmp_root, reco_folder_name)
69
+ reco_images_counter = 0
70
+
71
+ if recognition_task and os.path.isdir(reco_folder_path):
72
+ self._read_from_folder(reco_folder_path)
73
+ return
74
+ elif recognition_task and not os.path.isdir(reco_folder_path):
75
+ os.makedirs(reco_folder_path, exist_ok=False)
76
+
77
+ mat_data = sio.loadmat(os.path.join(tmp_root, "gt.mat"))
78
+ train_samples = int(len(mat_data["imnames"][0]) * 0.9)
79
+ set_slice = slice(train_samples) if self.train else slice(train_samples, None)
80
+ paths = mat_data["imnames"][0][set_slice]
81
+ boxes = mat_data["wordBB"][0][set_slice]
82
+ labels = mat_data["txt"][0][set_slice]
83
+ del mat_data
84
+
85
+ for img_path, word_boxes, txt in tqdm(
86
+ iterable=zip(paths, boxes, labels), desc="Unpacking SynthText", total=len(paths)
87
+ ):
88
+ # File existence check
89
+ if not os.path.exists(os.path.join(tmp_root, img_path[0])):
90
+ raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}")
91
+
92
+ labels = [elt for word in txt.tolist() for elt in word.split()]
93
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
94
+ word_boxes = (
95
+ word_boxes.transpose(2, 1, 0)
96
+ if word_boxes.ndim == 3
97
+ else np.expand_dims(word_boxes.transpose(1, 0), axis=0)
98
+ )
99
+
100
+ if not use_polygons:
101
+ # xmin, ymin, xmax, ymax
102
+ word_boxes = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1)
103
+
104
+ if recognition_task:
105
+ crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_path[0]), geoms=word_boxes)
106
+ for crop, label in zip(crops, labels):
107
+ if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0:
108
+ # write data to disk
109
+ with open(os.path.join(reco_folder_path, f"{reco_images_counter}.txt"), "w") as f:
110
+ f.write(label)
111
+ tmp_img = Image.fromarray(crop)
112
+ tmp_img.save(os.path.join(reco_folder_path, f"{reco_images_counter}.png"))
113
+ reco_images_counter += 1
114
+ else:
115
+ self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels)))
116
+
117
+ if recognition_task:
118
+ self._read_from_folder(reco_folder_path)
119
+
120
+ self.root = tmp_root
121
+
122
+ def extra_repr(self) -> str:
123
+ return f"train={self.train}"
124
+
125
+ def _read_from_folder(self, path: str) -> None:
126
+ for img_path in glob.glob(os.path.join(path, "*.png")):
127
+ with open(os.path.join(path, f"{os.path.basename(img_path)[:-4]}.txt"), "r") as f:
128
+ self.data.append((img_path, f.read()))
doctr/datasets/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import string
7
+ import unicodedata
8
+ from collections.abc import Sequence
9
+ from functools import partial
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
12
+ from typing import Sequence as SequenceType
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+ from doctr.io.image import get_img_shape
18
+ from doctr.utils.geometry import convert_to_relative_coords, extract_crops, extract_rcrops
19
+
20
+ from .vocabs import VOCABS
21
+
22
+ __all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"]
23
+
24
+ ImageTensor = TypeVar("ImageTensor")
25
+
26
+
27
+ def translate(
28
+ input_string: str,
29
+ vocab_name: str,
30
+ unknown_char: str = "■",
31
+ ) -> str:
32
+ """Translate a string input in a given vocabulary
33
+
34
+ Args:
35
+ ----
36
+ input_string: input string to translate
37
+ vocab_name: vocabulary to use (french, latin, ...)
38
+ unknown_char: unknown character for non-translatable characters
39
+
40
+ Returns:
41
+ -------
42
+ A string translated in a given vocab
43
+ """
44
+ if VOCABS.get(vocab_name) is None:
45
+ raise KeyError("output vocabulary must be in vocabs dictionnary")
46
+
47
+ translated = ""
48
+ for char in input_string:
49
+ if char not in VOCABS[vocab_name]:
50
+ # we need to translate char into a vocab char
51
+ if char in string.whitespace:
52
+ # remove whitespaces
53
+ continue
54
+ # normalize character if it is not in vocab
55
+ char = unicodedata.normalize("NFD", char).encode("ascii", "ignore").decode("ascii")
56
+ if char == "" or char not in VOCABS[vocab_name]:
57
+ # if normalization fails or char still not in vocab, return unknown character)
58
+ char = unknown_char
59
+ translated += char
60
+ return translated
61
+
62
+
63
+ def encode_string(
64
+ input_string: str,
65
+ vocab: str,
66
+ ) -> List[int]:
67
+ """Given a predefined mapping, encode the string to a sequence of numbers
68
+
69
+ Args:
70
+ ----
71
+ input_string: string to encode
72
+ vocab: vocabulary (string), the encoding is given by the indexing of the character sequence
73
+
74
+ Returns:
75
+ -------
76
+ A list encoding the input_string
77
+ """
78
+ try:
79
+ return list(map(vocab.index, input_string))
80
+ except ValueError:
81
+ raise ValueError(
82
+ f"some characters cannot be found in 'vocab'. \
83
+ Please check the input string {input_string} and the vocabulary {vocab}"
84
+ )
85
+
86
+
87
+ def decode_sequence(
88
+ input_seq: Union[np.ndarray, SequenceType[int]],
89
+ mapping: str,
90
+ ) -> str:
91
+ """Given a predefined mapping, decode the sequence of numbers to a string
92
+
93
+ Args:
94
+ ----
95
+ input_seq: array to decode
96
+ mapping: vocabulary (string), the encoding is given by the indexing of the character sequence
97
+
98
+ Returns:
99
+ -------
100
+ A string, decoded from input_seq
101
+ """
102
+ if not isinstance(input_seq, (Sequence, np.ndarray)):
103
+ raise TypeError("Invalid sequence type")
104
+ if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)):
105
+ raise AssertionError("Input must be an array of int, with max less than mapping size")
106
+
107
+ return "".join(map(mapping.__getitem__, input_seq))
108
+
109
+
110
+ def encode_sequences(
111
+ sequences: List[str],
112
+ vocab: str,
113
+ target_size: Optional[int] = None,
114
+ eos: int = -1,
115
+ sos: Optional[int] = None,
116
+ pad: Optional[int] = None,
117
+ dynamic_seq_length: bool = False,
118
+ ) -> np.ndarray:
119
+ """Encode character sequences using a given vocab as mapping
120
+
121
+ Args:
122
+ ----
123
+ sequences: the list of character sequences of size N
124
+ vocab: the ordered vocab to use for encoding
125
+ target_size: maximum length of the encoded data
126
+ eos: encoding of End Of String
127
+ sos: optional encoding of Start Of String
128
+ pad: optional encoding for padding. In case of padding, all sequences are followed by 1 EOS then PAD
129
+ dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size
130
+
131
+ Returns:
132
+ -------
133
+ the padded encoded data as a tensor
134
+ """
135
+ if 0 <= eos < len(vocab):
136
+ raise ValueError("argument 'eos' needs to be outside of vocab possible indices")
137
+
138
+ if not isinstance(target_size, int) or dynamic_seq_length:
139
+ # Maximum string length + EOS
140
+ max_length = max(len(w) for w in sequences) + 1
141
+ if isinstance(sos, int):
142
+ max_length += 1
143
+ if isinstance(pad, int):
144
+ max_length += 1
145
+ target_size = max_length if not isinstance(target_size, int) else min(max_length, target_size)
146
+
147
+ # Pad all sequences
148
+ if isinstance(pad, int): # pad with padding symbol
149
+ if 0 <= pad < len(vocab):
150
+ raise ValueError("argument 'pad' needs to be outside of vocab possible indices")
151
+ # In that case, add EOS at the end of the word before padding
152
+ default_symbol = pad
153
+ else: # pad with eos symbol
154
+ default_symbol = eos
155
+ encoded_data: np.ndarray = np.full([len(sequences), target_size], default_symbol, dtype=np.int32)
156
+
157
+ # Encode the strings
158
+ for idx, seq in enumerate(map(partial(encode_string, vocab=vocab), sequences)):
159
+ if isinstance(pad, int): # add eos at the end of the sequence
160
+ seq.append(eos)
161
+ encoded_data[idx, : min(len(seq), target_size)] = seq[: min(len(seq), target_size)]
162
+
163
+ if isinstance(sos, int): # place sos symbol at the beginning of each sequence
164
+ if 0 <= sos < len(vocab):
165
+ raise ValueError("argument 'sos' needs to be outside of vocab possible indices")
166
+ encoded_data = np.roll(encoded_data, 1)
167
+ encoded_data[:, 0] = sos
168
+
169
+ return encoded_data
170
+
171
+
172
+ def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]:
173
+ target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img))
174
+ return img, target
175
+
176
+
177
+ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> List[np.ndarray]:
178
+ """Crop a set of bounding boxes from an image
179
+
180
+ Args:
181
+ ----
182
+ img_path: path to the image
183
+ geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4)
184
+
185
+ Returns:
186
+ -------
187
+ a list of cropped images
188
+ """
189
+ img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
190
+ # Polygon
191
+ if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
192
+ return extract_rcrops(img, geoms.astype(dtype=int))
193
+ if geoms.ndim == 2 and geoms.shape[1] == 4:
194
+ return extract_crops(img, geoms.astype(dtype=int))
195
+ raise ValueError("Invalid geometry format")
196
+
197
+
198
+ def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.ndarray, Dict[str, List]]:
199
+ """Converts multiclass target to relative coordinates.
200
+
201
+ Args:
202
+ ----
203
+ img: Image
204
+ target: tuple of target polygons and their classes names
205
+
206
+ Returns:
207
+ -------
208
+ Image and dictionary of boxes, with class names as keys
209
+ """
210
+ boxes = convert_to_relative_coords(target[0], get_img_shape(img))
211
+ boxes_classes = target[1]
212
+ boxes_dict: Dict = {k: [] for k in sorted(set(boxes_classes))}
213
+ for k, poly in zip(boxes_classes, boxes):
214
+ boxes_dict[k].append(poly)
215
+ boxes_dict = {k: np.stack(v, axis=0) for k, v in boxes_dict.items()}
216
+ return img, boxes_dict
doctr/datasets/vocabs.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import string
7
+ from typing import Dict
8
+
9
+ __all__ = ["VOCABS"]
10
+
11
+
12
+ VOCABS: Dict[str, str] = {
13
+ "digits": string.digits,
14
+ "ascii_letters": string.ascii_letters,
15
+ "punctuation": string.punctuation,
16
+ "currency": "£€¥¢฿",
17
+ "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
18
+ "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
19
+ "persian_letters": "پچڢڤگ",
20
+ "hindi_digits": "٠١٢٣٤٥٦٧٨٩",
21
+ "arabic_diacritics": "ًٌٍَُِّْ",
22
+ "arabic_punctuation": "؟؛«»—",
23
+ }
24
+
25
+ VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
26
+ VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"]
27
+ VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"]
28
+ VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ"
29
+ VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ"
30
+ VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿"
31
+ VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙÚ"
32
+ VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
33
+ VOCABS["arabic"] = (
34
+ VOCABS["digits"]
35
+ + VOCABS["hindi_digits"]
36
+ + VOCABS["arabic_letters"]
37
+ + VOCABS["persian_letters"]
38
+ + VOCABS["arabic_diacritics"]
39
+ + VOCABS["arabic_punctuation"]
40
+ + VOCABS["punctuation"]
41
+ )
42
+ VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ"
43
+ VOCABS["polish"] = VOCABS["english"] + "ąćęłńóśźżĄĆĘŁŃÓŚŹŻ"
44
+ VOCABS["dutch"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ"
45
+ VOCABS["norwegian"] = VOCABS["english"] + "æøåÆØÅ"
46
+ VOCABS["danish"] = VOCABS["english"] + "æøåÆØÅ"
47
+ VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
48
+ VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ"
49
+ VOCABS["vietnamese"] = (
50
+ VOCABS["english"]
51
+ + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
52
+ + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
53
+ )
54
+ VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
55
+ VOCABS["multilingual"] = "".join(
56
+ dict.fromkeys(
57
+ VOCABS["french"]
58
+ + VOCABS["portuguese"]
59
+ + VOCABS["spanish"]
60
+ + VOCABS["german"]
61
+ + VOCABS["czech"]
62
+ + VOCABS["polish"]
63
+ + VOCABS["dutch"]
64
+ + VOCABS["italian"]
65
+ + VOCABS["norwegian"]
66
+ + VOCABS["danish"]
67
+ + VOCABS["finnish"]
68
+ + VOCABS["swedish"]
69
+ + "§"
70
+ )
71
+ )
doctr/datasets/wildreceipt.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Tuple, Union
10
+
11
+ import numpy as np
12
+
13
+ from .datasets import AbstractDataset
14
+ from .utils import convert_target_to_relative, crop_bboxes_from_image
15
+
16
+ __all__ = ["WILDRECEIPT"]
17
+
18
+
19
+ class WILDRECEIPT(AbstractDataset):
20
+ """WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction"
21
+ <https://arxiv.org/abs/2103.14470v1>`_ |
22
+ `repository <https://download.openmmlab.com/mmocr/data/wildreceipt.tar>`_.
23
+
24
+ .. image:: https://doctr-static.mindee.com/models?id=v0.7.0/wildreceipt-dataset.jpg&src=0
25
+ :align: center
26
+
27
+ >>> # NOTE: You need to download the dataset first.
28
+ >>> from doctr.datasets import WILDRECEIPT
29
+ >>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/",
30
+ >>> label_path="/path/to/wildreceipt/train.txt")
31
+ >>> img, target = train_set[0]
32
+ >>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/",
33
+ >>> label_path="/path/to/wildreceipt/test.txt")
34
+ >>> img, target = test_set[0]
35
+
36
+ Args:
37
+ ----
38
+ img_folder: folder with all the images of the dataset
39
+ label_path: path to the annotations file of the dataset
40
+ train: whether the subset should be the training one
41
+ use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
42
+ recognition_task: whether the dataset should be used for recognition task
43
+ **kwargs: keyword arguments from `AbstractDataset`.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ img_folder: str,
49
+ label_path: str,
50
+ train: bool = True,
51
+ use_polygons: bool = False,
52
+ recognition_task: bool = False,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(
56
+ img_folder, pre_transforms=convert_target_to_relative if not recognition_task else None, **kwargs
57
+ )
58
+ # File existence check
59
+ if not os.path.exists(label_path) or not os.path.exists(img_folder):
60
+ raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")
61
+
62
+ tmp_root = img_folder
63
+ self.train = train
64
+ np_dtype = np.float32
65
+ self.data: List[Tuple[Union[str, Path, np.ndarray], Union[str, Dict[str, Any]]]] = []
66
+
67
+ with open(label_path, "r") as file:
68
+ data = file.read()
69
+ # Split the text file into separate JSON strings
70
+ json_strings = data.strip().split("\n")
71
+ box: Union[List[float], np.ndarray]
72
+ _targets = []
73
+ for json_string in json_strings:
74
+ json_data = json.loads(json_string)
75
+ img_path = json_data["file_name"]
76
+ annotations = json_data["annotations"]
77
+ for annotation in annotations:
78
+ coordinates = annotation["box"]
79
+ if use_polygons:
80
+ # (x, y) coordinates of top left, top right, bottom right, bottom left corners
81
+ box = np.array(
82
+ [
83
+ [coordinates[0], coordinates[1]],
84
+ [coordinates[2], coordinates[3]],
85
+ [coordinates[4], coordinates[5]],
86
+ [coordinates[6], coordinates[7]],
87
+ ],
88
+ dtype=np_dtype,
89
+ )
90
+ else:
91
+ x, y = coordinates[::2], coordinates[1::2]
92
+ box = [min(x), min(y), max(x), max(y)]
93
+ _targets.append((annotation["text"], box))
94
+ text_targets, box_targets = zip(*_targets)
95
+
96
+ if recognition_task:
97
+ crops = crop_bboxes_from_image(
98
+ img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0)
99
+ )
100
+ for crop, label in zip(crops, list(text_targets)):
101
+ if label and " " not in label:
102
+ self.data.append((crop, label))
103
+ else:
104
+ self.data.append((
105
+ img_path,
106
+ dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)),
107
+ ))
108
+ self.root = tmp_root
109
+
110
+ def extra_repr(self) -> str:
111
+ return f"train={self.train}"
doctr/file_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
7
+
8
+ import importlib.util
9
+ import logging
10
+ import os
11
+ import sys
12
+
13
+ CLASS_NAME: str = "words"
14
+
15
+
16
+ if sys.version_info < (3, 8): # pragma: no cover
17
+ import importlib_metadata
18
+ else:
19
+ import importlib.metadata as importlib_metadata
20
+
21
+
22
+ __all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"]
23
+
24
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
25
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
26
+
27
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
28
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
29
+
30
+
31
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
32
+ _torch_available = importlib.util.find_spec("torch") is not None
33
+ if _torch_available:
34
+ try:
35
+ _torch_version = importlib_metadata.version("torch")
36
+ logging.info(f"PyTorch version {_torch_version} available.")
37
+ except importlib_metadata.PackageNotFoundError: # pragma: no cover
38
+ _torch_available = False
39
+ else: # pragma: no cover
40
+ logging.info("Disabling PyTorch because USE_TF is set")
41
+ _torch_available = False
42
+
43
+
44
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
45
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
46
+ if _tf_available:
47
+ candidates = (
48
+ "tensorflow",
49
+ "tensorflow-cpu",
50
+ "tensorflow-gpu",
51
+ "tf-nightly",
52
+ "tf-nightly-cpu",
53
+ "tf-nightly-gpu",
54
+ "intel-tensorflow",
55
+ "tensorflow-rocm",
56
+ "tensorflow-macos",
57
+ )
58
+ _tf_version = None
59
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
60
+ for pkg in candidates:
61
+ try:
62
+ _tf_version = importlib_metadata.version(pkg)
63
+ break
64
+ except importlib_metadata.PackageNotFoundError:
65
+ pass
66
+ _tf_available = _tf_version is not None
67
+ if _tf_available:
68
+ if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover
69
+ logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.")
70
+ _tf_available = False
71
+ else:
72
+ logging.info(f"TensorFlow version {_tf_version} available.")
73
+ else: # pragma: no cover
74
+ logging.info("Disabling Tensorflow because USE_TORCH is set")
75
+ _tf_available = False
76
+
77
+
78
+ if not _torch_available and not _tf_available: # pragma: no cover
79
+ raise ModuleNotFoundError(
80
+ "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them"
81
+ " is installed and that either USE_TF or USE_TORCH is enabled."
82
+ )
83
+
84
+
85
+ def is_torch_available():
86
+ """Whether PyTorch is installed."""
87
+ return _torch_available
88
+
89
+
90
+ def is_tf_available():
91
+ """Whether TensorFlow is installed."""
92
+ return _tf_available
doctr/io/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .elements import *
2
+ from .html import *
3
+ from .image import *
4
+ from .pdf import *
5
+ from .reader import *
doctr/io/elements.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ from defusedxml import defuse_stdlib
9
+
10
+ defuse_stdlib()
11
+ from xml.etree import ElementTree as ET
12
+ from xml.etree.ElementTree import Element as ETElement
13
+ from xml.etree.ElementTree import SubElement
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+ import doctr
19
+ from doctr.utils.common_types import BoundingBox
20
+ from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox
21
+ from doctr.utils.repr import NestedObject
22
+ from doctr.utils.visualization import synthesize_kie_page, synthesize_page, visualize_kie_page, visualize_page
23
+
24
+ __all__ = ["Element", "Word", "Artefact", "Line", "Prediction", "Block", "Page", "KIEPage", "Document"]
25
+
26
+
27
+ class Element(NestedObject):
28
+ """Implements an abstract document element with exporting and text rendering capabilities"""
29
+
30
+ _children_names: List[str] = []
31
+ _exported_keys: List[str] = []
32
+
33
+ def __init__(self, **kwargs: Any) -> None:
34
+ for k, v in kwargs.items():
35
+ if k in self._children_names:
36
+ setattr(self, k, v)
37
+ else:
38
+ raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'")
39
+
40
+ def export(self) -> Dict[str, Any]:
41
+ """Exports the object into a nested dict format"""
42
+ export_dict = {k: getattr(self, k) for k in self._exported_keys}
43
+ for children_name in self._children_names:
44
+ if children_name in ["predictions"]:
45
+ export_dict[children_name] = {
46
+ k: [item.export() for item in c] for k, c in getattr(self, children_name).items()
47
+ }
48
+ else:
49
+ export_dict[children_name] = [c.export() for c in getattr(self, children_name)]
50
+
51
+ return export_dict
52
+
53
+ @classmethod
54
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
55
+ raise NotImplementedError
56
+
57
+ def render(self) -> str:
58
+ raise NotImplementedError
59
+
60
+
61
+ class Word(Element):
62
+ """Implements a word element
63
+
64
+ Args:
65
+ ----
66
+ value: the text string of the word
67
+ confidence: the confidence associated with the text prediction
68
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
69
+ the page's size
70
+ crop_orientation: the general orientation of the crop in degrees and its confidence
71
+ """
72
+
73
+ _exported_keys: List[str] = ["value", "confidence", "geometry", "crop_orientation"]
74
+ _children_names: List[str] = []
75
+
76
+ def __init__(
77
+ self,
78
+ value: str,
79
+ confidence: float,
80
+ geometry: Union[BoundingBox, np.ndarray],
81
+ crop_orientation: Dict[str, Any],
82
+ ) -> None:
83
+ super().__init__()
84
+ self.value = value
85
+ self.confidence = confidence
86
+ self.geometry = geometry
87
+ self.crop_orientation = crop_orientation
88
+
89
+ def render(self) -> str:
90
+ """Renders the full text of the element"""
91
+ return self.value
92
+
93
+ def extra_repr(self) -> str:
94
+ return f"value='{self.value}', confidence={self.confidence:.2}"
95
+
96
+ @classmethod
97
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
98
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
99
+ return cls(**kwargs)
100
+
101
+
102
+ class Artefact(Element):
103
+ """Implements a non-textual element
104
+
105
+ Args:
106
+ ----
107
+ artefact_type: the type of artefact
108
+ confidence: the confidence of the type prediction
109
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
110
+ the page's size.
111
+ """
112
+
113
+ _exported_keys: List[str] = ["geometry", "type", "confidence"]
114
+ _children_names: List[str] = []
115
+
116
+ def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None:
117
+ super().__init__()
118
+ self.geometry = geometry
119
+ self.type = artefact_type
120
+ self.confidence = confidence
121
+
122
+ def render(self) -> str:
123
+ """Renders the full text of the element"""
124
+ return f"[{self.type.upper()}]"
125
+
126
+ def extra_repr(self) -> str:
127
+ return f"type='{self.type}', confidence={self.confidence:.2}"
128
+
129
+ @classmethod
130
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
131
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
132
+ return cls(**kwargs)
133
+
134
+
135
+ class Line(Element):
136
+ """Implements a line element as a collection of words
137
+
138
+ Args:
139
+ ----
140
+ words: list of word elements
141
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
142
+ the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
143
+ all words in it.
144
+ """
145
+
146
+ _exported_keys: List[str] = ["geometry"]
147
+ _children_names: List[str] = ["words"]
148
+ words: List[Word] = []
149
+
150
+ def __init__(
151
+ self,
152
+ words: List[Word],
153
+ geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
154
+ ) -> None:
155
+ # Resolve the geometry using the smallest enclosing bounding box
156
+ if geometry is None:
157
+ # Check whether this is a rotated or straight box
158
+ box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox
159
+ geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator]
160
+
161
+ super().__init__(words=words)
162
+ self.geometry = geometry
163
+
164
+ def render(self) -> str:
165
+ """Renders the full text of the element"""
166
+ return " ".join(w.render() for w in self.words)
167
+
168
+ @classmethod
169
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
170
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
171
+ kwargs.update({
172
+ "words": [Word.from_dict(_dict) for _dict in save_dict["words"]],
173
+ })
174
+ return cls(**kwargs)
175
+
176
+
177
+ class Prediction(Word):
178
+ """Implements a prediction element"""
179
+
180
+ def render(self) -> str:
181
+ """Renders the full text of the element"""
182
+ return self.value
183
+
184
+ def extra_repr(self) -> str:
185
+ return f"value='{self.value}', confidence={self.confidence:.2}, bounding_box={self.geometry}"
186
+
187
+
188
+ class Block(Element):
189
+ """Implements a block element as a collection of lines and artefacts
190
+
191
+ Args:
192
+ ----
193
+ lines: list of line elements
194
+ artefacts: list of artefacts
195
+ geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to
196
+ the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing
197
+ all lines and artefacts in it.
198
+ """
199
+
200
+ _exported_keys: List[str] = ["geometry"]
201
+ _children_names: List[str] = ["lines", "artefacts"]
202
+ lines: List[Line] = []
203
+ artefacts: List[Artefact] = []
204
+
205
+ def __init__(
206
+ self,
207
+ lines: List[Line] = [],
208
+ artefacts: List[Artefact] = [],
209
+ geometry: Optional[Union[BoundingBox, np.ndarray]] = None,
210
+ ) -> None:
211
+ # Resolve the geometry using the smallest enclosing bounding box
212
+ if geometry is None:
213
+ line_boxes = [word.geometry for line in lines for word in line.words]
214
+ artefact_boxes = [artefact.geometry for artefact in artefacts]
215
+ box_resolution_fn = (
216
+ resolve_enclosing_rbbox if isinstance(lines[0].geometry, np.ndarray) else resolve_enclosing_bbox
217
+ )
218
+ geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator]
219
+
220
+ super().__init__(lines=lines, artefacts=artefacts)
221
+ self.geometry = geometry
222
+
223
+ def render(self, line_break: str = "\n") -> str:
224
+ """Renders the full text of the element"""
225
+ return line_break.join(line.render() for line in self.lines)
226
+
227
+ @classmethod
228
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
229
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
230
+ kwargs.update({
231
+ "lines": [Line.from_dict(_dict) for _dict in save_dict["lines"]],
232
+ "artefacts": [Artefact.from_dict(_dict) for _dict in save_dict["artefacts"]],
233
+ })
234
+ return cls(**kwargs)
235
+
236
+
237
+ class Page(Element):
238
+ """Implements a page element as a collection of blocks
239
+
240
+ Args:
241
+ ----
242
+ page: image encoded as a numpy array in uint8
243
+ blocks: list of block elements
244
+ page_idx: the index of the page in the input raw document
245
+ dimensions: the page size in pixels in format (height, width)
246
+ orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
247
+ language: a dictionary with the language value and confidence of the prediction
248
+ """
249
+
250
+ _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
251
+ _children_names: List[str] = ["blocks"]
252
+ blocks: List[Block] = []
253
+
254
+ def __init__(
255
+ self,
256
+ page: np.ndarray,
257
+ blocks: List[Block],
258
+ page_idx: int,
259
+ dimensions: Tuple[int, int],
260
+ orientation: Optional[Dict[str, Any]] = None,
261
+ language: Optional[Dict[str, Any]] = None,
262
+ ) -> None:
263
+ super().__init__(blocks=blocks)
264
+ self.page = page
265
+ self.page_idx = page_idx
266
+ self.dimensions = dimensions
267
+ self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
268
+ self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
269
+
270
+ def render(self, block_break: str = "\n\n") -> str:
271
+ """Renders the full text of the element"""
272
+ return block_break.join(b.render() for b in self.blocks)
273
+
274
+ def extra_repr(self) -> str:
275
+ return f"dimensions={self.dimensions}"
276
+
277
+ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
278
+ """Overlay the result on a given image
279
+
280
+ Args:
281
+ interactive: whether the display should be interactive
282
+ preserve_aspect_ratio: pass True if you passed True to the predictor
283
+ **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method
284
+ """
285
+ visualize_page(self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio)
286
+ plt.show(**kwargs)
287
+
288
+ def synthesize(self, **kwargs) -> np.ndarray:
289
+ """Synthesize the page from the predictions
290
+
291
+ Returns
292
+ -------
293
+ synthesized page
294
+ """
295
+ return synthesize_page(self.export(), **kwargs)
296
+
297
+ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
298
+ """Export the page as XML (hOCR-format)
299
+ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
300
+
301
+ Args:
302
+ ----
303
+ file_title: the title of the XML file
304
+
305
+ Returns:
306
+ -------
307
+ a tuple of the XML byte string, and its ElementTree
308
+ """
309
+ p_idx = self.page_idx
310
+ block_count: int = 1
311
+ line_count: int = 1
312
+ word_count: int = 1
313
+ height, width = self.dimensions
314
+ language = self.language if "language" in self.language.keys() else "en"
315
+ # Create the XML root element
316
+ page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
317
+ # Create the header / SubElements of the root element
318
+ head = SubElement(page_hocr, "head")
319
+ SubElement(head, "title").text = file_title
320
+ SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
321
+ SubElement(
322
+ head,
323
+ "meta",
324
+ attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
325
+ )
326
+ SubElement(
327
+ head,
328
+ "meta",
329
+ attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
330
+ )
331
+ # Create the body
332
+ body = SubElement(page_hocr, "body")
333
+ SubElement(
334
+ body,
335
+ "div",
336
+ attrib={
337
+ "class": "ocr_page",
338
+ "id": f"page_{p_idx + 1}",
339
+ "title": f"image; bbox 0 0 {width} {height}; ppageno 0",
340
+ },
341
+ )
342
+ # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
343
+ for block in self.blocks:
344
+ if len(block.geometry) != 2:
345
+ raise TypeError("XML export is only available for straight bounding boxes for now.")
346
+ (xmin, ymin), (xmax, ymax) = block.geometry
347
+ block_div = SubElement(
348
+ body,
349
+ "div",
350
+ attrib={
351
+ "class": "ocr_carea",
352
+ "id": f"block_{block_count}",
353
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
354
+ {int(round(xmax * width))} {int(round(ymax * height))}",
355
+ },
356
+ )
357
+ paragraph = SubElement(
358
+ block_div,
359
+ "p",
360
+ attrib={
361
+ "class": "ocr_par",
362
+ "id": f"par_{block_count}",
363
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
364
+ {int(round(xmax * width))} {int(round(ymax * height))}",
365
+ },
366
+ )
367
+ block_count += 1
368
+ for line in block.lines:
369
+ (xmin, ymin), (xmax, ymax) = line.geometry
370
+ # NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0
371
+ line_span = SubElement(
372
+ paragraph,
373
+ "span",
374
+ attrib={
375
+ "class": "ocr_line",
376
+ "id": f"line_{line_count}",
377
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
378
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
379
+ baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0",
380
+ },
381
+ )
382
+ line_count += 1
383
+ for word in line.words:
384
+ (xmin, ymin), (xmax, ymax) = word.geometry
385
+ conf = word.confidence
386
+ word_div = SubElement(
387
+ line_span,
388
+ "span",
389
+ attrib={
390
+ "class": "ocrx_word",
391
+ "id": f"word_{word_count}",
392
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
393
+ {int(round(xmax * width))} {int(round(ymax * height))}; \
394
+ x_wconf {int(round(conf * 100))}",
395
+ },
396
+ )
397
+ # set the text
398
+ word_div.text = word.value
399
+ word_count += 1
400
+
401
+ return (ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr))
402
+
403
+ @classmethod
404
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
405
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
406
+ kwargs.update({"blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]]})
407
+ return cls(**kwargs)
408
+
409
+
410
+ class KIEPage(Element):
411
+ """Implements a KIE page element as a collection of predictions
412
+
413
+ Args:
414
+ ----
415
+ predictions: Dictionary with list of block elements for each detection class
416
+ page: image encoded as a numpy array in uint8
417
+ page_idx: the index of the page in the input raw document
418
+ dimensions: the page size in pixels in format (height, width)
419
+ orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction
420
+ language: a dictionary with the language value and confidence of the prediction
421
+ """
422
+
423
+ _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"]
424
+ _children_names: List[str] = ["predictions"]
425
+ predictions: Dict[str, List[Prediction]] = {}
426
+
427
+ def __init__(
428
+ self,
429
+ page: np.ndarray,
430
+ predictions: Dict[str, List[Prediction]],
431
+ page_idx: int,
432
+ dimensions: Tuple[int, int],
433
+ orientation: Optional[Dict[str, Any]] = None,
434
+ language: Optional[Dict[str, Any]] = None,
435
+ ) -> None:
436
+ super().__init__(predictions=predictions)
437
+ self.page = page
438
+ self.page_idx = page_idx
439
+ self.dimensions = dimensions
440
+ self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None)
441
+ self.language = language if isinstance(language, dict) else dict(value=None, confidence=None)
442
+
443
+ def render(self, prediction_break: str = "\n\n") -> str:
444
+ """Renders the full text of the element"""
445
+ return prediction_break.join(
446
+ f"{class_name}: {p.render()}" for class_name, predictions in self.predictions.items() for p in predictions
447
+ )
448
+
449
+ def extra_repr(self) -> str:
450
+ return f"dimensions={self.dimensions}"
451
+
452
+ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs) -> None:
453
+ """Overlay the result on a given image
454
+
455
+ Args:
456
+ interactive: whether the display should be interactive
457
+ preserve_aspect_ratio: pass True if you passed True to the predictor
458
+ **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
459
+ """
460
+ visualize_kie_page(
461
+ self.export(), self.page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio
462
+ )
463
+ plt.show(**kwargs)
464
+
465
+ def synthesize(self, **kwargs) -> np.ndarray:
466
+ """Synthesize the page from the predictions
467
+
468
+ Args:
469
+ ----
470
+ **kwargs: keyword arguments passed to the matplotlib.pyplot.show method
471
+
472
+ Returns:
473
+ -------
474
+ synthesized page
475
+ """
476
+ return synthesize_kie_page(self.export(), **kwargs)
477
+
478
+ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[bytes, ET.ElementTree]:
479
+ """Export the page as XML (hOCR-format)
480
+ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md
481
+
482
+ Args:
483
+ ----
484
+ file_title: the title of the XML file
485
+
486
+ Returns:
487
+ -------
488
+ a tuple of the XML byte string, and its ElementTree
489
+ """
490
+ p_idx = self.page_idx
491
+ prediction_count: int = 1
492
+ height, width = self.dimensions
493
+ language = self.language if "language" in self.language.keys() else "en"
494
+ # Create the XML root element
495
+ page_hocr = ETElement("html", attrib={"xmlns": "http://www.w3.org/1999/xhtml", "xml:lang": str(language)})
496
+ # Create the header / SubElements of the root element
497
+ head = SubElement(page_hocr, "head")
498
+ SubElement(head, "title").text = file_title
499
+ SubElement(head, "meta", attrib={"http-equiv": "Content-Type", "content": "text/html; charset=utf-8"})
500
+ SubElement(
501
+ head,
502
+ "meta",
503
+ attrib={"name": "ocr-system", "content": f"python-doctr {doctr.__version__}"}, # type: ignore[attr-defined]
504
+ )
505
+ SubElement(
506
+ head,
507
+ "meta",
508
+ attrib={"name": "ocr-capabilities", "content": "ocr_page ocr_carea ocr_par ocr_line ocrx_word"},
509
+ )
510
+ # Create the body
511
+ body = SubElement(page_hocr, "body")
512
+ SubElement(
513
+ body,
514
+ "div",
515
+ attrib={
516
+ "class": "ocr_page",
517
+ "id": f"page_{p_idx + 1}",
518
+ "title": f"image; bbox 0 0 {width} {height}; ppageno 0",
519
+ },
520
+ )
521
+ # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes
522
+ for class_name, predictions in self.predictions.items():
523
+ for prediction in predictions:
524
+ if len(prediction.geometry) != 2:
525
+ raise TypeError("XML export is only available for straight bounding boxes for now.")
526
+ (xmin, ymin), (xmax, ymax) = prediction.geometry
527
+ prediction_div = SubElement(
528
+ body,
529
+ "div",
530
+ attrib={
531
+ "class": "ocr_carea",
532
+ "id": f"{class_name}_prediction_{prediction_count}",
533
+ "title": f"bbox {int(round(xmin * width))} {int(round(ymin * height))} \
534
+ {int(round(xmax * width))} {int(round(ymax * height))}",
535
+ },
536
+ )
537
+ prediction_div.text = prediction.value
538
+ prediction_count += 1
539
+
540
+ return ET.tostring(page_hocr, encoding="utf-8", method="xml"), ET.ElementTree(page_hocr)
541
+
542
+ @classmethod
543
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
544
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
545
+ kwargs.update({
546
+ "predictions": [Prediction.from_dict(predictions_dict) for predictions_dict in save_dict["predictions"]]
547
+ })
548
+ return cls(**kwargs)
549
+
550
+
551
+ class Document(Element):
552
+ """Implements a document element as a collection of pages
553
+
554
+ Args:
555
+ ----
556
+ pages: list of page elements
557
+ """
558
+
559
+ _children_names: List[str] = ["pages"]
560
+ pages: List[Page] = []
561
+
562
+ def __init__(
563
+ self,
564
+ pages: List[Page],
565
+ ) -> None:
566
+ super().__init__(pages=pages)
567
+
568
+ def render(self, page_break: str = "\n\n\n\n") -> str:
569
+ """Renders the full text of the element"""
570
+ return page_break.join(p.render() for p in self.pages)
571
+
572
+ def show(self, **kwargs) -> None:
573
+ """Overlay the result on a given image"""
574
+ for result in self.pages:
575
+ result.show(**kwargs)
576
+
577
+ def synthesize(self, **kwargs) -> List[np.ndarray]:
578
+ """Synthesize all pages from their predictions
579
+
580
+ Returns
581
+ -------
582
+ list of synthesized pages
583
+ """
584
+ return [page.synthesize() for page in self.pages]
585
+
586
+ def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]:
587
+ """Export the document as XML (hOCR-format)
588
+
589
+ Args:
590
+ ----
591
+ **kwargs: additional keyword arguments passed to the Page.export_as_xml method
592
+
593
+ Returns:
594
+ -------
595
+ list of tuple of (bytes, ElementTree)
596
+ """
597
+ return [page.export_as_xml(**kwargs) for page in self.pages]
598
+
599
+ @classmethod
600
+ def from_dict(cls, save_dict: Dict[str, Any], **kwargs):
601
+ kwargs = {k: save_dict[k] for k in cls._exported_keys}
602
+ kwargs.update({"pages": [Page.from_dict(page_dict) for page_dict in save_dict["pages"]]})
603
+ return cls(**kwargs)
604
+
605
+
606
+ class KIEDocument(Document):
607
+ """Implements a document element as a collection of pages
608
+
609
+ Args:
610
+ ----
611
+ pages: list of page elements
612
+ """
613
+
614
+ _children_names: List[str] = ["pages"]
615
+ pages: List[KIEPage] = [] # type: ignore[assignment]
616
+
617
+ def __init__(
618
+ self,
619
+ pages: List[KIEPage],
620
+ ) -> None:
621
+ super().__init__(pages=pages) # type: ignore[arg-type]
doctr/io/html.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any
7
+
8
+ from weasyprint import HTML
9
+
10
+ __all__ = ["read_html"]
11
+
12
+
13
+ def read_html(url: str, **kwargs: Any) -> bytes:
14
+ """Read a PDF file and convert it into an image in numpy format
15
+
16
+ >>> from doctr.io import read_html
17
+ >>> doc = read_html("https://www.yoursite.com")
18
+
19
+ Args:
20
+ ----
21
+ url: URL of the target web page
22
+ **kwargs: keyword arguments from `weasyprint.HTML`
23
+
24
+ Returns:
25
+ -------
26
+ decoded PDF file as a bytes stream
27
+ """
28
+ return HTML(url, **kwargs).write_pdf()
doctr/io/image/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
+
3
+ from .base import *
4
+
5
+ if is_tf_available():
6
+ from .tensorflow import *
7
+ elif is_torch_available():
8
+ from .pytorch import *
doctr/io/image/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+ from doctr.utils.common_types import AbstractFile
13
+
14
+ __all__ = ["read_img_as_numpy"]
15
+
16
+
17
+ def read_img_as_numpy(
18
+ file: AbstractFile,
19
+ output_size: Optional[Tuple[int, int]] = None,
20
+ rgb_output: bool = True,
21
+ ) -> np.ndarray:
22
+ """Read an image file into numpy format
23
+
24
+ >>> from doctr.io import read_img_as_numpy
25
+ >>> page = read_img_as_numpy("path/to/your/doc.jpg")
26
+
27
+ Args:
28
+ ----
29
+ file: the path to the image file
30
+ output_size: the expected output size of each page in format H x W
31
+ rgb_output: whether the output ndarray channel order should be RGB instead of BGR.
32
+
33
+ Returns:
34
+ -------
35
+ the page decoded as numpy ndarray of shape H x W x 3
36
+ """
37
+ if isinstance(file, (str, Path)):
38
+ if not Path(file).is_file():
39
+ raise FileNotFoundError(f"unable to access {file}")
40
+ img = cv2.imread(str(file), cv2.IMREAD_COLOR)
41
+ elif isinstance(file, bytes):
42
+ _file: np.ndarray = np.frombuffer(file, np.uint8)
43
+ img = cv2.imdecode(_file, cv2.IMREAD_COLOR)
44
+ else:
45
+ raise TypeError("unsupported object type for argument 'file'")
46
+
47
+ # Validity check
48
+ if img is None:
49
+ raise ValueError("unable to read file.")
50
+ # Resizing
51
+ if isinstance(output_size, tuple):
52
+ img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR)
53
+ # Switch the channel order
54
+ if rgb_output:
55
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
56
+ return img
doctr/io/image/pytorch.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from io import BytesIO
7
+ from typing import Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import to_tensor
13
+
14
+ from doctr.utils.common_types import AbstractPath
15
+
16
+ __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
17
+
18
+
19
+ def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor:
20
+ """Convert a PIL Image to a PyTorch tensor
21
+
22
+ Args:
23
+ ----
24
+ pil_img: a PIL image
25
+ dtype: the output tensor data type
26
+
27
+ Returns:
28
+ -------
29
+ decoded image as tensor
30
+ """
31
+ if dtype == torch.float32:
32
+ img = to_tensor(pil_img)
33
+ else:
34
+ img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype)
35
+
36
+ return img
37
+
38
+
39
+ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor:
40
+ """Read an image file as a PyTorch tensor
41
+
42
+ Args:
43
+ ----
44
+ img_path: location of the image file
45
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
46
+
47
+ Returns:
48
+ -------
49
+ decoded image as a tensor
50
+ """
51
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
52
+ raise ValueError("insupported value for dtype")
53
+
54
+ pil_img = Image.open(img_path, mode="r").convert("RGB")
55
+
56
+ return tensor_from_pil(pil_img, dtype)
57
+
58
+
59
+ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
60
+ """Read a byte stream as a PyTorch tensor
61
+
62
+ Args:
63
+ ----
64
+ img_content: bytes of a decoded image
65
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
66
+
67
+ Returns:
68
+ -------
69
+ decoded image as a tensor
70
+ """
71
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
72
+ raise ValueError("insupported value for dtype")
73
+
74
+ pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")
75
+
76
+ return tensor_from_pil(pil_img, dtype)
77
+
78
+
79
+ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
80
+ """Read an image file as a PyTorch tensor
81
+
82
+ Args:
83
+ ----
84
+ npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
85
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
86
+
87
+ Returns:
88
+ -------
89
+ same image as a tensor of shape (C, H, W)
90
+ """
91
+ if dtype not in (torch.uint8, torch.float16, torch.float32):
92
+ raise ValueError("insupported value for dtype")
93
+
94
+ if dtype == torch.float32:
95
+ img = to_tensor(npy_img)
96
+ else:
97
+ img = torch.from_numpy(npy_img)
98
+ # put it from HWC to CHW format
99
+ img = img.permute((2, 0, 1)).contiguous()
100
+ if dtype == torch.float16:
101
+ # Switch to FP16
102
+ img = img.to(dtype=torch.float16).div(255)
103
+
104
+ return img
105
+
106
+
107
+ def get_img_shape(img: torch.Tensor) -> Tuple[int, int]:
108
+ """Get the shape of an image"""
109
+ return img.shape[-2:] # type: ignore[return-value]
doctr/io/image/tensorflow.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from PIL import Image
11
+ from tensorflow.keras.utils import img_to_array
12
+
13
+ from doctr.utils.common_types import AbstractPath
14
+
15
+ __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"]
16
+
17
+
18
+ def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
19
+ """Convert a PIL Image to a TensorFlow tensor
20
+
21
+ Args:
22
+ ----
23
+ pil_img: a PIL image
24
+ dtype: the output tensor data type
25
+
26
+ Returns:
27
+ -------
28
+ decoded image as tensor
29
+ """
30
+ npy_img = img_to_array(pil_img)
31
+
32
+ return tensor_from_numpy(npy_img, dtype)
33
+
34
+
35
+ def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
36
+ """Read an image file as a TensorFlow tensor
37
+
38
+ Args:
39
+ ----
40
+ img_path: location of the image file
41
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
42
+
43
+ Returns:
44
+ -------
45
+ decoded image as a tensor
46
+ """
47
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
48
+ raise ValueError("insupported value for dtype")
49
+
50
+ img = tf.io.read_file(img_path)
51
+ img = tf.image.decode_jpeg(img, channels=3)
52
+
53
+ if dtype != tf.uint8:
54
+ img = tf.image.convert_image_dtype(img, dtype=dtype)
55
+ img = tf.clip_by_value(img, 0, 1)
56
+
57
+ return img
58
+
59
+
60
+ def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
61
+ """Read a byte stream as a TensorFlow tensor
62
+
63
+ Args:
64
+ ----
65
+ img_content: bytes of a decoded image
66
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
67
+
68
+ Returns:
69
+ -------
70
+ decoded image as a tensor
71
+ """
72
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
73
+ raise ValueError("insupported value for dtype")
74
+
75
+ img = tf.io.decode_image(img_content, channels=3)
76
+
77
+ if dtype != tf.uint8:
78
+ img = tf.image.convert_image_dtype(img, dtype=dtype)
79
+ img = tf.clip_by_value(img, 0, 1)
80
+
81
+ return img
82
+
83
+
84
+ def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
85
+ """Read an image file as a TensorFlow tensor
86
+
87
+ Args:
88
+ ----
89
+ npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8
90
+ dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255.
91
+
92
+ Returns:
93
+ -------
94
+ same image as a tensor of shape (H, W, C)
95
+ """
96
+ if dtype not in (tf.uint8, tf.float16, tf.float32):
97
+ raise ValueError("insupported value for dtype")
98
+
99
+ if dtype == tf.uint8:
100
+ img = tf.convert_to_tensor(npy_img, dtype=dtype)
101
+ else:
102
+ img = tf.image.convert_image_dtype(npy_img, dtype=dtype)
103
+ img = tf.clip_by_value(img, 0, 1)
104
+
105
+ return img
106
+
107
+
108
+ def get_img_shape(img: tf.Tensor) -> Tuple[int, int]:
109
+ """Get the shape of an image"""
110
+ return img.shape[:2]
doctr/io/pdf.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from typing import Any, List, Optional
7
+
8
+ import numpy as np
9
+ import pypdfium2 as pdfium
10
+
11
+ from doctr.utils.common_types import AbstractFile
12
+
13
+ __all__ = ["read_pdf"]
14
+
15
+
16
+ def read_pdf(
17
+ file: AbstractFile,
18
+ scale: float = 2,
19
+ rgb_mode: bool = True,
20
+ password: Optional[str] = None,
21
+ **kwargs: Any,
22
+ ) -> List[np.ndarray]:
23
+ """Read a PDF file and convert it into an image in numpy format
24
+
25
+ >>> from doctr.io import read_pdf
26
+ >>> doc = read_pdf("path/to/your/doc.pdf")
27
+
28
+ Args:
29
+ ----
30
+ file: the path to the PDF file
31
+ scale: rendering scale (1 corresponds to 72dpi)
32
+ rgb_mode: if True, the output will be RGB, otherwise BGR
33
+ password: a password to unlock the document, if encrypted
34
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
35
+
36
+ Returns:
37
+ -------
38
+ the list of pages decoded as numpy ndarray of shape H x W x C
39
+ """
40
+ # Rasterise pages to numpy ndarrays with pypdfium2
41
+ pdf = pdfium.PdfDocument(file, password=password, autoclose=True)
42
+ return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf]
doctr/io/reader.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2024, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from pathlib import Path
7
+ from typing import List, Sequence, Union
8
+
9
+ import numpy as np
10
+
11
+ from doctr.utils.common_types import AbstractFile
12
+
13
+ from .html import read_html
14
+ from .image import read_img_as_numpy
15
+ from .pdf import read_pdf
16
+
17
+ __all__ = ["DocumentFile"]
18
+
19
+
20
+ class DocumentFile:
21
+ """Read a document from multiple extensions"""
22
+
23
+ @classmethod
24
+ def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]:
25
+ """Read a PDF file
26
+
27
+ >>> from doctr.io import DocumentFile
28
+ >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
29
+
30
+ Args:
31
+ ----
32
+ file: the path to the PDF file or a binary stream
33
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
34
+
35
+ Returns:
36
+ -------
37
+ the list of pages decoded as numpy ndarray of shape H x W x 3
38
+ """
39
+ return read_pdf(file, **kwargs)
40
+
41
+ @classmethod
42
+ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]:
43
+ """Interpret a web page as a PDF document
44
+
45
+ >>> from doctr.io import DocumentFile
46
+ >>> doc = DocumentFile.from_url("https://www.yoursite.com")
47
+
48
+ Args:
49
+ ----
50
+ url: the URL of the target web page
51
+ **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render`
52
+
53
+ Returns:
54
+ -------
55
+ the list of pages decoded as numpy ndarray of shape H x W x 3
56
+ """
57
+ pdf_stream = read_html(url)
58
+ return cls.from_pdf(pdf_stream, **kwargs)
59
+
60
+ @classmethod
61
+ def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]:
62
+ """Read an image file (or a collection of image files) and convert it into an image in numpy format
63
+
64
+ >>> from doctr.io import DocumentFile
65
+ >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"])
66
+
67
+ Args:
68
+ ----
69
+ files: the path to the image file or a binary stream, or a collection of those
70
+ **kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy`
71
+
72
+ Returns:
73
+ -------
74
+ the list of pages decoded as numpy ndarray of shape H x W x 3
75
+ """
76
+ if isinstance(files, (str, Path, bytes)):
77
+ files = [files]
78
+
79
+ return [read_img_as_numpy(file, **kwargs) for file in files]