Merge pull request #472 from VikParuchuri/vik_dev
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/scripts.yml +29 -0
- README.md +55 -10
- benchmarks/table/scoring.py +109 -0
- benchmarks/table/table.py +187 -0
- chunk_convert.py +2 -20
- convert.py +2 -115
- convert_single.py +2 -41
- marker/builders/document.py +6 -1
- marker/builders/layout.py +37 -15
- marker/builders/llm_layout.py +59 -44
- marker/builders/ocr.py +7 -11
- marker/config/parser.py +17 -0
- marker/config/printer.py +42 -16
- marker/converters/pdf.py +34 -27
- marker/converters/table.py +49 -0
- marker/models.py +34 -71
- marker/processors/debug.py +2 -2
- marker/processors/equation.py +4 -6
- marker/processors/ignoretext.py +1 -3
- marker/processors/llm/__init__.py +4 -13
- marker/processors/llm/llm_complex.py +17 -11
- marker/processors/llm/llm_equation.py +82 -0
- marker/processors/llm/llm_form.py +57 -35
- marker/processors/llm/llm_handwriting.py +86 -0
- marker/processors/llm/llm_image_description.py +5 -2
- marker/processors/llm/llm_table.py +55 -23
- marker/processors/llm/llm_table_merge.py +318 -0
- marker/processors/llm/llm_text.py +8 -5
- marker/processors/llm/utils.py +5 -2
- marker/processors/table.py +204 -44
- marker/providers/__init__.py +13 -1
- marker/providers/image.py +52 -0
- marker/providers/pdf.py +21 -6
- marker/providers/registry.py +12 -0
- marker/renderers/__init__.py +14 -8
- marker/renderers/html.py +16 -8
- marker/renderers/json.py +3 -0
- marker/renderers/markdown.py +84 -8
- marker/schema/__init__.py +1 -0
- marker/schema/blocks/__init__.py +1 -0
- marker/schema/blocks/base.py +25 -5
- marker/schema/blocks/basetable.py +39 -0
- marker/schema/blocks/caption.py +2 -0
- marker/schema/blocks/code.py +2 -1
- marker/schema/blocks/complexregion.py +3 -2
- marker/schema/blocks/equation.py +4 -3
- marker/schema/blocks/figure.py +3 -2
- marker/schema/blocks/footnote.py +1 -0
- marker/schema/blocks/form.py +4 -15
- marker/schema/blocks/handwriting.py +8 -0
.github/workflows/scripts.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Test CLI scripts
|
| 2 |
+
|
| 3 |
+
on: [push]
|
| 4 |
+
|
| 5 |
+
env:
|
| 6 |
+
TORCH_DEVICE: "cpu"
|
| 7 |
+
OCR_ENGINE: "surya"
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
tests:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
steps:
|
| 13 |
+
- uses: actions/checkout@v3
|
| 14 |
+
- name: Set up Python 3.11
|
| 15 |
+
uses: actions/setup-python@v4
|
| 16 |
+
with:
|
| 17 |
+
python-version: 3.11
|
| 18 |
+
- name: Install python dependencies
|
| 19 |
+
run: |
|
| 20 |
+
pip install poetry
|
| 21 |
+
poetry install
|
| 22 |
+
- name: Download benchmark data
|
| 23 |
+
run: |
|
| 24 |
+
wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
|
| 25 |
+
unzip -o benchmark_data.zip
|
| 26 |
+
- name: Test single script
|
| 27 |
+
run: poetry run marker_single benchmark_data/pdfs/switch_trans.pdf --page_range 0
|
| 28 |
+
- name: Test convert script
|
| 29 |
+
run: poetry run marker benchmark_data/pdfs --max_files 1 --workers 1 --page_range 0
|
README.md
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
# Marker
|
| 2 |
|
| 3 |
-
Marker converts PDFs to markdown, JSON, and HTML quickly and accurately.
|
| 4 |
|
| 5 |
-
- Supports a
|
| 6 |
-
- Supports all languages
|
| 7 |
- Removes headers/footers/other artifacts
|
| 8 |
-
- Formats tables, forms, and code blocks
|
| 9 |
- Extracts and saves images along with the markdown
|
| 10 |
-
- Converts equations to latex
|
| 11 |
- Easily extensible with your own formatting and logic
|
| 12 |
- Optionally boost accuracy with an LLM
|
| 13 |
- Works on GPU, CPU, or MPS
|
|
@@ -63,11 +61,11 @@ There's a hosted API for marker available [here](https://www.datalab.to/):
|
|
| 63 |
PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
|
| 64 |
|
| 65 |
- Marker will only convert block equations
|
| 66 |
-
- Tables are not always formatted 100% correctly
|
| 67 |
- Forms are not converted optimally
|
| 68 |
- Very complex layouts, with nested tables and forms, may not work
|
| 69 |
|
| 70 |
-
Note: Passing the `--use_llm` flag will mostly solve
|
| 71 |
|
| 72 |
# Installation
|
| 73 |
|
|
@@ -84,7 +82,7 @@ pip install marker-pdf
|
|
| 84 |
First, some configuration:
|
| 85 |
|
| 86 |
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
|
| 87 |
-
- Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag on the CLI or via configuration to ensure your PDF runs through OCR.
|
| 88 |
|
| 89 |
## Interactive App
|
| 90 |
|
|
@@ -101,9 +99,12 @@ marker_gui
|
|
| 101 |
marker_single /path/to/file.pdf
|
| 102 |
```
|
| 103 |
|
|
|
|
|
|
|
| 104 |
Options:
|
| 105 |
- `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
|
| 106 |
- `--output_format [markdown|json|html]`: Specify the format for the output results.
|
|
|
|
| 107 |
- `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
|
| 108 |
- `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
|
| 109 |
- `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
|
|
@@ -114,6 +115,7 @@ Options:
|
|
| 114 |
- `--config_json PATH`: Path to a JSON configuration file containing additional settings.
|
| 115 |
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
|
| 116 |
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
|
|
|
|
| 117 |
|
| 118 |
The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you don't need OCR, marker can work with any language.
|
| 119 |
|
|
@@ -179,7 +181,7 @@ rendered = converter("FILEPATH")
|
|
| 179 |
|
| 180 |
### Extract blocks
|
| 181 |
|
| 182 |
-
Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to
|
| 183 |
|
| 184 |
Here's an example of extracting all forms from a document:
|
| 185 |
|
|
@@ -197,6 +199,28 @@ forms = document.contained_blocks((BlockTypes.Form,))
|
|
| 197 |
|
| 198 |
Look at the processors for more examples of extracting and manipulating blocks.
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
# Output Formats
|
| 201 |
|
| 202 |
## Markdown
|
|
@@ -348,7 +372,7 @@ There are some settings that you may find useful if things aren't working the wa
|
|
| 348 |
Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
|
| 349 |
|
| 350 |
# Benchmarks
|
| 351 |
-
|
| 352 |
Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods. It's noisy, but at least directionally correct.
|
| 353 |
|
| 354 |
**Speed**
|
|
@@ -371,6 +395,18 @@ Marker takes about 6GB of VRAM on average per task, so you can convert 8 documen
|
|
| 371 |
|
| 372 |

|
| 373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
## Running your own benchmarks
|
| 375 |
|
| 376 |
You can benchmark the performance of marker on your machine. Install marker manually with:
|
|
@@ -380,12 +416,21 @@ git clone https://github.com/VikParuchuri/marker.git
|
|
| 380 |
poetry install
|
| 381 |
```
|
| 382 |
|
|
|
|
|
|
|
| 383 |
Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
|
| 384 |
|
| 385 |
```shell
|
| 386 |
python benchmarks/overall.py data/pdfs data/references report.json
|
| 387 |
```
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# Thanks
|
| 390 |
|
| 391 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
|
|
|
| 1 |
# Marker
|
| 2 |
|
| 3 |
+
Marker converts PDFs and images to markdown, JSON, and HTML quickly and accurately.
|
| 4 |
|
| 5 |
+
- Supports a range of documents in all languages
|
|
|
|
| 6 |
- Removes headers/footers/other artifacts
|
| 7 |
+
- Formats tables, forms, equations, links, and code blocks
|
| 8 |
- Extracts and saves images along with the markdown
|
|
|
|
| 9 |
- Easily extensible with your own formatting and logic
|
| 10 |
- Optionally boost accuracy with an LLM
|
| 11 |
- Works on GPU, CPU, or MPS
|
|
|
|
| 61 |
PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
|
| 62 |
|
| 63 |
- Marker will only convert block equations
|
| 64 |
+
- Tables are not always formatted 100% correctly
|
| 65 |
- Forms are not converted optimally
|
| 66 |
- Very complex layouts, with nested tables and forms, may not work
|
| 67 |
|
| 68 |
+
Note: Passing the `--use_llm` flag will mostly solve these issues.
|
| 69 |
|
| 70 |
# Installation
|
| 71 |
|
|
|
|
| 82 |
First, some configuration:
|
| 83 |
|
| 84 |
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
|
| 85 |
+
- Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag on the CLI or via configuration to ensure your PDF runs through OCR, or the `strip_existing_ocr` to keep all digital text, and only strip out any existing OCR text.
|
| 86 |
|
| 87 |
## Interactive App
|
| 88 |
|
|
|
|
| 99 |
marker_single /path/to/file.pdf
|
| 100 |
```
|
| 101 |
|
| 102 |
+
You can pass in PDFs or images.
|
| 103 |
+
|
| 104 |
Options:
|
| 105 |
- `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
|
| 106 |
- `--output_format [markdown|json|html]`: Specify the format for the output results.
|
| 107 |
+
- `--paginate_output`: Paginates the output, using `\n\n{PAGE_NUMBER}` followed by `-` * 48, then `\n\n`
|
| 108 |
- `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
|
| 109 |
- `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
|
| 110 |
- `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
|
|
|
|
| 115 |
- `--config_json PATH`: Path to a JSON configuration file containing additional settings.
|
| 116 |
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
|
| 117 |
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
|
| 118 |
+
- `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
|
| 119 |
|
| 120 |
The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you don't need OCR, marker can work with any language.
|
| 121 |
|
|
|
|
| 181 |
|
| 182 |
### Extract blocks
|
| 183 |
|
| 184 |
+
Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programmatically manipulate these blocks.
|
| 185 |
|
| 186 |
Here's an example of extracting all forms from a document:
|
| 187 |
|
|
|
|
| 199 |
|
| 200 |
Look at the processors for more examples of extracting and manipulating blocks.
|
| 201 |
|
| 202 |
+
## Other converters
|
| 203 |
+
|
| 204 |
+
You can also use other converters that define different conversion pipelines:
|
| 205 |
+
|
| 206 |
+
### Extract tables
|
| 207 |
+
|
| 208 |
+
The `TableConverter` will only convert and extract tables:
|
| 209 |
+
|
| 210 |
+
```python
|
| 211 |
+
from marker.converters.table import TableConverter
|
| 212 |
+
from marker.models import create_model_dict
|
| 213 |
+
from marker.output import text_from_rendered
|
| 214 |
+
|
| 215 |
+
converter = TableConverter(
|
| 216 |
+
artifact_dict=create_model_dict(),
|
| 217 |
+
)
|
| 218 |
+
rendered = converter("FILEPATH")
|
| 219 |
+
text, _, images = text_from_rendered(rendered)
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
This takes all the same configuration as the PdfConverter. You can specify the configuration `force_layout_block=Table` to avoid layout detection and instead assume every page is a table.
|
| 223 |
+
|
| 224 |
# Output Formats
|
| 225 |
|
| 226 |
## Markdown
|
|
|
|
| 372 |
Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
|
| 373 |
|
| 374 |
# Benchmarks
|
| 375 |
+
## Overall PDF Conversion
|
| 376 |
Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods. It's noisy, but at least directionally correct.
|
| 377 |
|
| 378 |
**Speed**
|
|
|
|
| 395 |
|
| 396 |

|
| 397 |
|
| 398 |
+
## Table Conversion
|
| 399 |
+
Marker can extract tables from PDFs using `marker.converters.table.TableConverter`. The table extraction performance is measured by comparing the extracted HTML representation of tables against the original HTML representations using the test split of [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/). The HTML representations are compared using a tree edit distance based metric to judge both structure and content. Marker detects and identifies the structure of all tables in a PDF page and achieves these scores:
|
| 400 |
+
|
| 401 |
+
| Avg score | Total tables | use_llm |
|
| 402 |
+
|-----------|--------------|---------|
|
| 403 |
+
| 0.824 | 54 | False |
|
| 404 |
+
| 0.873 | 54 | True |
|
| 405 |
+
|
| 406 |
+
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
|
| 407 |
+
|
| 408 |
+
We filter out tables that we cannot align with the ground truth, since fintabnet and our layout model have slightly different detection methods (this results in some tables being split/merged).
|
| 409 |
+
|
| 410 |
## Running your own benchmarks
|
| 411 |
|
| 412 |
You can benchmark the performance of marker on your machine. Install marker manually with:
|
|
|
|
| 416 |
poetry install
|
| 417 |
```
|
| 418 |
|
| 419 |
+
### Overall PDF Conversion
|
| 420 |
+
|
| 421 |
Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
|
| 422 |
|
| 423 |
```shell
|
| 424 |
python benchmarks/overall.py data/pdfs data/references report.json
|
| 425 |
```
|
| 426 |
|
| 427 |
+
### Table Conversion
|
| 428 |
+
The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:
|
| 429 |
+
|
| 430 |
+
```shell
|
| 431 |
+
python benchmarks/table/table.py table_report.json --max_rows 1000
|
| 432 |
+
```
|
| 433 |
+
|
| 434 |
# Thanks
|
| 435 |
|
| 436 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
benchmarks/table/scoring.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""""
|
| 2 |
+
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import distance
|
| 6 |
+
from apted import APTED, Config
|
| 7 |
+
from apted.helpers import Tree
|
| 8 |
+
from lxml import html
|
| 9 |
+
from collections import deque
|
| 10 |
+
|
| 11 |
+
def wrap_table_html(table_html:str)->str:
|
| 12 |
+
return f'<html><body>{table_html}</body></html>'
|
| 13 |
+
|
| 14 |
+
class TableTree(Tree):
|
| 15 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
| 16 |
+
self.tag = tag
|
| 17 |
+
self.colspan = colspan
|
| 18 |
+
self.rowspan = rowspan
|
| 19 |
+
self.content = content
|
| 20 |
+
|
| 21 |
+
# Sets self.name and self.children
|
| 22 |
+
super().__init__(tag, *children)
|
| 23 |
+
|
| 24 |
+
def bracket(self):
|
| 25 |
+
"""Show tree using brackets notation"""
|
| 26 |
+
if self.tag == 'td':
|
| 27 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
| 28 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
| 29 |
+
else:
|
| 30 |
+
result = '"tag": %s' % self.tag
|
| 31 |
+
for child in self.children:
|
| 32 |
+
result += child.bracket()
|
| 33 |
+
return "{{{}}}".format(result)
|
| 34 |
+
|
| 35 |
+
class CustomConfig(Config):
|
| 36 |
+
@staticmethod
|
| 37 |
+
def maximum(*sequences):
|
| 38 |
+
return max(map(len, sequences))
|
| 39 |
+
|
| 40 |
+
def normalized_distance(self, *sequences):
|
| 41 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
| 42 |
+
|
| 43 |
+
def rename(self, node1, node2):
|
| 44 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
| 45 |
+
return 1.
|
| 46 |
+
if node1.tag == 'td':
|
| 47 |
+
if node1.content or node2.content:
|
| 48 |
+
return self.normalized_distance(node1.content, node2.content)
|
| 49 |
+
return 0.
|
| 50 |
+
|
| 51 |
+
def tokenize(node):
|
| 52 |
+
"""
|
| 53 |
+
Tokenizes table cells
|
| 54 |
+
"""
|
| 55 |
+
global __tokens__
|
| 56 |
+
__tokens__.append('<%s>' % node.tag)
|
| 57 |
+
if node.text is not None:
|
| 58 |
+
__tokens__ += list(node.text)
|
| 59 |
+
for n in node.getchildren():
|
| 60 |
+
tokenize(n)
|
| 61 |
+
if node.tag != 'unk':
|
| 62 |
+
__tokens__.append('</%s>' % node.tag)
|
| 63 |
+
if node.tag != 'td' and node.tail is not None:
|
| 64 |
+
__tokens__ += list(node.tail)
|
| 65 |
+
|
| 66 |
+
def tree_convert_html(node, convert_cell=False, parent=None):
|
| 67 |
+
"""
|
| 68 |
+
Converts HTML tree to the format required by apted
|
| 69 |
+
"""
|
| 70 |
+
global __tokens__
|
| 71 |
+
if node.tag == 'td':
|
| 72 |
+
if convert_cell:
|
| 73 |
+
__tokens__ = []
|
| 74 |
+
tokenize(node)
|
| 75 |
+
cell = __tokens__[1:-1].copy()
|
| 76 |
+
else:
|
| 77 |
+
cell = []
|
| 78 |
+
new_node = TableTree(node.tag,
|
| 79 |
+
int(node.attrib.get('colspan', '1')),
|
| 80 |
+
int(node.attrib.get('rowspan', '1')),
|
| 81 |
+
cell, *deque())
|
| 82 |
+
else:
|
| 83 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
| 84 |
+
if parent is not None:
|
| 85 |
+
parent.children.append(new_node)
|
| 86 |
+
if node.tag != 'td':
|
| 87 |
+
for n in node.getchildren():
|
| 88 |
+
tree_convert_html(n, convert_cell, new_node)
|
| 89 |
+
if parent is None:
|
| 90 |
+
return new_node
|
| 91 |
+
|
| 92 |
+
def similarity_eval_html(pred, true, structure_only=False):
|
| 93 |
+
"""
|
| 94 |
+
Computes TEDS score between the prediction and the ground truth of a given samples
|
| 95 |
+
"""
|
| 96 |
+
pred, true = html.fromstring(pred), html.fromstring(true)
|
| 97 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
| 98 |
+
pred = pred.xpath('body/table')[0]
|
| 99 |
+
true = true.xpath('body/table')[0]
|
| 100 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
| 101 |
+
n_nodes_true = len(true.xpath(".//*"))
|
| 102 |
+
tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
|
| 103 |
+
tree_true = tree_convert_html(true, convert_cell=not structure_only)
|
| 104 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
| 105 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
| 106 |
+
return 1.0 - (float(distance) / n_nodes)
|
| 107 |
+
else:
|
| 108 |
+
return 0.0
|
| 109 |
+
|
benchmarks/table/table.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 7 |
+
|
| 8 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 9 |
+
|
| 10 |
+
import base64
|
| 11 |
+
import time
|
| 12 |
+
import datasets
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import tempfile
|
| 15 |
+
import click
|
| 16 |
+
from tabulate import tabulate
|
| 17 |
+
import json
|
| 18 |
+
from bs4 import BeautifulSoup
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
| 20 |
+
from pypdfium2._helpers.misc import PdfiumError
|
| 21 |
+
from marker.util import matrix_intersection_area
|
| 22 |
+
|
| 23 |
+
from marker.config.parser import ConfigParser
|
| 24 |
+
from marker.converters.table import TableConverter
|
| 25 |
+
from marker.models import create_model_dict
|
| 26 |
+
|
| 27 |
+
from scoring import wrap_table_html, similarity_eval_html
|
| 28 |
+
|
| 29 |
+
def update_teds_score(result):
|
| 30 |
+
prediction, ground_truth = result['marker_table'], result['gt_table']
|
| 31 |
+
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
|
| 32 |
+
score = similarity_eval_html(prediction, ground_truth)
|
| 33 |
+
result.update({'score':score})
|
| 34 |
+
return result
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def extract_tables(children: List[JSONBlockOutput]):
|
| 38 |
+
tables = []
|
| 39 |
+
for child in children:
|
| 40 |
+
if child.block_type == 'Table':
|
| 41 |
+
tables.append(child)
|
| 42 |
+
elif child.children:
|
| 43 |
+
tables.extend(extract_tables(child.children))
|
| 44 |
+
return tables
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@click.command(help="Benchmark Table to HTML Conversion")
|
| 48 |
+
@click.argument("out_file", type=str)
|
| 49 |
+
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
|
| 50 |
+
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
| 51 |
+
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 52 |
+
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 53 |
+
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
| 54 |
+
def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool, table_rec_batch_size: int | None):
|
| 55 |
+
models = create_model_dict()
|
| 56 |
+
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
|
| 57 |
+
start = time.time()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
dataset = datasets.load_dataset(dataset, split='train')
|
| 61 |
+
dataset = dataset.shuffle(seed=0)
|
| 62 |
+
|
| 63 |
+
iterations = len(dataset)
|
| 64 |
+
if max_rows is not None:
|
| 65 |
+
iterations = min(max_rows, len(dataset))
|
| 66 |
+
|
| 67 |
+
results = []
|
| 68 |
+
total_unaligned = 0
|
| 69 |
+
for i in tqdm(range(iterations), desc='Converting Tables'):
|
| 70 |
+
try:
|
| 71 |
+
row = dataset[i]
|
| 72 |
+
pdf_binary = base64.b64decode(row['pdf'])
|
| 73 |
+
gt_tables = row['tables'] #Already sorted by reading order, which is what marker returns
|
| 74 |
+
|
| 75 |
+
converter = TableConverter(
|
| 76 |
+
config=config_parser.generate_config_dict(),
|
| 77 |
+
artifact_dict=models,
|
| 78 |
+
processor_list=config_parser.get_processors(),
|
| 79 |
+
renderer=config_parser.get_renderer()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
|
| 83 |
+
temp_pdf_file.write(pdf_binary)
|
| 84 |
+
temp_pdf_file.seek(0)
|
| 85 |
+
tqdm.disable = True
|
| 86 |
+
marker_json = converter(temp_pdf_file.name).children
|
| 87 |
+
tqdm.disable = False
|
| 88 |
+
|
| 89 |
+
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 90 |
+
print(f'No tables detected, skipping...')
|
| 91 |
+
total_unaligned += len(gt_tables)
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
marker_tables = extract_tables(marker_json)
|
| 95 |
+
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 96 |
+
page_bbox = marker_json[0].bbox
|
| 97 |
+
|
| 98 |
+
# Normalize the bboxes
|
| 99 |
+
for bbox in marker_table_boxes:
|
| 100 |
+
bbox[0] = bbox[0] / page_bbox[2]
|
| 101 |
+
bbox[1] = bbox[1] / page_bbox[3]
|
| 102 |
+
bbox[2] = bbox[2] / page_bbox[2]
|
| 103 |
+
bbox[3] = bbox[3] / page_bbox[3]
|
| 104 |
+
|
| 105 |
+
gt_boxes = [table['normalized_bbox'] for table in gt_tables]
|
| 106 |
+
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
|
| 107 |
+
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
|
| 108 |
+
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
|
| 109 |
+
|
| 110 |
+
aligned_tables = []
|
| 111 |
+
used_tables = set()
|
| 112 |
+
unaligned_tables = set()
|
| 113 |
+
for table_idx, alignment in enumerate(table_alignments):
|
| 114 |
+
try:
|
| 115 |
+
max_area = np.max(alignment)
|
| 116 |
+
aligned_idx = np.argmax(alignment)
|
| 117 |
+
except ValueError:
|
| 118 |
+
# No alignment found
|
| 119 |
+
unaligned_tables.add(table_idx)
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
if aligned_idx in used_tables:
|
| 123 |
+
# Marker table already aligned with another gt table
|
| 124 |
+
unaligned_tables.add(table_idx)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
# Gt table doesn't align well with any marker table
|
| 128 |
+
gt_table_pct = gt_areas[table_idx] / max_area
|
| 129 |
+
if not .75 < gt_table_pct < 1.25:
|
| 130 |
+
unaligned_tables.add(table_idx)
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
# Marker table doesn't align with gt table
|
| 134 |
+
marker_table_pct = marker_areas[aligned_idx] / max_area
|
| 135 |
+
if not .75 < marker_table_pct < 1.25:
|
| 136 |
+
unaligned_tables.add(table_idx)
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
aligned_tables.append(
|
| 140 |
+
(marker_tables[aligned_idx], gt_tables[table_idx])
|
| 141 |
+
)
|
| 142 |
+
used_tables.add(aligned_idx)
|
| 143 |
+
|
| 144 |
+
total_unaligned += len(unaligned_tables)
|
| 145 |
+
|
| 146 |
+
for marker_table, gt_table in aligned_tables:
|
| 147 |
+
gt_table_html = gt_table['html']
|
| 148 |
+
|
| 149 |
+
#marker wraps the table in <tbody> which fintabnet data doesn't
|
| 150 |
+
#Fintabnet doesn't use th tags, need to be replaced for fair comparison
|
| 151 |
+
marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser')
|
| 152 |
+
marker_table_soup.find('tbody').unwrap()
|
| 153 |
+
for th_tag in marker_table_soup.find_all('th'):
|
| 154 |
+
th_tag.name = 'td'
|
| 155 |
+
marker_table_html = str(marker_table_soup)
|
| 156 |
+
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 157 |
+
|
| 158 |
+
results.append({
|
| 159 |
+
"marker_table": marker_table_html,
|
| 160 |
+
"gt_table": gt_table_html
|
| 161 |
+
})
|
| 162 |
+
except PdfiumError:
|
| 163 |
+
print('Broken PDF, Skipping...')
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
print(f"Total time: {time.time() - start}.")
|
| 167 |
+
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
| 168 |
+
|
| 169 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 170 |
+
results = list(
|
| 171 |
+
tqdm(
|
| 172 |
+
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
avg_score = sum([r["score"] for r in results]) / len(results)
|
| 176 |
+
|
| 177 |
+
headers = ["Avg score", "Total tables"]
|
| 178 |
+
data = [f"{avg_score:.3f}", len(results)]
|
| 179 |
+
table = tabulate([data], headers=headers, tablefmt="github")
|
| 180 |
+
print(table)
|
| 181 |
+
print("Avg score computed by comparing marker predicted HTML with original HTML")
|
| 182 |
+
|
| 183 |
+
with open(out_file, "w+") as f:
|
| 184 |
+
json.dump(results, f, indent=2)
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
main()
|
chunk_convert.py
CHANGED
|
@@ -1,22 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
-
import subprocess
|
| 3 |
-
import pkg_resources
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
parser = argparse.ArgumentParser(description="Convert a folder of PDFs to a folder of markdown files in chunks.")
|
| 8 |
-
parser.add_argument("in_folder", help="Input folder with pdfs.")
|
| 9 |
-
parser.add_argument("out_folder", help="Output folder")
|
| 10 |
-
args = parser.parse_args()
|
| 11 |
-
|
| 12 |
-
script_path = pkg_resources.resource_filename(__name__, 'chunk_convert.sh')
|
| 13 |
-
|
| 14 |
-
# Construct the command
|
| 15 |
-
cmd = f"{script_path} {args.in_folder} {args.out_folder}"
|
| 16 |
-
|
| 17 |
-
# Execute the shell script
|
| 18 |
-
subprocess.run(cmd, shell=True, check=True)
|
| 19 |
-
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
-
|
|
|
|
| 1 |
+
from marker.scripts import chunk_convert_cli
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
+
chunk_convert_cli()
|
convert.py
CHANGED
|
@@ -1,117 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
| 4 |
-
os.environ["GLOG_minloglevel"] = "2"
|
| 5 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 6 |
-
os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
import traceback
|
| 10 |
-
|
| 11 |
-
import click
|
| 12 |
-
import torch.multiprocessing as mp
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
|
| 15 |
-
from marker.config.parser import ConfigParser
|
| 16 |
-
from marker.config.printer import CustomClickPrinter
|
| 17 |
-
from marker.converters.pdf import PdfConverter
|
| 18 |
-
from marker.logger import configure_logging
|
| 19 |
-
from marker.models import create_model_dict
|
| 20 |
-
from marker.output import output_exists, save_output
|
| 21 |
-
from marker.settings import settings
|
| 22 |
-
|
| 23 |
-
configure_logging()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def worker_init(model_dict):
|
| 27 |
-
if model_dict is None:
|
| 28 |
-
model_dict = create_model_dict()
|
| 29 |
-
|
| 30 |
-
global model_refs
|
| 31 |
-
model_refs = model_dict
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def worker_exit():
|
| 35 |
-
global model_refs
|
| 36 |
-
del model_refs
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def process_single_pdf(args):
|
| 40 |
-
fpath, cli_options = args
|
| 41 |
-
config_parser = ConfigParser(cli_options)
|
| 42 |
-
|
| 43 |
-
out_folder = config_parser.get_output_folder(fpath)
|
| 44 |
-
base_name = config_parser.get_base_filename(fpath)
|
| 45 |
-
if cli_options.get('skip_existing') and output_exists(out_folder, base_name):
|
| 46 |
-
return
|
| 47 |
-
|
| 48 |
-
try:
|
| 49 |
-
converter = PdfConverter(
|
| 50 |
-
config=config_parser.generate_config_dict(),
|
| 51 |
-
artifact_dict=model_refs,
|
| 52 |
-
processor_list=config_parser.get_processors(),
|
| 53 |
-
renderer=config_parser.get_renderer()
|
| 54 |
-
)
|
| 55 |
-
rendered = converter(fpath)
|
| 56 |
-
out_folder = config_parser.get_output_folder(fpath)
|
| 57 |
-
save_output(rendered, out_folder, base_name)
|
| 58 |
-
except Exception as e:
|
| 59 |
-
print(f"Error converting {fpath}: {e}")
|
| 60 |
-
print(traceback.format_exc())
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@click.command(cls=CustomClickPrinter)
|
| 64 |
-
@click.argument("in_folder", type=str)
|
| 65 |
-
@ConfigParser.common_options
|
| 66 |
-
@click.option("--chunk_idx", type=int, default=0, help="Chunk index to convert")
|
| 67 |
-
@click.option("--num_chunks", type=int, default=1, help="Number of chunks being processed in parallel")
|
| 68 |
-
@click.option("--max_files", type=int, default=None, help="Maximum number of pdfs to convert")
|
| 69 |
-
@click.option("--workers", type=int, default=5, help="Number of worker processes to use.")
|
| 70 |
-
@click.option("--skip_existing", is_flag=True, default=False, help="Skip existing converted files.")
|
| 71 |
-
def main(in_folder: str, **kwargs):
|
| 72 |
-
in_folder = os.path.abspath(in_folder)
|
| 73 |
-
files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)]
|
| 74 |
-
files = [f for f in files if os.path.isfile(f)]
|
| 75 |
-
|
| 76 |
-
# Handle chunks if we're processing in parallel
|
| 77 |
-
# Ensure we get all files into a chunk
|
| 78 |
-
chunk_size = math.ceil(len(files) / kwargs["num_chunks"])
|
| 79 |
-
start_idx = kwargs["chunk_idx"] * chunk_size
|
| 80 |
-
end_idx = start_idx + chunk_size
|
| 81 |
-
files_to_convert = files[start_idx:end_idx]
|
| 82 |
-
|
| 83 |
-
# Limit files converted if needed
|
| 84 |
-
if kwargs["max_files"]:
|
| 85 |
-
files_to_convert = files_to_convert[:kwargs["max_files"]]
|
| 86 |
-
|
| 87 |
-
# Disable nested multiprocessing
|
| 88 |
-
kwargs["disable_multiprocessing"] = True
|
| 89 |
-
|
| 90 |
-
total_processes = min(len(files_to_convert), kwargs["workers"])
|
| 91 |
-
|
| 92 |
-
try:
|
| 93 |
-
mp.set_start_method('spawn') # Required for CUDA, forkserver doesn't work
|
| 94 |
-
except RuntimeError:
|
| 95 |
-
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")
|
| 96 |
-
|
| 97 |
-
if settings.TORCH_DEVICE == "mps" or settings.TORCH_DEVICE_MODEL == "mps":
|
| 98 |
-
model_dict = None
|
| 99 |
-
else:
|
| 100 |
-
model_dict = create_model_dict()
|
| 101 |
-
for k, v in model_dict.items():
|
| 102 |
-
v.share_memory()
|
| 103 |
-
|
| 104 |
-
print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
|
| 105 |
-
task_args = [(f, kwargs) for f in files_to_convert]
|
| 106 |
-
|
| 107 |
-
with mp.Pool(processes=total_processes, initializer=worker_init, initargs=(model_dict,)) as pool:
|
| 108 |
-
list(tqdm(pool.imap(process_single_pdf, task_args), total=len(task_args), desc="Processing PDFs", unit="pdf"))
|
| 109 |
-
|
| 110 |
-
pool._worker_handler.terminate = worker_exit
|
| 111 |
-
|
| 112 |
-
# Delete all CUDA tensors
|
| 113 |
-
del model_dict
|
| 114 |
-
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
|
|
|
|
| 1 |
+
from marker.scripts import convert_cli
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
+
convert_cli()
|
convert_single.py
CHANGED
|
@@ -1,43 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
| 4 |
-
os.environ["GLOG_minloglevel"] = "2"
|
| 5 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import click
|
| 9 |
-
|
| 10 |
-
from marker.config.parser import ConfigParser
|
| 11 |
-
from marker.config.printer import CustomClickPrinter
|
| 12 |
-
from marker.converters.pdf import PdfConverter
|
| 13 |
-
from marker.logger import configure_logging
|
| 14 |
-
from marker.models import create_model_dict
|
| 15 |
-
from marker.output import save_output
|
| 16 |
-
|
| 17 |
-
configure_logging()
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
@click.command(cls=CustomClickPrinter, help="Convert a single PDF to markdown.")
|
| 21 |
-
@click.argument("fpath", type=str)
|
| 22 |
-
@ConfigParser.common_options
|
| 23 |
-
def main(fpath: str, **kwargs):
|
| 24 |
-
models = create_model_dict()
|
| 25 |
-
start = time.time()
|
| 26 |
-
config_parser = ConfigParser(kwargs)
|
| 27 |
-
|
| 28 |
-
converter = PdfConverter(
|
| 29 |
-
config=config_parser.generate_config_dict(),
|
| 30 |
-
artifact_dict=models,
|
| 31 |
-
processor_list=config_parser.get_processors(),
|
| 32 |
-
renderer=config_parser.get_renderer()
|
| 33 |
-
)
|
| 34 |
-
rendered = converter(fpath)
|
| 35 |
-
out_folder = config_parser.get_output_folder(fpath)
|
| 36 |
-
save_output(rendered, out_folder, config_parser.get_base_filename(fpath))
|
| 37 |
-
|
| 38 |
-
print(f"Saved markdown to {out_folder}")
|
| 39 |
-
print(f"Total time: {time.time() - start}")
|
| 40 |
-
|
| 41 |
|
| 42 |
if __name__ == "__main__":
|
| 43 |
-
|
|
|
|
| 1 |
+
from marker.scripts import convert_single_cli
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
+
convert_single_cli()
|
marker/builders/document.py
CHANGED
|
@@ -22,11 +22,16 @@ class DocumentBuilder(BaseBuilder):
|
|
| 22 |
int,
|
| 23 |
"DPI setting for high-resolution page images used for OCR.",
|
| 24 |
] = 192
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
|
| 27 |
document = self.build_document(provider)
|
| 28 |
layout_builder(document, provider)
|
| 29 |
-
|
|
|
|
| 30 |
return document
|
| 31 |
|
| 32 |
def build_document(self, provider: PdfProvider):
|
|
|
|
| 22 |
int,
|
| 23 |
"DPI setting for high-resolution page images used for OCR.",
|
| 24 |
] = 192
|
| 25 |
+
disable_ocr: Annotated[
|
| 26 |
+
bool,
|
| 27 |
+
"Disable OCR processing.",
|
| 28 |
+
] = False
|
| 29 |
|
| 30 |
def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
|
| 31 |
document = self.build_document(provider)
|
| 32 |
layout_builder(document, provider)
|
| 33 |
+
if not self.disable_ocr:
|
| 34 |
+
ocr_builder(document, provider)
|
| 35 |
return document
|
| 36 |
|
| 37 |
def build_document(self, provider: PdfProvider):
|
marker/builders/layout.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
from typing import Annotated, List, Optional, Tuple
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
from surya.layout import
|
| 5 |
-
from surya.
|
| 6 |
-
from surya.
|
| 7 |
-
from surya.ocr_error import
|
| 8 |
-
from surya.schema import LayoutResult, OCRErrorDetectionResult
|
| 9 |
|
| 10 |
from marker.builders import BaseBuilder
|
| 11 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
@@ -51,15 +50,23 @@ class LayoutBuilder(BaseBuilder):
|
|
| 51 |
Tuple[BlockTypes],
|
| 52 |
"A list of block types to exclude from the layout coverage check.",
|
| 53 |
] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
def __init__(self, layout_model:
|
| 56 |
self.layout_model = layout_model
|
| 57 |
self.ocr_error_model = ocr_error_model
|
| 58 |
|
| 59 |
super().__init__(config)
|
| 60 |
|
| 61 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
self.add_blocks_to_pages(document.pages, layout_results)
|
| 64 |
self.merge_blocks(document.pages, provider.page_lines)
|
| 65 |
|
|
@@ -70,12 +77,29 @@ class LayoutBuilder(BaseBuilder):
|
|
| 70 |
return 6
|
| 71 |
return 6
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
[p.lowres_image for p in pages],
|
| 77 |
-
self.layout_model,
|
| 78 |
-
processor,
|
| 79 |
batch_size=int(self.get_batch_size())
|
| 80 |
)
|
| 81 |
return layout_results
|
|
@@ -97,10 +121,8 @@ class LayoutBuilder(BaseBuilder):
|
|
| 97 |
|
| 98 |
page_texts.append(page_text)
|
| 99 |
|
| 100 |
-
ocr_error_detection_results =
|
| 101 |
page_texts,
|
| 102 |
-
self.ocr_error_model,
|
| 103 |
-
self.ocr_error_model.tokenizer,
|
| 104 |
batch_size=int(self.get_batch_size()) # TODO Better Multiplier
|
| 105 |
)
|
| 106 |
return ocr_error_detection_results
|
|
|
|
| 1 |
from typing import Annotated, List, Optional, Tuple
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
+
from surya.layout import LayoutPredictor
|
| 5 |
+
from surya.layout.schema import LayoutResult, LayoutBox
|
| 6 |
+
from surya.ocr_error import OCRErrorPredictor
|
| 7 |
+
from surya.ocr_error.schema import OCRErrorDetectionResult
|
|
|
|
| 8 |
|
| 9 |
from marker.builders import BaseBuilder
|
| 10 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
|
|
| 50 |
Tuple[BlockTypes],
|
| 51 |
"A list of block types to exclude from the layout coverage check.",
|
| 52 |
] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
| 53 |
+
force_layout_block: Annotated[
|
| 54 |
+
str,
|
| 55 |
+
"Skip layout and force every page to be treated as a specific block type.",
|
| 56 |
+
] = None
|
| 57 |
|
| 58 |
+
def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
|
| 59 |
self.layout_model = layout_model
|
| 60 |
self.ocr_error_model = ocr_error_model
|
| 61 |
|
| 62 |
super().__init__(config)
|
| 63 |
|
| 64 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 65 |
+
if self.force_layout_block is not None:
|
| 66 |
+
# Assign the full content of every page to a single layout type
|
| 67 |
+
layout_results = self.forced_layout(document.pages)
|
| 68 |
+
else:
|
| 69 |
+
layout_results = self.surya_layout(document.pages)
|
| 70 |
self.add_blocks_to_pages(document.pages, layout_results)
|
| 71 |
self.merge_blocks(document.pages, provider.page_lines)
|
| 72 |
|
|
|
|
| 77 |
return 6
|
| 78 |
return 6
|
| 79 |
|
| 80 |
+
def forced_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
|
| 81 |
+
layout_results = []
|
| 82 |
+
for page in pages:
|
| 83 |
+
layout_results.append(
|
| 84 |
+
LayoutResult(
|
| 85 |
+
image_bbox=page.polygon.bbox,
|
| 86 |
+
bboxes=[
|
| 87 |
+
LayoutBox(
|
| 88 |
+
label=self.force_layout_block,
|
| 89 |
+
position=0,
|
| 90 |
+
top_k={self.force_layout_block: 1},
|
| 91 |
+
polygon=page.polygon.polygon,
|
| 92 |
+
),
|
| 93 |
+
],
|
| 94 |
+
sliced=False
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
return layout_results
|
| 98 |
+
|
| 99 |
+
|
| 100 |
def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
|
| 101 |
+
layout_results = self.layout_model(
|
| 102 |
+
[p.get_image(highres=False) for p in pages],
|
|
|
|
|
|
|
|
|
|
| 103 |
batch_size=int(self.get_batch_size())
|
| 104 |
)
|
| 105 |
return layout_results
|
|
|
|
| 121 |
|
| 122 |
page_texts.append(page_text)
|
| 123 |
|
| 124 |
+
ocr_error_detection_results = self.ocr_error_model(
|
| 125 |
page_texts,
|
|
|
|
|
|
|
| 126 |
batch_size=int(self.get_batch_size()) # TODO Better Multiplier
|
| 127 |
)
|
| 128 |
return ocr_error_detection_results
|
marker/builders/llm_layout.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
-
import json
|
| 2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 3 |
-
from typing import Annotated
|
| 4 |
|
| 5 |
from google.ai.generativelanguage_v1beta.types import content
|
| 6 |
-
from surya.
|
| 7 |
-
from surya.
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
from marker.builders.layout import LayoutBuilder
|
|
@@ -24,16 +23,16 @@ class LLMLayoutBuilder(LayoutBuilder):
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
google_api_key: Annotated[
|
| 27 |
-
|
| 28 |
"The Google API key to use for the Gemini model.",
|
| 29 |
] = settings.GOOGLE_API_KEY
|
| 30 |
confidence_threshold: Annotated[
|
| 31 |
float,
|
| 32 |
-
"The confidence threshold to use for relabeling.",
|
| 33 |
-
] = 0.
|
| 34 |
picture_height_threshold: Annotated[
|
| 35 |
float,
|
| 36 |
-
"The height threshold for pictures that may actually be complex regions.",
|
| 37 |
] = 0.8
|
| 38 |
model_name: Annotated[
|
| 39 |
str,
|
|
@@ -55,43 +54,47 @@ class LLMLayoutBuilder(LayoutBuilder):
|
|
| 55 |
str,
|
| 56 |
"The prompt to use for relabelling blocks.",
|
| 57 |
"Default is a string containing the Gemini relabelling prompt."
|
| 58 |
-
] = """You
|
| 59 |
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
|
| 60 |
-
You will be provided with an image of a layout block and the top k predictions from the current model, along with
|
| 61 |
Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
|
| 62 |
Do not invent any new labels.
|
| 63 |
-
Carefully examine the image and consider the provided predictions.
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
Here are the top k predictions from the model
|
| 67 |
|
|
|
|
| 68 |
"""
|
| 69 |
complex_relabeling_prompt: Annotated[
|
| 70 |
str,
|
| 71 |
"The prompt to use for complex relabelling blocks.",
|
| 72 |
"Default is a string containing the complex relabelling prompt."
|
| 73 |
-
] = """You
|
| 74 |
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
|
| 75 |
-
You will be provided with an image of a layout block and some potential labels.
|
| 76 |
Your job is to analyze the image and choose the single most appropriate label from the provided labels.
|
| 77 |
Do not invent any new labels.
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
Potential labels:
|
| 82 |
|
| 83 |
-
|
| 84 |
-
- Table
|
| 85 |
-
- Form
|
| 86 |
-
- Figure - A graph or diagram with text.
|
| 87 |
-
- ComplexRegion - a complex region containing multiple text and other elements.
|
| 88 |
|
| 89 |
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
|
| 90 |
-
|
| 91 |
-
Here is the image of the layout block:
|
| 92 |
"""
|
| 93 |
|
| 94 |
-
def __init__(self, layout_model:
|
| 95 |
super().__init__(layout_model, ocr_error_model, config)
|
| 96 |
|
| 97 |
self.model = GoogleModel(self.google_api_key, self.model_name)
|
|
@@ -114,10 +117,10 @@ Here is the image of the layout block:
|
|
| 114 |
confidence = block.top_k.get(block.block_type)
|
| 115 |
# Case when the block is detected as a different type with low confidence
|
| 116 |
if confidence < self.confidence_threshold:
|
| 117 |
-
futures.append(executor.submit(self.process_block_topk_relabeling, page, block))
|
| 118 |
# Case when the block is detected as a picture or figure, but is actually complex
|
| 119 |
elif block.block_type in (BlockTypes.Picture, BlockTypes.Figure, BlockTypes.SectionHeader) and block.polygon.height > page.polygon.height * self.picture_height_threshold:
|
| 120 |
-
futures.append(executor.submit(self.process_block_complex_relabeling, page, block))
|
| 121 |
|
| 122 |
for future in as_completed(futures):
|
| 123 |
future.result() # Raise exceptions if any occurred
|
|
@@ -125,23 +128,40 @@ Here is the image of the layout block:
|
|
| 125 |
|
| 126 |
pbar.close()
|
| 127 |
|
| 128 |
-
def process_block_topk_relabeling(self, page: PageGroup, block: Block):
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
prompt = self.topk_relabelling_prompt
|
| 132 |
-
return self.process_block_relabeling(page, block, prompt)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
complex_prompt = self.complex_relabeling_prompt
|
| 136 |
-
return self.process_block_relabeling(page, block, complex_prompt)
|
| 137 |
|
| 138 |
-
def
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
response_schema = content.Schema(
|
| 141 |
type=content.Type.OBJECT,
|
| 142 |
enum=[],
|
| 143 |
-
required=["label"],
|
| 144 |
properties={
|
|
|
|
|
|
|
|
|
|
| 145 |
"label": content.Schema(
|
| 146 |
type=content.Type.STRING,
|
| 147 |
),
|
|
@@ -162,10 +182,5 @@ Here is the image of the layout block:
|
|
| 162 |
)
|
| 163 |
page.replace_block(block, generated_block)
|
| 164 |
|
| 165 |
-
def extract_image(self,
|
| 166 |
-
|
| 167 |
-
image_box = image_block.polygon\
|
| 168 |
-
.rescale(page.polygon.size, page_img.size)\
|
| 169 |
-
.expand(expand, expand)
|
| 170 |
-
cropped = page_img.crop(image_box.bbox)
|
| 171 |
-
return cropped
|
|
|
|
|
|
|
| 1 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
+
from typing import Annotated
|
| 3 |
|
| 4 |
from google.ai.generativelanguage_v1beta.types import content
|
| 5 |
+
from surya.layout import LayoutPredictor
|
| 6 |
+
from surya.ocr_error import OCRErrorPredictor
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
from marker.builders.layout import LayoutBuilder
|
|
|
|
| 23 |
"""
|
| 24 |
|
| 25 |
google_api_key: Annotated[
|
| 26 |
+
str,
|
| 27 |
"The Google API key to use for the Gemini model.",
|
| 28 |
] = settings.GOOGLE_API_KEY
|
| 29 |
confidence_threshold: Annotated[
|
| 30 |
float,
|
| 31 |
+
"The confidence threshold to use for relabeling (anything below is relabeled).",
|
| 32 |
+
] = 0.7
|
| 33 |
picture_height_threshold: Annotated[
|
| 34 |
float,
|
| 35 |
+
"The height threshold for pictures that may actually be complex regions. (anything above this ratio against the page is relabeled)",
|
| 36 |
] = 0.8
|
| 37 |
model_name: Annotated[
|
| 38 |
str,
|
|
|
|
| 54 |
str,
|
| 55 |
"The prompt to use for relabelling blocks.",
|
| 56 |
"Default is a string containing the Gemini relabelling prompt."
|
| 57 |
+
] = """You're a layout expert specializing in document analysis.
|
| 58 |
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
|
| 59 |
+
You will be provided with an image of a layout block and the top k predictions from the current model, along with the per-label confidence scores.
|
| 60 |
Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
|
| 61 |
Do not invent any new labels.
|
| 62 |
+
Carefully examine the image and consider the provided predictions. Take the model confidence scores into account. The confidence is reported on a 0-1 scale, with 1 being 100% confident. If the existing label is the most appropriate, you should not change it.
|
| 63 |
+
**Instructions**
|
| 64 |
+
1. Analyze the image and consider the provided top k predictions.
|
| 65 |
+
2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
|
| 66 |
+
3. Choose the single most appropriate label from the provided top k predictions.
|
| 67 |
+
|
| 68 |
+
Here are descriptions of the layout blocks you can choose from:
|
| 69 |
+
|
| 70 |
+
{potential_labels}
|
| 71 |
|
| 72 |
+
Here are the top k predictions from the model:
|
| 73 |
|
| 74 |
+
{top_k}
|
| 75 |
"""
|
| 76 |
complex_relabeling_prompt: Annotated[
|
| 77 |
str,
|
| 78 |
"The prompt to use for complex relabelling blocks.",
|
| 79 |
"Default is a string containing the complex relabelling prompt."
|
| 80 |
+
] = """You're a layout expert specializing in document analysis.
|
| 81 |
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
|
| 82 |
+
You will be provided with an image of a layout block and some potential labels that might be appropriate.
|
| 83 |
Your job is to analyze the image and choose the single most appropriate label from the provided labels.
|
| 84 |
Do not invent any new labels.
|
| 85 |
+
**Instructions**
|
| 86 |
+
1. Analyze the image and consider the potential labels.
|
| 87 |
+
2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
|
| 88 |
+
3. Choose the single most appropriate label from the provided labels.
|
| 89 |
|
| 90 |
Potential labels:
|
| 91 |
|
| 92 |
+
{potential_labels}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
|
| 97 |
+
def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
|
| 98 |
super().__init__(layout_model, ocr_error_model, config)
|
| 99 |
|
| 100 |
self.model = GoogleModel(self.google_api_key, self.model_name)
|
|
|
|
| 117 |
confidence = block.top_k.get(block.block_type)
|
| 118 |
# Case when the block is detected as a different type with low confidence
|
| 119 |
if confidence < self.confidence_threshold:
|
| 120 |
+
futures.append(executor.submit(self.process_block_topk_relabeling, document, page, block))
|
| 121 |
# Case when the block is detected as a picture or figure, but is actually complex
|
| 122 |
elif block.block_type in (BlockTypes.Picture, BlockTypes.Figure, BlockTypes.SectionHeader) and block.polygon.height > page.polygon.height * self.picture_height_threshold:
|
| 123 |
+
futures.append(executor.submit(self.process_block_complex_relabeling, document, page, block))
|
| 124 |
|
| 125 |
for future in as_completed(futures):
|
| 126 |
future.result() # Raise exceptions if any occurred
|
|
|
|
| 128 |
|
| 129 |
pbar.close()
|
| 130 |
|
| 131 |
+
def process_block_topk_relabeling(self, document: Document, page: PageGroup, block: Block):
|
| 132 |
+
topk_types = list(block.top_k.keys())
|
| 133 |
+
potential_labels = ""
|
| 134 |
+
for block_type in topk_types:
|
| 135 |
+
label_cls = get_block_class(block_type)
|
| 136 |
+
potential_labels += f"- `{block_type}` - {label_cls.model_fields['block_description'].default}\n"
|
| 137 |
+
|
| 138 |
+
topk = ""
|
| 139 |
+
for k,v in block.top_k.items():
|
| 140 |
+
topk += f"- `{k}` - Confidence {round(v, 3)}\n"
|
| 141 |
|
| 142 |
+
prompt = self.topk_relabelling_prompt.replace("{potential_labels}", potential_labels).replace("{top_k}", topk)
|
|
|
|
| 143 |
|
| 144 |
+
return self.process_block_relabeling(document, page, block, prompt)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
def process_block_complex_relabeling(self, document: Document, page: PageGroup, block: Block):
|
| 147 |
+
potential_labels = ""
|
| 148 |
+
for block_type in [BlockTypes.Figure, BlockTypes.Picture, BlockTypes.ComplexRegion, BlockTypes.Table, BlockTypes.Form]:
|
| 149 |
+
label_cls = get_block_class(block_type)
|
| 150 |
+
potential_labels += f"- `{block_type}` - {label_cls.model_fields['block_description'].default}\n"
|
| 151 |
+
|
| 152 |
+
complex_prompt = self.complex_relabeling_prompt.replace("{potential_labels}", potential_labels)
|
| 153 |
+
return self.process_block_relabeling(document, page, block, complex_prompt)
|
| 154 |
+
|
| 155 |
+
def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
|
| 156 |
+
image = self.extract_image(document, block)
|
| 157 |
response_schema = content.Schema(
|
| 158 |
type=content.Type.OBJECT,
|
| 159 |
enum=[],
|
| 160 |
+
required=["image_description", "label"],
|
| 161 |
properties={
|
| 162 |
+
"image_description": content.Schema(
|
| 163 |
+
type=content.Type.STRING,
|
| 164 |
+
),
|
| 165 |
"label": content.Schema(
|
| 166 |
type=content.Type.STRING,
|
| 167 |
),
|
|
|
|
| 182 |
)
|
| 183 |
page.replace_block(block, generated_block)
|
| 184 |
|
| 185 |
+
def extract_image(self, document: Document, image_block: Block, expand: float = 0.01):
|
| 186 |
+
return image_block.get_image(document, highres=False, expansion=(expand, expand))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker/builders/ocr.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from typing import Annotated, List, Optional
|
| 2 |
|
| 3 |
from ftfy import fix_text
|
| 4 |
-
from surya.
|
| 5 |
-
from surya.
|
| 6 |
-
from surya.ocr import run_ocr
|
| 7 |
|
| 8 |
from marker.builders import BaseBuilder
|
| 9 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
@@ -37,7 +36,7 @@ class OcrBuilder(BaseBuilder):
|
|
| 37 |
"Default is None."
|
| 38 |
] = None
|
| 39 |
|
| 40 |
-
def __init__(self, detection_model:
|
| 41 |
super().__init__(config)
|
| 42 |
|
| 43 |
self.detection_model = detection_model
|
|
@@ -65,16 +64,13 @@ class OcrBuilder(BaseBuilder):
|
|
| 65 |
|
| 66 |
def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
|
| 67 |
page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
|
| 68 |
-
recognition_results =
|
| 69 |
-
images=[page.
|
| 70 |
langs=[self.languages] * len(page_list),
|
| 71 |
-
|
| 72 |
-
det_processor=self.detection_model.processor,
|
| 73 |
-
rec_model=self.recognition_model,
|
| 74 |
-
rec_processor=self.recognition_model.processor,
|
| 75 |
detection_batch_size=int(self.get_detection_batch_size()),
|
| 76 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 77 |
-
highres_images=[page.
|
| 78 |
)
|
| 79 |
|
| 80 |
page_lines = {}
|
|
|
|
| 1 |
from typing import Annotated, List, Optional
|
| 2 |
|
| 3 |
from ftfy import fix_text
|
| 4 |
+
from surya.detection import DetectionPredictor
|
| 5 |
+
from surya.recognition import RecognitionPredictor
|
|
|
|
| 6 |
|
| 7 |
from marker.builders import BaseBuilder
|
| 8 |
from marker.providers import ProviderOutput, ProviderPageLines
|
|
|
|
| 36 |
"Default is None."
|
| 37 |
] = None
|
| 38 |
|
| 39 |
+
def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
|
| 40 |
super().__init__(config)
|
| 41 |
|
| 42 |
self.detection_model = detection_model
|
|
|
|
| 64 |
|
| 65 |
def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
|
| 66 |
page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
|
| 67 |
+
recognition_results = self.recognition_model(
|
| 68 |
+
images=[page.get_image(highres=False) for page in page_list],
|
| 69 |
langs=[self.languages] * len(page_list),
|
| 70 |
+
det_predictor=self.detection_model,
|
|
|
|
|
|
|
|
|
|
| 71 |
detection_batch_size=int(self.get_detection_batch_size()),
|
| 72 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 73 |
+
highres_images=[page.get_image(highres=True) for page in page_list]
|
| 74 |
)
|
| 75 |
|
| 76 |
page_lines = {}
|
marker/config/parser.py
CHANGED
|
@@ -5,11 +5,13 @@ from typing import Dict
|
|
| 5 |
import click
|
| 6 |
|
| 7 |
from marker.config.crawler import crawler
|
|
|
|
| 8 |
from marker.renderers.html import HTMLRenderer
|
| 9 |
from marker.renderers.json import JSONRenderer
|
| 10 |
from marker.renderers.markdown import MarkdownRenderer
|
| 11 |
from marker.settings import settings
|
| 12 |
from marker.util import classes_to_strings, parse_range_str, strings_to_classes
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class ConfigParser:
|
|
@@ -39,6 +41,10 @@ class ConfigParser:
|
|
| 39 |
# we put common options here
|
| 40 |
fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
|
| 41 |
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return fn
|
| 43 |
|
| 44 |
def generate_config_dict(self) -> Dict[str, any]:
|
|
@@ -95,6 +101,17 @@ class ConfigParser:
|
|
| 95 |
|
| 96 |
return processors
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def get_output_folder(self, filepath: str):
|
| 99 |
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
|
| 100 |
fname_base = os.path.splitext(os.path.basename(filepath))[0]
|
|
|
|
| 5 |
import click
|
| 6 |
|
| 7 |
from marker.config.crawler import crawler
|
| 8 |
+
from marker.converters.pdf import PdfConverter
|
| 9 |
from marker.renderers.html import HTMLRenderer
|
| 10 |
from marker.renderers.json import JSONRenderer
|
| 11 |
from marker.renderers.markdown import MarkdownRenderer
|
| 12 |
from marker.settings import settings
|
| 13 |
from marker.util import classes_to_strings, parse_range_str, strings_to_classes
|
| 14 |
+
from marker.schema import BlockTypes
|
| 15 |
|
| 16 |
|
| 17 |
class ConfigParser:
|
|
|
|
| 41 |
# we put common options here
|
| 42 |
fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
|
| 43 |
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
|
| 44 |
+
fn = click.option("--converter_cls", type=str, default=None, help="Converter class to use. Defaults to PDF converter.")(fn)
|
| 45 |
+
|
| 46 |
+
# enum options
|
| 47 |
+
fn = click.option("--force_layout_block", type=click.Choice(choices=[t.name for t in BlockTypes]), default=None,)(fn)
|
| 48 |
return fn
|
| 49 |
|
| 50 |
def generate_config_dict(self) -> Dict[str, any]:
|
|
|
|
| 101 |
|
| 102 |
return processors
|
| 103 |
|
| 104 |
+
def get_converter_cls(self):
|
| 105 |
+
converter_cls = self.cli_options.get("converter_cls", None)
|
| 106 |
+
if converter_cls is not None:
|
| 107 |
+
try:
|
| 108 |
+
return strings_to_classes([converter_cls])[0]
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Error loading converter: {converter_cls} with error: {e}")
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
return PdfConverter
|
| 114 |
+
|
| 115 |
def get_output_folder(self, filepath: str):
|
| 116 |
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
|
| 117 |
fname_base = os.path.splitext(os.path.basename(filepath))[0]
|
marker/config/printer.py
CHANGED
|
@@ -6,19 +6,47 @@ from marker.config.crawler import crawler
|
|
| 6 |
|
| 7 |
|
| 8 |
class CustomClickPrinter(click.Command):
|
| 9 |
-
def get_help(self, ctx):
|
| 10 |
-
additional_help = (
|
| 11 |
-
"\n\nTip: Use 'config --help' to display all the attributes of the Builders, Processors, and Converters in Marker."
|
| 12 |
-
)
|
| 13 |
-
help_text = super().get_help(ctx)
|
| 14 |
-
help_text = help_text + additional_help
|
| 15 |
-
click.echo(help_text)
|
| 16 |
-
|
| 17 |
def parse_args(self, ctx, args):
|
|
|
|
| 18 |
display_help = 'config' in args and '--help' in args
|
| 19 |
if display_help:
|
| 20 |
-
click.echo(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
for base_type, base_type_dict in crawler.class_config_map.items():
|
| 23 |
if display_help:
|
| 24 |
click.echo(f"{base_type}s:")
|
|
@@ -32,16 +60,14 @@ class CustomClickPrinter(click.Command):
|
|
| 32 |
if display_help:
|
| 33 |
click.echo(" " * 8 + f"{attr} ({formatted_type}):")
|
| 34 |
click.echo("\n".join([f'{" " * 12}' + desc for desc in metadata]))
|
| 35 |
-
|
|
|
|
| 36 |
is_flag = attr_type in [bool, Optional[bool]] and not default
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
else:
|
| 40 |
-
options = ["--" + attr, "--" + class_name_attr]
|
| 41 |
-
options.append(class_name_attr)
|
| 42 |
ctx.command.params.append(
|
| 43 |
click.Option(
|
| 44 |
-
|
| 45 |
type=attr_type,
|
| 46 |
help=" ".join(metadata),
|
| 47 |
is_flag=is_flag,
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class CustomClickPrinter(click.Command):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def parse_args(self, ctx, args):
|
| 10 |
+
|
| 11 |
display_help = 'config' in args and '--help' in args
|
| 12 |
if display_help:
|
| 13 |
+
click.echo(
|
| 14 |
+
"Here is a list of all the Builders, Processors, Converters, Providers and Renderers in Marker along with their attributes:")
|
| 15 |
+
|
| 16 |
+
# Keep track of shared attributes and their types
|
| 17 |
+
shared_attrs = {}
|
| 18 |
|
| 19 |
+
# First pass: identify shared attributes and verify compatibility
|
| 20 |
+
for base_type, base_type_dict in crawler.class_config_map.items():
|
| 21 |
+
for class_name, class_map in base_type_dict.items():
|
| 22 |
+
for attr, (attr_type, formatted_type, default, metadata) in class_map['config'].items():
|
| 23 |
+
if attr not in shared_attrs:
|
| 24 |
+
shared_attrs[attr] = {
|
| 25 |
+
'classes': [],
|
| 26 |
+
'type': attr_type,
|
| 27 |
+
'is_flag': attr_type in [bool, Optional[bool]] and not default,
|
| 28 |
+
'metadata': metadata,
|
| 29 |
+
'default': default
|
| 30 |
+
}
|
| 31 |
+
shared_attrs[attr]['classes'].append(class_name)
|
| 32 |
+
|
| 33 |
+
# These are the types of attrs that can be set from the command line
|
| 34 |
+
attr_types = [str, int, float, bool, Optional[int], Optional[float], Optional[str]]
|
| 35 |
+
|
| 36 |
+
# Add shared attribute options first
|
| 37 |
+
for attr, info in shared_attrs.items():
|
| 38 |
+
if info['type'] in attr_types:
|
| 39 |
+
ctx.command.params.append(
|
| 40 |
+
click.Option(
|
| 41 |
+
["--" + attr],
|
| 42 |
+
type=info['type'],
|
| 43 |
+
help=" ".join(info['metadata']) + f" (Applies to: {', '.join(info['classes'])})",
|
| 44 |
+
default=info['default'],
|
| 45 |
+
is_flag=info['is_flag'],
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Second pass: create class-specific options
|
| 50 |
for base_type, base_type_dict in crawler.class_config_map.items():
|
| 51 |
if display_help:
|
| 52 |
click.echo(f"{base_type}s:")
|
|
|
|
| 60 |
if display_help:
|
| 61 |
click.echo(" " * 8 + f"{attr} ({formatted_type}):")
|
| 62 |
click.echo("\n".join([f'{" " * 12}' + desc for desc in metadata]))
|
| 63 |
+
|
| 64 |
+
if attr_type in attr_types:
|
| 65 |
is_flag = attr_type in [bool, Optional[bool]] and not default
|
| 66 |
+
|
| 67 |
+
# Only add class-specific options
|
|
|
|
|
|
|
|
|
|
| 68 |
ctx.command.params.append(
|
| 69 |
click.Option(
|
| 70 |
+
["--" + class_name_attr, class_name_attr],
|
| 71 |
type=attr_type,
|
| 72 |
help=" ".join(metadata),
|
| 73 |
is_flag=is_flag,
|
marker/converters/pdf.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
|
| 4 |
|
| 5 |
import inspect
|
| 6 |
from collections import defaultdict
|
| 7 |
from functools import cache
|
| 8 |
-
from typing import Annotated, Any, Dict, List, Optional, Type
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
from marker.builders.document import DocumentBuilder
|
| 11 |
from marker.builders.layout import LayoutBuilder
|
| 12 |
from marker.builders.llm_layout import LLMLayoutBuilder
|
|
@@ -32,12 +34,13 @@ from marker.processors.reference import ReferenceProcessor
|
|
| 32 |
from marker.processors.sectionheader import SectionHeaderProcessor
|
| 33 |
from marker.processors.table import TableProcessor
|
| 34 |
from marker.processors.text import TextProcessor
|
| 35 |
-
from marker.
|
| 36 |
from marker.renderers.markdown import MarkdownRenderer
|
| 37 |
from marker.schema import BlockTypes
|
| 38 |
from marker.schema.blocks import Block
|
| 39 |
from marker.schema.registry import register_block_class
|
| 40 |
from marker.util import strings_to_classes
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class PdfConverter(BaseConverter):
|
|
@@ -55,6 +58,30 @@ class PdfConverter(BaseConverter):
|
|
| 55 |
bool,
|
| 56 |
"Enable higher quality processing with LLMs.",
|
| 57 |
] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[str]] = None, renderer: str | None = None, config=None):
|
| 60 |
super().__init__(config)
|
|
@@ -65,27 +92,7 @@ class PdfConverter(BaseConverter):
|
|
| 65 |
if processor_list:
|
| 66 |
processor_list = strings_to_classes(processor_list)
|
| 67 |
else:
|
| 68 |
-
processor_list =
|
| 69 |
-
BlockquoteProcessor,
|
| 70 |
-
CodeProcessor,
|
| 71 |
-
DocumentTOCProcessor,
|
| 72 |
-
EquationProcessor,
|
| 73 |
-
FootnoteProcessor,
|
| 74 |
-
IgnoreTextProcessor,
|
| 75 |
-
LineNumbersProcessor,
|
| 76 |
-
ListProcessor,
|
| 77 |
-
PageHeaderProcessor,
|
| 78 |
-
SectionHeaderProcessor,
|
| 79 |
-
TableProcessor,
|
| 80 |
-
LLMTableProcessor,
|
| 81 |
-
LLMFormProcessor,
|
| 82 |
-
TextProcessor,
|
| 83 |
-
LLMTextProcessor,
|
| 84 |
-
LLMComplexRegionProcessor,
|
| 85 |
-
LLMImageDescriptionProcessor,
|
| 86 |
-
ReferenceProcessor,
|
| 87 |
-
DebugProcessor,
|
| 88 |
-
]
|
| 89 |
|
| 90 |
if renderer:
|
| 91 |
renderer = strings_to_classes([renderer])[0]
|
|
@@ -121,11 +128,11 @@ class PdfConverter(BaseConverter):
|
|
| 121 |
|
| 122 |
@cache
|
| 123 |
def build_document(self, filepath: str):
|
|
|
|
| 124 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 125 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
document = DocumentBuilder(self.config)(pdf_provider, layout_builder, ocr_builder)
|
| 129 |
StructureBuilder(self.config)(document)
|
| 130 |
|
| 131 |
for processor_cls in self.processor_list:
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
|
| 3 |
|
| 4 |
import inspect
|
| 5 |
from collections import defaultdict
|
| 6 |
from functools import cache
|
| 7 |
+
from typing import Annotated, Any, Dict, List, Optional, Type, Tuple
|
| 8 |
|
| 9 |
+
from marker.processors import BaseProcessor
|
| 10 |
+
from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
|
| 11 |
+
from marker.providers.registry import provider_from_filepath
|
| 12 |
from marker.builders.document import DocumentBuilder
|
| 13 |
from marker.builders.layout import LayoutBuilder
|
| 14 |
from marker.builders.llm_layout import LLMLayoutBuilder
|
|
|
|
| 34 |
from marker.processors.sectionheader import SectionHeaderProcessor
|
| 35 |
from marker.processors.table import TableProcessor
|
| 36 |
from marker.processors.text import TextProcessor
|
| 37 |
+
from marker.processors.llm.llm_equation import LLMEquationProcessor
|
| 38 |
from marker.renderers.markdown import MarkdownRenderer
|
| 39 |
from marker.schema import BlockTypes
|
| 40 |
from marker.schema.blocks import Block
|
| 41 |
from marker.schema.registry import register_block_class
|
| 42 |
from marker.util import strings_to_classes
|
| 43 |
+
from marker.processors.llm.llm_handwriting import LLMHandwritingProcessor
|
| 44 |
|
| 45 |
|
| 46 |
class PdfConverter(BaseConverter):
|
|
|
|
| 58 |
bool,
|
| 59 |
"Enable higher quality processing with LLMs.",
|
| 60 |
] = False
|
| 61 |
+
default_processors: Tuple[BaseProcessor, ...] = (
|
| 62 |
+
BlockquoteProcessor,
|
| 63 |
+
CodeProcessor,
|
| 64 |
+
DocumentTOCProcessor,
|
| 65 |
+
EquationProcessor,
|
| 66 |
+
FootnoteProcessor,
|
| 67 |
+
IgnoreTextProcessor,
|
| 68 |
+
LineNumbersProcessor,
|
| 69 |
+
ListProcessor,
|
| 70 |
+
PageHeaderProcessor,
|
| 71 |
+
SectionHeaderProcessor,
|
| 72 |
+
TableProcessor,
|
| 73 |
+
LLMTableProcessor,
|
| 74 |
+
LLMTableMergeProcessor,
|
| 75 |
+
LLMFormProcessor,
|
| 76 |
+
TextProcessor,
|
| 77 |
+
LLMTextProcessor,
|
| 78 |
+
LLMComplexRegionProcessor,
|
| 79 |
+
LLMImageDescriptionProcessor,
|
| 80 |
+
LLMEquationProcessor,
|
| 81 |
+
LLMHandwritingProcessor,
|
| 82 |
+
ReferenceProcessor,
|
| 83 |
+
DebugProcessor,
|
| 84 |
+
)
|
| 85 |
|
| 86 |
def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[str]] = None, renderer: str | None = None, config=None):
|
| 87 |
super().__init__(config)
|
|
|
|
| 92 |
if processor_list:
|
| 93 |
processor_list = strings_to_classes(processor_list)
|
| 94 |
else:
|
| 95 |
+
processor_list = self.default_processors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
if renderer:
|
| 98 |
renderer = strings_to_classes([renderer])[0]
|
|
|
|
| 128 |
|
| 129 |
@cache
|
| 130 |
def build_document(self, filepath: str):
|
| 131 |
+
provider_cls = provider_from_filepath(filepath)
|
| 132 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 133 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 134 |
+
with provider_cls(filepath, self.config) as provider:
|
| 135 |
+
document = DocumentBuilder(self.config)(provider, layout_builder, ocr_builder)
|
|
|
|
| 136 |
StructureBuilder(self.config)(document)
|
| 137 |
|
| 138 |
for processor_cls in self.processor_list:
|
marker/converters/table.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import cache
|
| 2 |
+
from typing import Tuple, List
|
| 3 |
+
|
| 4 |
+
from marker.builders.document import DocumentBuilder
|
| 5 |
+
from marker.builders.ocr import OcrBuilder
|
| 6 |
+
from marker.converters.pdf import PdfConverter
|
| 7 |
+
from marker.processors import BaseProcessor
|
| 8 |
+
from marker.processors.llm.llm_complex import LLMComplexRegionProcessor
|
| 9 |
+
from marker.processors.llm.llm_form import LLMFormProcessor
|
| 10 |
+
from marker.processors.llm.llm_table import LLMTableProcessor
|
| 11 |
+
from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
|
| 12 |
+
from marker.processors.table import TableProcessor
|
| 13 |
+
from marker.providers.registry import provider_from_filepath
|
| 14 |
+
from marker.schema import BlockTypes
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TableConverter(PdfConverter):
|
| 18 |
+
default_processors: Tuple[BaseProcessor, ...] = (
|
| 19 |
+
TableProcessor,
|
| 20 |
+
LLMTableProcessor,
|
| 21 |
+
LLMTableMergeProcessor,
|
| 22 |
+
LLMFormProcessor,
|
| 23 |
+
LLMComplexRegionProcessor,
|
| 24 |
+
)
|
| 25 |
+
converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)
|
| 26 |
+
|
| 27 |
+
@cache
|
| 28 |
+
def build_document(self, filepath: str):
|
| 29 |
+
provider_cls = provider_from_filepath(filepath)
|
| 30 |
+
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 31 |
+
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 32 |
+
document_builder = DocumentBuilder(self.config)
|
| 33 |
+
document_builder.disable_ocr = True
|
| 34 |
+
with provider_cls(filepath, self.config) as provider:
|
| 35 |
+
document = document_builder(provider, layout_builder, ocr_builder)
|
| 36 |
+
|
| 37 |
+
for page in document.pages:
|
| 38 |
+
page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
|
| 39 |
+
|
| 40 |
+
for processor_cls in self.processor_list:
|
| 41 |
+
processor = self.resolve_dependencies(processor_cls)
|
| 42 |
+
processor(document)
|
| 43 |
+
|
| 44 |
+
return document
|
| 45 |
+
|
| 46 |
+
def __call__(self, filepath: str):
|
| 47 |
+
document = self.build_document(filepath)
|
| 48 |
+
renderer = self.resolve_dependencies(self.renderer)
|
| 49 |
+
return renderer(document)
|
marker/models.py
CHANGED
|
@@ -1,86 +1,49 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 4 |
-
|
| 5 |
-
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
|
| 6 |
-
from surya.model.layout.model import load_model as load_layout_model
|
| 7 |
-
from surya.model.layout.processor import load_processor as load_layout_processor
|
| 8 |
-
from texify.model.model import load_model as load_texify_model
|
| 9 |
-
from texify.model.processor import load_processor as load_texify_processor
|
| 10 |
from marker.settings import settings
|
| 11 |
-
from surya.model.recognition.model import load_model as load_recognition_model
|
| 12 |
-
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
| 13 |
-
from surya.model.table_rec.model import load_model as load_table_model
|
| 14 |
-
from surya.model.table_rec.processor import load_processor as load_table_processor
|
| 15 |
-
from surya.model.ocr_error.model import load_model as load_ocr_error_model
|
| 16 |
-
from surya.model.ocr_error.model import load_tokenizer as load_ocr_error_tokenizer
|
| 17 |
-
|
| 18 |
-
from texify.model.model import GenerateVisionEncoderDecoderModel
|
| 19 |
-
from surya.model.layout.encoderdecoder import SuryaLayoutModel
|
| 20 |
-
from surya.model.detection.model import EfficientViTForSemanticSegmentation
|
| 21 |
-
from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
|
| 22 |
-
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
| 23 |
-
from surya.model.ocr_error.model import DistilBertForSequenceClassification
|
| 24 |
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
if device:
|
| 28 |
-
table_model = load_table_model(device=device, dtype=dtype)
|
| 29 |
-
else:
|
| 30 |
-
table_model = load_table_model()
|
| 31 |
-
table_model.processor = load_table_processor()
|
| 32 |
-
return table_model
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def setup_recognition_model(device=None, dtype=None) -> OCREncoderDecoderModel:
|
| 36 |
-
if device:
|
| 37 |
-
rec_model = load_recognition_model(device=device, dtype=dtype)
|
| 38 |
-
else:
|
| 39 |
-
rec_model = load_recognition_model()
|
| 40 |
-
rec_model.processor = load_recognition_processor()
|
| 41 |
-
return rec_model
|
| 42 |
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
model.processor = load_detection_processor()
|
| 50 |
-
return model
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
return texify_model
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
def
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
|
| 70 |
-
def setup_ocr_error_model(device=None, dtype=None) -> DistilBertForSequenceClassification:
|
| 71 |
-
if device:
|
| 72 |
-
model = load_ocr_error_model(device=device, dtype=dtype)
|
| 73 |
-
else:
|
| 74 |
-
model = load_ocr_error_model()
|
| 75 |
-
model.tokenizer = load_ocr_error_tokenizer()
|
| 76 |
-
return model
|
| 77 |
|
| 78 |
def create_model_dict(device=None, dtype=None) -> dict:
|
| 79 |
return {
|
| 80 |
-
"layout_model":
|
| 81 |
-
"texify_model":
|
| 82 |
-
"recognition_model":
|
| 83 |
-
"table_rec_model":
|
| 84 |
-
"detection_model":
|
| 85 |
-
"ocr_error_model":
|
| 86 |
}
|
|
|
|
| 1 |
import os
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from marker.settings import settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from typing import List
|
| 8 |
+
from PIL import Image
|
| 9 |
|
| 10 |
+
from surya.detection import DetectionPredictor
|
| 11 |
+
from surya.layout import LayoutPredictor
|
| 12 |
+
from surya.ocr_error import OCRErrorPredictor
|
| 13 |
+
from surya.recognition import RecognitionPredictor
|
| 14 |
+
from surya.table_rec import TableRecPredictor
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from texify.model.model import load_model as load_texify_model
|
| 17 |
+
from texify.model.processor import load_processor as load_texify_processor
|
| 18 |
+
from texify.inference import batch_inference
|
| 19 |
|
| 20 |
+
class TexifyPredictor:
|
| 21 |
+
def __init__(self, device=None, dtype=None):
|
| 22 |
+
if not device:
|
| 23 |
+
device = settings.TORCH_DEVICE_MODEL
|
| 24 |
+
if not dtype:
|
| 25 |
+
dtype = settings.TEXIFY_DTYPE
|
|
|
|
| 26 |
|
| 27 |
+
self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
|
| 28 |
+
self.processor = load_texify_processor()
|
| 29 |
+
self.device = device
|
| 30 |
+
self.dtype = dtype
|
| 31 |
|
| 32 |
+
def __call__(self, batch_images: List[Image.Image], max_tokens: int):
|
| 33 |
+
return batch_inference(
|
| 34 |
+
batch_images,
|
| 35 |
+
self.model,
|
| 36 |
+
self.processor,
|
| 37 |
+
max_tokens=max_tokens
|
| 38 |
+
)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def create_model_dict(device=None, dtype=None) -> dict:
|
| 42 |
return {
|
| 43 |
+
"layout_model": LayoutPredictor(device=device, dtype=dtype),
|
| 44 |
+
"texify_model": TexifyPredictor(device=device, dtype=dtype),
|
| 45 |
+
"recognition_model": RecognitionPredictor(device=device, dtype=dtype),
|
| 46 |
+
"table_rec_model": TableRecPredictor(device=device, dtype=dtype),
|
| 47 |
+
"detection_model": DetectionPredictor(device=device, dtype=dtype),
|
| 48 |
+
"ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype)
|
| 49 |
}
|
marker/processors/debug.py
CHANGED
|
@@ -68,7 +68,7 @@ class DebugProcessor(BaseProcessor):
|
|
| 68 |
|
| 69 |
def draw_pdf_debug_images(self, document: Document):
|
| 70 |
for page in document.pages:
|
| 71 |
-
png_image = page.
|
| 72 |
|
| 73 |
line_bboxes = []
|
| 74 |
span_bboxes = []
|
|
@@ -90,7 +90,7 @@ class DebugProcessor(BaseProcessor):
|
|
| 90 |
|
| 91 |
def draw_layout_debug_images(self, document: Document, pdf_mode=False):
|
| 92 |
for page in document.pages:
|
| 93 |
-
img_size = page.
|
| 94 |
png_image = Image.new("RGB", img_size, color="white")
|
| 95 |
|
| 96 |
line_bboxes = []
|
|
|
|
| 68 |
|
| 69 |
def draw_pdf_debug_images(self, document: Document):
|
| 70 |
for page in document.pages:
|
| 71 |
+
png_image = page.get_image(highres=True).copy()
|
| 72 |
|
| 73 |
line_bboxes = []
|
| 74 |
span_bboxes = []
|
|
|
|
| 90 |
|
| 91 |
def draw_layout_debug_images(self, document: Document, pdf_mode=False):
|
| 92 |
for page in document.pages:
|
| 93 |
+
img_size = page.get_image(highres=True).size
|
| 94 |
png_image = Image.new("RGB", img_size, color="white")
|
| 95 |
|
| 96 |
line_bboxes = []
|
marker/processors/equation.py
CHANGED
|
@@ -4,6 +4,7 @@ from texify.inference import batch_inference
|
|
| 4 |
from texify.model.model import GenerateVisionEncoderDecoderModel
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
|
|
|
| 7 |
from marker.processors import BaseProcessor
|
| 8 |
from marker.schema import BlockTypes
|
| 9 |
from marker.schema.document import Document
|
|
@@ -32,7 +33,7 @@ class EquationProcessor(BaseProcessor):
|
|
| 32 |
"The number of tokens to buffer above max for the Texify model.",
|
| 33 |
] = 256
|
| 34 |
|
| 35 |
-
def __init__(self, texify_model:
|
| 36 |
super().__init__(config)
|
| 37 |
|
| 38 |
self.texify_model = texify_model
|
|
@@ -42,8 +43,7 @@ class EquationProcessor(BaseProcessor):
|
|
| 42 |
|
| 43 |
for page in document.pages:
|
| 44 |
for block in page.contained_blocks(document, self.block_types):
|
| 45 |
-
|
| 46 |
-
image = page.lowres_image.crop(image_poly.bbox).convert("RGB")
|
| 47 |
raw_text = block.raw_text(document)
|
| 48 |
token_count = self.get_total_texify_tokens(raw_text)
|
| 49 |
|
|
@@ -92,10 +92,8 @@ class EquationProcessor(BaseProcessor):
|
|
| 92 |
|
| 93 |
batch_images = [eq["image"] for eq in batch_equations]
|
| 94 |
|
| 95 |
-
model_output =
|
| 96 |
batch_images,
|
| 97 |
-
self.texify_model,
|
| 98 |
-
self.texify_model.processor,
|
| 99 |
max_tokens=max_length
|
| 100 |
)
|
| 101 |
|
|
|
|
| 4 |
from texify.model.model import GenerateVisionEncoderDecoderModel
|
| 5 |
from tqdm import tqdm
|
| 6 |
|
| 7 |
+
from marker.models import TexifyPredictor
|
| 8 |
from marker.processors import BaseProcessor
|
| 9 |
from marker.schema import BlockTypes
|
| 10 |
from marker.schema.document import Document
|
|
|
|
| 33 |
"The number of tokens to buffer above max for the Texify model.",
|
| 34 |
] = 256
|
| 35 |
|
| 36 |
+
def __init__(self, texify_model: TexifyPredictor, config=None):
|
| 37 |
super().__init__(config)
|
| 38 |
|
| 39 |
self.texify_model = texify_model
|
|
|
|
| 43 |
|
| 44 |
for page in document.pages:
|
| 45 |
for block in page.contained_blocks(document, self.block_types):
|
| 46 |
+
image = block.get_image(document, highres=False).convert("RGB")
|
|
|
|
| 47 |
raw_text = block.raw_text(document)
|
| 48 |
token_count = self.get_total_texify_tokens(raw_text)
|
| 49 |
|
|
|
|
| 92 |
|
| 93 |
batch_images = [eq["image"] for eq in batch_equations]
|
| 94 |
|
| 95 |
+
model_output = self.texify_model(
|
| 96 |
batch_images,
|
|
|
|
|
|
|
| 97 |
max_tokens=max_length
|
| 98 |
)
|
| 99 |
|
marker/processors/ignoretext.py
CHANGED
|
@@ -17,8 +17,7 @@ class IgnoreTextProcessor(BaseProcessor):
|
|
| 17 |
These blocks often represent repetitive or non-essential elements, such as headers, footers, or page numbers.
|
| 18 |
"""
|
| 19 |
block_types = (
|
| 20 |
-
BlockTypes.Text, BlockTypes.
|
| 21 |
-
BlockTypes.PageFooter, BlockTypes.SectionHeader,
|
| 22 |
BlockTypes.TextInlineMath
|
| 23 |
)
|
| 24 |
common_element_threshold: Annotated[
|
|
@@ -47,7 +46,6 @@ class IgnoreTextProcessor(BaseProcessor):
|
|
| 47 |
last_blocks = []
|
| 48 |
for page in document.pages:
|
| 49 |
initial_block = None
|
| 50 |
-
block = None
|
| 51 |
last_block = None
|
| 52 |
for block in page.contained_blocks(document, self.block_types):
|
| 53 |
if block.structure is not None:
|
|
|
|
| 17 |
These blocks often represent repetitive or non-essential elements, such as headers, footers, or page numbers.
|
| 18 |
"""
|
| 19 |
block_types = (
|
| 20 |
+
BlockTypes.Text, BlockTypes.SectionHeader,
|
|
|
|
| 21 |
BlockTypes.TextInlineMath
|
| 22 |
)
|
| 23 |
common_element_threshold: Annotated[
|
|
|
|
| 46 |
last_blocks = []
|
| 47 |
for page in document.pages:
|
| 48 |
initial_block = None
|
|
|
|
| 49 |
last_block = None
|
| 50 |
for block in page.contained_blocks(document, self.block_types):
|
| 51 |
if block.structure is not None:
|
marker/processors/llm/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
from typing import Annotated, Optional
|
| 3 |
|
|
@@ -16,7 +17,7 @@ class BaseLLMProcessor(BaseProcessor):
|
|
| 16 |
A processor for using LLMs to convert blocks.
|
| 17 |
"""
|
| 18 |
google_api_key: Annotated[
|
| 19 |
-
|
| 20 |
"The Google API key to use for the Gemini model.",
|
| 21 |
] = settings.GOOGLE_API_KEY
|
| 22 |
model_name: Annotated[
|
|
@@ -39,11 +40,6 @@ class BaseLLMProcessor(BaseProcessor):
|
|
| 39 |
float,
|
| 40 |
"The ratio to expand the image by when cropping.",
|
| 41 |
] = 0.01
|
| 42 |
-
gemini_rewriting_prompt: Annotated[
|
| 43 |
-
str,
|
| 44 |
-
"The prompt to use for rewriting text.",
|
| 45 |
-
"Default is a string containing the Gemini rewriting prompt."
|
| 46 |
-
] = ''
|
| 47 |
use_llm: Annotated[
|
| 48 |
bool,
|
| 49 |
"Whether to use the LLM model.",
|
|
@@ -84,10 +80,5 @@ class BaseLLMProcessor(BaseProcessor):
|
|
| 84 |
|
| 85 |
pbar.close()
|
| 86 |
|
| 87 |
-
def extract_image(self,
|
| 88 |
-
|
| 89 |
-
image_box = image_block.polygon\
|
| 90 |
-
.rescale(page.polygon.size, page_img.size)\
|
| 91 |
-
.expand(self.image_expansion_ratio, self.image_expansion_ratio)
|
| 92 |
-
cropped = page_img.crop(image_box.bbox)
|
| 93 |
-
return cropped
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 3 |
from typing import Annotated, Optional
|
| 4 |
|
|
|
|
| 17 |
A processor for using LLMs to convert blocks.
|
| 18 |
"""
|
| 19 |
google_api_key: Annotated[
|
| 20 |
+
str,
|
| 21 |
"The Google API key to use for the Gemini model.",
|
| 22 |
] = settings.GOOGLE_API_KEY
|
| 23 |
model_name: Annotated[
|
|
|
|
| 40 |
float,
|
| 41 |
"The ratio to expand the image by when cropping.",
|
| 42 |
] = 0.01
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
use_llm: Annotated[
|
| 44 |
bool,
|
| 45 |
"Whether to use the LLM model.",
|
|
|
|
| 80 |
|
| 81 |
pbar.close()
|
| 82 |
|
| 83 |
+
def extract_image(self, document: Document, image_block: Block):
|
| 84 |
+
return image_block.get_image(document, highres=False, expansion=(self.image_expansion_ratio, self.image_expansion_ratio))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker/processors/llm/llm_complex.py
CHANGED
|
@@ -12,9 +12,9 @@ from marker.schema.groups.page import PageGroup
|
|
| 12 |
|
| 13 |
class LLMComplexRegionProcessor(BaseLLMProcessor):
|
| 14 |
block_types = (BlockTypes.ComplexRegion,)
|
| 15 |
-
|
| 16 |
You will receive an image of a text block and the text that can be extracted from the image.
|
| 17 |
-
Your task is to
|
| 18 |
|
| 19 |
Formatting should be in markdown, with the following rules:
|
| 20 |
- * for italics, ** for bold, and ` for inline code.
|
|
@@ -29,27 +29,32 @@ Formatting should be in markdown, with the following rules:
|
|
| 29 |
|
| 30 |
**Instructions:**
|
| 31 |
1. Carefully examine the provided block image.
|
| 32 |
-
2. Analyze the text representation
|
| 33 |
-
3.
|
| 34 |
-
4. If the text representation contains errors, generate the corrected markdown representation.
|
| 35 |
-
5. Output only either the corrected markdown representation or "No corrections needed."
|
| 36 |
**Example:**
|
| 37 |
Input:
|
| 38 |
```text
|
| 39 |
-
|
| 40 |
```
|
| 41 |
Output:
|
| 42 |
```markdown
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
```
|
| 45 |
**Input:**
|
|
|
|
|
|
|
|
|
|
| 46 |
"""
|
| 47 |
|
| 48 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 49 |
text = block.raw_text(document)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
image = self.extract_image(page, block)
|
| 53 |
response_schema = content.Schema(
|
| 54 |
type=content.Type.OBJECT,
|
| 55 |
enum=[],
|
|
@@ -79,4 +84,5 @@ No corrections needed.
|
|
| 79 |
return
|
| 80 |
|
| 81 |
# Convert LLM markdown to html
|
|
|
|
| 82 |
block.html = markdown2.markdown(corrected_markdown)
|
|
|
|
| 12 |
|
| 13 |
class LLMComplexRegionProcessor(BaseLLMProcessor):
|
| 14 |
block_types = (BlockTypes.ComplexRegion,)
|
| 15 |
+
complex_region_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 16 |
You will receive an image of a text block and the text that can be extracted from the image.
|
| 17 |
+
Your task is to generate markdown to properly represent the content of the image. Do not omit any text present in the image - make sure everything is included in the markdown representation. The markdown representation should be as faithful to the original image as possible.
|
| 18 |
|
| 19 |
Formatting should be in markdown, with the following rules:
|
| 20 |
- * for italics, ** for bold, and ` for inline code.
|
|
|
|
| 29 |
|
| 30 |
**Instructions:**
|
| 31 |
1. Carefully examine the provided block image.
|
| 32 |
+
2. Analyze the existing text representation.
|
| 33 |
+
3. Generate the markdown representation of the content in the image.
|
|
|
|
|
|
|
| 34 |
**Example:**
|
| 35 |
Input:
|
| 36 |
```text
|
| 37 |
+
Table 1: Car Sales
|
| 38 |
```
|
| 39 |
Output:
|
| 40 |
```markdown
|
| 41 |
+
## Table 1: Car Sales
|
| 42 |
+
|
| 43 |
+
| Car | Sales |
|
| 44 |
+
| --- | --- |
|
| 45 |
+
| Honda | 100 |
|
| 46 |
+
| Toyota | 200 |
|
| 47 |
```
|
| 48 |
**Input:**
|
| 49 |
+
```text
|
| 50 |
+
{extracted_text}
|
| 51 |
+
```
|
| 52 |
"""
|
| 53 |
|
| 54 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 55 |
text = block.raw_text(document)
|
| 56 |
+
prompt = self.complex_region_prompt.replace("{extracted_text}", text)
|
| 57 |
+
image = self.extract_image(document, block)
|
|
|
|
| 58 |
response_schema = content.Schema(
|
| 59 |
type=content.Type.OBJECT,
|
| 60 |
enum=[],
|
|
|
|
| 84 |
return
|
| 85 |
|
| 86 |
# Convert LLM markdown to html
|
| 87 |
+
corrected_markdown = corrected_markdown.strip().lstrip("```markdown").rstrip("```").strip()
|
| 88 |
block.html = markdown2.markdown(corrected_markdown)
|
marker/processors/llm/llm_equation.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from marker.processors.llm import BaseLLMProcessor
|
| 2 |
+
|
| 3 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 4 |
+
|
| 5 |
+
from marker.schema import BlockTypes
|
| 6 |
+
from marker.schema.blocks import Equation
|
| 7 |
+
from marker.schema.document import Document
|
| 8 |
+
from marker.schema.groups.page import PageGroup
|
| 9 |
+
|
| 10 |
+
from typing import Annotated
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LLMEquationProcessor(BaseLLMProcessor):
|
| 14 |
+
block_types = (BlockTypes.Equation,)
|
| 15 |
+
min_equation_height: Annotated[
|
| 16 |
+
float,
|
| 17 |
+
"The minimum ratio between equation height and page height to consider for processing.",
|
| 18 |
+
] = 0.1
|
| 19 |
+
equation_latex_prompt: Annotated[
|
| 20 |
+
str,
|
| 21 |
+
"The prompt to use for generating LaTeX from equations.",
|
| 22 |
+
"Default is a string containing the Gemini prompt."
|
| 23 |
+
] = """You're an expert mathematician who is good at writing LaTeX code for equations'.
|
| 24 |
+
You will receive an image of a math block that may contain one or more equations. Your job is to write the LaTeX code for the equation, along with markdown for any other text.
|
| 25 |
+
|
| 26 |
+
Some guidelines:
|
| 27 |
+
- Keep the LaTeX code simple and concise.
|
| 28 |
+
- Make it KaTeX compatible.
|
| 29 |
+
- Use $$ as a block equation delimiter and $ for inline equations. Block equations should also be on their own line. Do not use any other delimiters.
|
| 30 |
+
- You can include text in between equation blocks as needed. Try to put long text segments into plain text and not inside the equations.
|
| 31 |
+
|
| 32 |
+
**Instructions:**
|
| 33 |
+
1. Carefully examine the provided image.
|
| 34 |
+
2. Analyze the existing markdown, which may include LaTeX code.
|
| 35 |
+
3. If the markdown and LaTeX are correct, write "No corrections needed."
|
| 36 |
+
4. If the markdown and LaTeX are incorrect, generate the corrected markdown and LaTeX.
|
| 37 |
+
5. Output only the corrected text or "No corrections needed."
|
| 38 |
+
**Example:**
|
| 39 |
+
Input:
|
| 40 |
+
```markdown
|
| 41 |
+
Equation 1:
|
| 42 |
+
$$x^2 + y^2 = z2$$
|
| 43 |
+
```
|
| 44 |
+
Output:
|
| 45 |
+
```markdown
|
| 46 |
+
Equation 1:
|
| 47 |
+
$$x^2 + y^2 = z^2$$
|
| 48 |
+
```
|
| 49 |
+
**Input:**
|
| 50 |
+
```markdown
|
| 51 |
+
{equation}
|
| 52 |
+
```
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
|
| 56 |
+
text = block.latex if block.latex else block.raw_text(document)
|
| 57 |
+
prompt = self.equation_latex_prompt.replace("{equation}", text)
|
| 58 |
+
|
| 59 |
+
image = self.extract_image(document, block)
|
| 60 |
+
response_schema = content.Schema(
|
| 61 |
+
type=content.Type.OBJECT,
|
| 62 |
+
enum=[],
|
| 63 |
+
required=["markdown_equation"],
|
| 64 |
+
properties={
|
| 65 |
+
"markdown_equation": content.Schema(
|
| 66 |
+
type=content.Type.STRING
|
| 67 |
+
)
|
| 68 |
+
},
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
response = self.model.generate_response(prompt, image, block, response_schema)
|
| 72 |
+
|
| 73 |
+
if not response or "markdown_equation" not in response:
|
| 74 |
+
block.update_metadata(llm_error_count=1)
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
markdown_equation = response["markdown_equation"]
|
| 78 |
+
if len(markdown_equation) < len(text) * .5:
|
| 79 |
+
block.update_metadata(llm_error_count=1)
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
block.latex = markdown_equation
|
marker/processors/llm/llm_form.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
-
import markdown2
|
| 2 |
-
|
| 3 |
from marker.processors.llm import BaseLLMProcessor
|
| 4 |
|
| 5 |
from google.ai.generativelanguage_v1beta.types import content
|
| 6 |
-
from tabled.formats import markdown_format
|
| 7 |
|
| 8 |
from marker.schema import BlockTypes
|
| 9 |
from marker.schema.blocks import Block
|
|
@@ -13,48 +10,75 @@ from marker.schema.groups.page import PageGroup
|
|
| 13 |
|
| 14 |
class LLMFormProcessor(BaseLLMProcessor):
|
| 15 |
block_types = (BlockTypes.Form,)
|
| 16 |
-
|
| 17 |
-
You will receive an image of a text block and
|
| 18 |
-
Your task is to correct any errors in the
|
| 19 |
-
Values and labels should appear in
|
| 20 |
**Instructions:**
|
| 21 |
1. Carefully examine the provided form block image.
|
| 22 |
-
2. Analyze the
|
| 23 |
-
3. If the
|
| 24 |
-
4. If the
|
| 25 |
-
5. Output only either the corrected
|
| 26 |
**Example:**
|
| 27 |
Input:
|
| 28 |
-
```
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
```
|
| 33 |
Output:
|
| 34 |
-
```
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
```
|
| 41 |
**Input:**
|
|
|
|
|
|
|
|
|
|
| 42 |
"""
|
| 43 |
|
| 44 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 45 |
-
|
| 46 |
-
if
|
| 47 |
# Happens if table/form processors didn't run
|
| 48 |
return
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
response_schema = content.Schema(
|
| 53 |
type=content.Type.OBJECT,
|
| 54 |
enum=[],
|
| 55 |
-
required=["
|
| 56 |
properties={
|
| 57 |
-
"
|
| 58 |
type=content.Type.STRING
|
| 59 |
)
|
| 60 |
},
|
|
@@ -62,22 +86,20 @@ Output:
|
|
| 62 |
|
| 63 |
response = self.model.generate_response(prompt, image, block, response_schema)
|
| 64 |
|
| 65 |
-
if not response or "
|
| 66 |
block.update_metadata(llm_error_count=1)
|
| 67 |
return
|
| 68 |
|
| 69 |
-
|
| 70 |
|
| 71 |
# The original table is okay
|
| 72 |
-
if "no corrections" in
|
| 73 |
return
|
| 74 |
|
| 75 |
-
orig_cell_text = "".join([cell.text for cell in cells])
|
| 76 |
-
|
| 77 |
# Potentially a partial response
|
| 78 |
-
if len(
|
| 79 |
block.update_metadata(llm_error_count=1)
|
| 80 |
return
|
| 81 |
|
| 82 |
-
|
| 83 |
-
block.html =
|
|
|
|
|
|
|
|
|
|
| 1 |
from marker.processors.llm import BaseLLMProcessor
|
| 2 |
|
| 3 |
from google.ai.generativelanguage_v1beta.types import content
|
|
|
|
| 4 |
|
| 5 |
from marker.schema import BlockTypes
|
| 6 |
from marker.schema.blocks import Block
|
|
|
|
| 10 |
|
| 11 |
class LLMFormProcessor(BaseLLMProcessor):
|
| 12 |
block_types = (BlockTypes.Form,)
|
| 13 |
+
form_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 14 |
+
You will receive an image of a text block and an html representation of the form in the image.
|
| 15 |
+
Your task is to correct any errors in the html representation, and format it properly.
|
| 16 |
+
Values and labels should appear in html tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables. Only use the tags `table, p, span, i, b, th, td, tr, and div`. Do not omit any text from the form - make sure everything is included in the html representation. It should be as faithful to the original form as possible.
|
| 17 |
**Instructions:**
|
| 18 |
1. Carefully examine the provided form block image.
|
| 19 |
+
2. Analyze the html representation of the form.
|
| 20 |
+
3. If the html representation is largely correct, then write "No corrections needed."
|
| 21 |
+
4. If the html representation contains errors, generate the corrected html representation.
|
| 22 |
+
5. Output only either the corrected html representation or "No corrections needed."
|
| 23 |
**Example:**
|
| 24 |
Input:
|
| 25 |
+
```html
|
| 26 |
+
<table>
|
| 27 |
+
<tr>
|
| 28 |
+
<td>Label 1</td>
|
| 29 |
+
<td>Label 2</td>
|
| 30 |
+
<td>Label 3</td>
|
| 31 |
+
</tr>
|
| 32 |
+
<tr>
|
| 33 |
+
<td>Value 1</td>
|
| 34 |
+
<td>Value 2</td>
|
| 35 |
+
<td>Value 3</td>
|
| 36 |
+
</tr>
|
| 37 |
+
</table>
|
| 38 |
```
|
| 39 |
Output:
|
| 40 |
+
```html
|
| 41 |
+
<table>
|
| 42 |
+
<tr>
|
| 43 |
+
<th>Labels</th>
|
| 44 |
+
<th>Values</th>
|
| 45 |
+
</tr>
|
| 46 |
+
<tr>
|
| 47 |
+
<td>Label 1</td>
|
| 48 |
+
<td>Value 1</td>
|
| 49 |
+
</tr>
|
| 50 |
+
<tr>
|
| 51 |
+
<td>Label 2</td>
|
| 52 |
+
<td>Value 2</td>
|
| 53 |
+
</tr>
|
| 54 |
+
<tr>
|
| 55 |
+
<td>Label 3</td>
|
| 56 |
+
<td>Value 3</td>
|
| 57 |
+
</tr>
|
| 58 |
+
</table>
|
| 59 |
```
|
| 60 |
**Input:**
|
| 61 |
+
```html
|
| 62 |
+
{block_html}
|
| 63 |
+
```
|
| 64 |
"""
|
| 65 |
|
| 66 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 67 |
+
children = block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 68 |
+
if not children:
|
| 69 |
# Happens if table/form processors didn't run
|
| 70 |
return
|
| 71 |
|
| 72 |
+
block_html = block.render(document).html
|
| 73 |
+
prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
|
| 74 |
+
|
| 75 |
+
image = self.extract_image(document, block)
|
| 76 |
response_schema = content.Schema(
|
| 77 |
type=content.Type.OBJECT,
|
| 78 |
enum=[],
|
| 79 |
+
required=["corrected_html"],
|
| 80 |
properties={
|
| 81 |
+
"corrected_html": content.Schema(
|
| 82 |
type=content.Type.STRING
|
| 83 |
)
|
| 84 |
},
|
|
|
|
| 86 |
|
| 87 |
response = self.model.generate_response(prompt, image, block, response_schema)
|
| 88 |
|
| 89 |
+
if not response or "corrected_html" not in response:
|
| 90 |
block.update_metadata(llm_error_count=1)
|
| 91 |
return
|
| 92 |
|
| 93 |
+
corrected_html = response["corrected_html"]
|
| 94 |
|
| 95 |
# The original table is okay
|
| 96 |
+
if "no corrections" in corrected_html.lower():
|
| 97 |
return
|
| 98 |
|
|
|
|
|
|
|
| 99 |
# Potentially a partial response
|
| 100 |
+
if len(corrected_html) < len(block_html) * .33:
|
| 101 |
block.update_metadata(llm_error_count=1)
|
| 102 |
return
|
| 103 |
|
| 104 |
+
corrected_html = corrected_html.strip().lstrip("```html").rstrip("```").strip()
|
| 105 |
+
block.html = corrected_html
|
marker/processors/llm/llm_handwriting.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import markdown2
|
| 2 |
+
|
| 3 |
+
from marker.processors.llm import BaseLLMProcessor
|
| 4 |
+
|
| 5 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 6 |
+
|
| 7 |
+
from marker.schema import BlockTypes
|
| 8 |
+
from marker.schema.blocks import Equation
|
| 9 |
+
from marker.schema.document import Document
|
| 10 |
+
from marker.schema.groups.page import PageGroup
|
| 11 |
+
|
| 12 |
+
from typing import Annotated
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMHandwritingProcessor(BaseLLMProcessor):
|
| 16 |
+
block_types = (BlockTypes.Equation,)
|
| 17 |
+
min_handwriting_height: Annotated[
|
| 18 |
+
float,
|
| 19 |
+
"The minimum ratio between handwriting height and page height to consider for processing.",
|
| 20 |
+
] = 0.1
|
| 21 |
+
handwriting_generation_prompt: Annotated[
|
| 22 |
+
str,
|
| 23 |
+
"The prompt to use for OCRing handwriting.",
|
| 24 |
+
"Default is a string containing the Gemini prompt."
|
| 25 |
+
] = """You are an expert editor specializing in accurately reproducing text from images.
|
| 26 |
+
You will receive an image of a text block, along with the text that can be extracted. Your task is to generate markdown to properly represent the content of the image. Do not omit any text present in the image - make sure everything is included in the markdown representation. The markdown representation should be as faithful to the original image as possible.
|
| 27 |
+
|
| 28 |
+
Formatting should be in markdown, with the following rules:
|
| 29 |
+
- * for italics, ** for bold, and ` for inline code.
|
| 30 |
+
- Headers should be formatted with #, with one # for the largest header, and up to 6 for the smallest.
|
| 31 |
+
- Lists should be formatted with either - or 1. for unordered and ordered lists, respectively.
|
| 32 |
+
- Links should be formatted with [text](url).
|
| 33 |
+
- Use ``` for code blocks.
|
| 34 |
+
- Inline math should be formatted with <math>math expression</math>.
|
| 35 |
+
- Display math should be formatted with <math display="block">math expression</math>.
|
| 36 |
+
- Values and labels should be extracted from forms, and put into markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
|
| 37 |
+
- Tables should be formatted with markdown tables, with the headers bolded.
|
| 38 |
+
|
| 39 |
+
**Instructions:**
|
| 40 |
+
1. Carefully examine the provided block image.
|
| 41 |
+
2. Analyze the existing text representation.
|
| 42 |
+
3. Output the markdown representing the content of the image.
|
| 43 |
+
**Example:**
|
| 44 |
+
Input:
|
| 45 |
+
```text
|
| 46 |
+
This i sm handwritting.
|
| 47 |
+
```
|
| 48 |
+
Output:
|
| 49 |
+
```markdown
|
| 50 |
+
This is some *handwriting*.
|
| 51 |
+
```
|
| 52 |
+
**Input:**
|
| 53 |
+
```text
|
| 54 |
+
{extracted_text}
|
| 55 |
+
```
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
|
| 59 |
+
text = block.raw_text(document)
|
| 60 |
+
prompt = self.handwriting_generation_prompt.replace("{handwriting_text}", text)
|
| 61 |
+
|
| 62 |
+
image = self.extract_image(document, block)
|
| 63 |
+
response_schema = content.Schema(
|
| 64 |
+
type=content.Type.OBJECT,
|
| 65 |
+
enum=[],
|
| 66 |
+
required=["markdown"],
|
| 67 |
+
properties={
|
| 68 |
+
"markdown": content.Schema(
|
| 69 |
+
type=content.Type.STRING
|
| 70 |
+
)
|
| 71 |
+
},
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
response = self.model.generate_response(prompt, image, block, response_schema)
|
| 75 |
+
|
| 76 |
+
if not response or "markdown" not in response:
|
| 77 |
+
block.update_metadata(llm_error_count=1)
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
markdown = response["markdown"]
|
| 81 |
+
if len(markdown) < len(text) * .5:
|
| 82 |
+
block.update_metadata(llm_error_count=1)
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
markdown = markdown.strip().lstrip("```markdown").rstrip("```").strip()
|
| 86 |
+
block.html = markdown2.markdown(markdown)
|
marker/processors/llm/llm_image_description.py
CHANGED
|
@@ -36,6 +36,9 @@ Apples, Bananas, Oranges
|
|
| 36 |
Output:
|
| 37 |
In this figure, a bar chart titled "Fruit Preference Survey" is showing the number of people who prefer different types of fruits. The x-axis shows the types of fruits, and the y-axis shows the number of people. The bar chart shows that most people prefer apples, followed by bananas and oranges. 20 people prefer apples, 15 people prefer bananas, and 10 people prefer oranges.
|
| 38 |
**Input:**
|
|
|
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
|
@@ -44,8 +47,8 @@ In this figure, a bar chart titled "Fruit Preference Survey" is showing the numb
|
|
| 44 |
# Since this processor replaces images with descriptions
|
| 45 |
return
|
| 46 |
|
| 47 |
-
prompt = self.image_description_prompt
|
| 48 |
-
image = self.extract_image(
|
| 49 |
response_schema = content.Schema(
|
| 50 |
type=content.Type.OBJECT,
|
| 51 |
enum=[],
|
|
|
|
| 36 |
Output:
|
| 37 |
In this figure, a bar chart titled "Fruit Preference Survey" is showing the number of people who prefer different types of fruits. The x-axis shows the types of fruits, and the y-axis shows the number of people. The bar chart shows that most people prefer apples, followed by bananas and oranges. 20 people prefer apples, 15 people prefer bananas, and 10 people prefer oranges.
|
| 38 |
**Input:**
|
| 39 |
+
```text
|
| 40 |
+
{raw_text}
|
| 41 |
+
```
|
| 42 |
"""
|
| 43 |
|
| 44 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
|
|
|
| 47 |
# Since this processor replaces images with descriptions
|
| 48 |
return
|
| 49 |
|
| 50 |
+
prompt = self.image_description_prompt.replace("{raw_text}", block.raw_text(document))
|
| 51 |
+
image = self.extract_image(document, block)
|
| 52 |
response_schema = content.Schema(
|
| 53 |
type=content.Type.OBJECT,
|
| 54 |
enum=[],
|
marker/processors/llm/llm_table.py
CHANGED
|
@@ -2,12 +2,10 @@ from typing import Annotated, List, Tuple
|
|
| 2 |
|
| 3 |
from bs4 import BeautifulSoup
|
| 4 |
from google.ai.generativelanguage_v1beta.types import content
|
| 5 |
-
from tabled.formats import html_format
|
| 6 |
-
from tabled.schema import SpanTableCell
|
| 7 |
|
| 8 |
from marker.processors.llm import BaseLLMProcessor
|
| 9 |
from marker.schema import BlockTypes
|
| 10 |
-
from marker.schema.blocks import Block
|
| 11 |
from marker.schema.document import Document
|
| 12 |
from marker.schema.groups.page import PageGroup
|
| 13 |
from marker.schema.polygon import PolygonBox
|
|
@@ -17,19 +15,26 @@ class LLMTableProcessor(BaseLLMProcessor):
|
|
| 17 |
block_types: Annotated[
|
| 18 |
Tuple[BlockTypes],
|
| 19 |
"The block types to process.",
|
| 20 |
-
] = (BlockTypes.Table,)
|
| 21 |
-
|
| 22 |
str,
|
| 23 |
"The prompt to use for rewriting text.",
|
| 24 |
"Default is a string containing the Gemini rewriting prompt."
|
| 25 |
] = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 26 |
You will receive an image of a text block and an html representation of the table in the image.
|
| 27 |
Your task is to correct any errors in the html representation. The html representation should be as faithful to the original table as possible.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
**Instructions:**
|
| 29 |
1. Carefully examine the provided text block image.
|
| 30 |
2. Analyze the html representation of the table.
|
| 31 |
3. If the html representation is largely correct, then write "No corrections needed."
|
| 32 |
-
4. If the html representation contains errors, generate the corrected html representation.
|
| 33 |
5. Output only either the corrected html representation or "No corrections needed."
|
| 34 |
**Example:**
|
| 35 |
Input:
|
|
@@ -52,16 +57,21 @@ Output:
|
|
| 52 |
No corrections needed.
|
| 53 |
```
|
| 54 |
**Input:**
|
|
|
|
|
|
|
|
|
|
| 55 |
"""
|
| 56 |
|
| 57 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 58 |
-
|
| 59 |
-
if
|
| 60 |
# Happens if table/form processors didn't run
|
| 61 |
return
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
response_schema = content.Schema(
|
| 66 |
type=content.Type.OBJECT,
|
| 67 |
enum=[],
|
|
@@ -85,31 +95,49 @@ No corrections needed.
|
|
| 85 |
if "no corrections" in corrected_html.lower():
|
| 86 |
return
|
| 87 |
|
| 88 |
-
|
|
|
|
| 89 |
if len(parsed_cells) <= 1:
|
| 90 |
block.update_metadata(llm_error_count=1)
|
| 91 |
return
|
| 92 |
|
| 93 |
parsed_cell_text = "".join([cell.text for cell in parsed_cells])
|
| 94 |
-
orig_cell_text = "".join([cell.text for cell in
|
| 95 |
-
|
| 96 |
# Potentially a partial response
|
| 97 |
if len(parsed_cell_text) < len(orig_cell_text) * .5:
|
| 98 |
block.update_metadata(llm_error_count=1)
|
| 99 |
return
|
| 100 |
|
| 101 |
-
block.
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
soup = BeautifulSoup(html_text, 'html.parser')
|
| 105 |
table = soup.find('table')
|
| 106 |
|
| 107 |
# Initialize grid
|
| 108 |
rows = table.find_all('tr')
|
| 109 |
cells = []
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
grid = [[True] * max_cols for _ in range(len(rows))]
|
| 115 |
|
|
@@ -124,7 +152,7 @@ No corrections needed.
|
|
| 124 |
print("Table parsing warning: too many columns found")
|
| 125 |
break
|
| 126 |
|
| 127 |
-
cell_text = cell.
|
| 128 |
rowspan = min(int(cell.get('rowspan', 1)), len(rows) - i)
|
| 129 |
colspan = min(int(cell.get('colspan', 1)), max_cols - cur_col)
|
| 130 |
cell_rows = list(range(i, i + rowspan))
|
|
@@ -146,11 +174,15 @@ No corrections needed.
|
|
| 146 |
]
|
| 147 |
cell_polygon = PolygonBox.from_bbox(cell_bbox)
|
| 148 |
|
| 149 |
-
cell_obj =
|
| 150 |
text=cell_text,
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
cells.append(cell_obj)
|
| 156 |
cur_col += colspan
|
|
|
|
| 2 |
|
| 3 |
from bs4 import BeautifulSoup
|
| 4 |
from google.ai.generativelanguage_v1beta.types import content
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from marker.processors.llm import BaseLLMProcessor
|
| 7 |
from marker.schema import BlockTypes
|
| 8 |
+
from marker.schema.blocks import Block, TableCell
|
| 9 |
from marker.schema.document import Document
|
| 10 |
from marker.schema.groups.page import PageGroup
|
| 11 |
from marker.schema.polygon import PolygonBox
|
|
|
|
| 15 |
block_types: Annotated[
|
| 16 |
Tuple[BlockTypes],
|
| 17 |
"The block types to process.",
|
| 18 |
+
] = (BlockTypes.Table, BlockTypes.TableOfContents)
|
| 19 |
+
table_rewriting_prompt: Annotated[
|
| 20 |
str,
|
| 21 |
"The prompt to use for rewriting text.",
|
| 22 |
"Default is a string containing the Gemini rewriting prompt."
|
| 23 |
] = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 24 |
You will receive an image of a text block and an html representation of the table in the image.
|
| 25 |
Your task is to correct any errors in the html representation. The html representation should be as faithful to the original table as possible.
|
| 26 |
+
|
| 27 |
+
Some guidelines:
|
| 28 |
+
- Make sure to reproduce the original values as faithfully as possible.
|
| 29 |
+
- If you see any math in a table cell, fence it with the <math display="inline"> tag. Block math should be fenced with <math display="block">.
|
| 30 |
+
- Replace any images with a description, like "Image: [description]".
|
| 31 |
+
- Only use the tags th, td, tr, span, i, b, math, and table. Only use the attributes display, style, colspan, and rowspan if necessary.
|
| 32 |
+
|
| 33 |
**Instructions:**
|
| 34 |
1. Carefully examine the provided text block image.
|
| 35 |
2. Analyze the html representation of the table.
|
| 36 |
3. If the html representation is largely correct, then write "No corrections needed."
|
| 37 |
+
4. If the html representation contains errors, generate the corrected html representation.
|
| 38 |
5. Output only either the corrected html representation or "No corrections needed."
|
| 39 |
**Example:**
|
| 40 |
Input:
|
|
|
|
| 57 |
No corrections needed.
|
| 58 |
```
|
| 59 |
**Input:**
|
| 60 |
+
```html
|
| 61 |
+
{block_html}
|
| 62 |
+
```
|
| 63 |
"""
|
| 64 |
|
| 65 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
| 66 |
+
children = block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 67 |
+
if not children:
|
| 68 |
# Happens if table/form processors didn't run
|
| 69 |
return
|
| 70 |
|
| 71 |
+
block_html = block.render(document).html
|
| 72 |
+
prompt = self.table_rewriting_prompt.replace("{block_html}", block_html)
|
| 73 |
+
|
| 74 |
+
image = self.extract_image(document, block)
|
| 75 |
response_schema = content.Schema(
|
| 76 |
type=content.Type.OBJECT,
|
| 77 |
enum=[],
|
|
|
|
| 95 |
if "no corrections" in corrected_html.lower():
|
| 96 |
return
|
| 97 |
|
| 98 |
+
corrected_html = corrected_html.strip().lstrip("```html").rstrip("```").strip()
|
| 99 |
+
parsed_cells = self.parse_html_table(corrected_html, block, page)
|
| 100 |
if len(parsed_cells) <= 1:
|
| 101 |
block.update_metadata(llm_error_count=1)
|
| 102 |
return
|
| 103 |
|
| 104 |
parsed_cell_text = "".join([cell.text for cell in parsed_cells])
|
| 105 |
+
orig_cell_text = "".join([cell.text for cell in children])
|
|
|
|
| 106 |
# Potentially a partial response
|
| 107 |
if len(parsed_cell_text) < len(orig_cell_text) * .5:
|
| 108 |
block.update_metadata(llm_error_count=1)
|
| 109 |
return
|
| 110 |
|
| 111 |
+
block.structure = []
|
| 112 |
+
for cell in parsed_cells:
|
| 113 |
+
page.add_full_block(cell)
|
| 114 |
+
block.add_structure(cell)
|
| 115 |
|
| 116 |
+
@staticmethod
|
| 117 |
+
def get_cell_text(element, keep_tags=('br',)):
|
| 118 |
+
for tag in element.find_all(True):
|
| 119 |
+
if tag.name not in keep_tags:
|
| 120 |
+
tag.unwrap()
|
| 121 |
+
return element.decode_contents().replace("<br>", "\n")
|
| 122 |
+
|
| 123 |
+
def parse_html_table(self, html_text: str, block: Block, page: PageGroup) -> List[TableCell]:
|
| 124 |
soup = BeautifulSoup(html_text, 'html.parser')
|
| 125 |
table = soup.find('table')
|
| 126 |
|
| 127 |
# Initialize grid
|
| 128 |
rows = table.find_all('tr')
|
| 129 |
cells = []
|
| 130 |
+
|
| 131 |
+
# Find maximum number of columns in colspan-aware way
|
| 132 |
+
max_cols = 0
|
| 133 |
+
for row in rows:
|
| 134 |
+
row_tds = row.find_all(['td', 'th'])
|
| 135 |
+
curr_cols = 0
|
| 136 |
+
for cell in row_tds:
|
| 137 |
+
colspan = int(cell.get('colspan', 1))
|
| 138 |
+
curr_cols += colspan
|
| 139 |
+
if curr_cols > max_cols:
|
| 140 |
+
max_cols = curr_cols
|
| 141 |
|
| 142 |
grid = [[True] * max_cols for _ in range(len(rows))]
|
| 143 |
|
|
|
|
| 152 |
print("Table parsing warning: too many columns found")
|
| 153 |
break
|
| 154 |
|
| 155 |
+
cell_text = self.get_cell_text(cell).strip()
|
| 156 |
rowspan = min(int(cell.get('rowspan', 1)), len(rows) - i)
|
| 157 |
colspan = min(int(cell.get('colspan', 1)), max_cols - cur_col)
|
| 158 |
cell_rows = list(range(i, i + rowspan))
|
|
|
|
| 174 |
]
|
| 175 |
cell_polygon = PolygonBox.from_bbox(cell_bbox)
|
| 176 |
|
| 177 |
+
cell_obj = TableCell(
|
| 178 |
text=cell_text,
|
| 179 |
+
row_id=i,
|
| 180 |
+
col_id=cur_col,
|
| 181 |
+
rowspan=rowspan,
|
| 182 |
+
colspan=colspan,
|
| 183 |
+
is_header=cell.name == 'th',
|
| 184 |
+
polygon=cell_polygon,
|
| 185 |
+
page_id=page.page_id,
|
| 186 |
)
|
| 187 |
cells.append(cell_obj)
|
| 188 |
cur_col += colspan
|
marker/processors/llm/llm_table_merge.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
+
from typing import Annotated, List, Tuple, Literal
|
| 3 |
+
|
| 4 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from marker.processors.llm import BaseLLMProcessor
|
| 9 |
+
from marker.schema import BlockTypes
|
| 10 |
+
from marker.schema.blocks import Block, TableCell
|
| 11 |
+
from marker.schema.document import Document
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LLMTableMergeProcessor(BaseLLMProcessor):
|
| 15 |
+
block_types: Annotated[
|
| 16 |
+
Tuple[BlockTypes],
|
| 17 |
+
"The block types to process.",
|
| 18 |
+
] = (BlockTypes.Table, BlockTypes.TableOfContents)
|
| 19 |
+
table_height_threshold: Annotated[
|
| 20 |
+
float,
|
| 21 |
+
"The minimum height ratio relative to the page for the first table in a pair to be considered for merging.",
|
| 22 |
+
] = 0.6
|
| 23 |
+
table_start_threshold: Annotated[
|
| 24 |
+
float,
|
| 25 |
+
"The maximum percentage down the page the second table can start to be considered for merging."
|
| 26 |
+
] = 0.2
|
| 27 |
+
vertical_table_height_threshold: Annotated[
|
| 28 |
+
float,
|
| 29 |
+
"The height tolerance for 2 adjacent tables to be merged into one."
|
| 30 |
+
] = 0.25
|
| 31 |
+
vertical_table_distance_threshold: Annotated[
|
| 32 |
+
int,
|
| 33 |
+
"The maximum distance between table edges for adjacency."
|
| 34 |
+
] = 20
|
| 35 |
+
column_gap_threshold: Annotated[
|
| 36 |
+
int,
|
| 37 |
+
"The maximum gap between columns to merge tables"
|
| 38 |
+
] = 50
|
| 39 |
+
table_merge_prompt: Annotated[
|
| 40 |
+
str,
|
| 41 |
+
"The prompt to use for rewriting text.",
|
| 42 |
+
"Default is a string containing the Gemini rewriting prompt."
|
| 43 |
+
] = """You're a text correction expert specializing in accurately reproducing tables from PDFs.
|
| 44 |
+
You'll receive two images of tables from successive pages of a PDF. Table 1 is from the first page, and Table 2 is from the second page. Both tables may actually be part of the same larger table. Your job is to decide if Table 2 should be merged with Table 1, and how they should be joined. The should only be merged if they're part of the same larger table, and Table 2 cannot be interpreted without merging.
|
| 45 |
+
|
| 46 |
+
You'll specify your judgement in json format - first whether Table 2 should be merged with Table 1, then the direction of the merge, either `bottom` or `right`. A bottom merge means that the rows of Table 2 are joined to the rows of Table 1. A right merge means that the columns of Table 2 are joined to the columns of Table 1. (bottom merge is equal to np.vstack, right merge is equal to np.hstack)
|
| 47 |
+
|
| 48 |
+
Table 2 should be merged at the bottom of Table 1 if Table 2 has no headers, and the rows have similar values, meaning that Table 2 continues Table 1. Table 2 should be merged to the right of Table 1 if each row in Table 2 matches a row in Table 1, meaning that Table 2 contains additional columns that augment Table 1.
|
| 49 |
+
|
| 50 |
+
Only merge Table 1 and Table 2 if Table 2 cannot be interpreted without merging.
|
| 51 |
+
|
| 52 |
+
**Instructions:**
|
| 53 |
+
1. Carefully examine the provided table images. Table 1 is the first image, and Table 2 is the second image.
|
| 54 |
+
2. Examine the provided html representations of Table 1 and Table 2.
|
| 55 |
+
3. Write a description of Table 1.
|
| 56 |
+
4. Write a description of Table 2.
|
| 57 |
+
5. Analyze whether Table 2 should be merged into Table 1, and write an explanation.
|
| 58 |
+
6. Output your decision on whether they should be merged, and merge direction.
|
| 59 |
+
**Example:**
|
| 60 |
+
Input:
|
| 61 |
+
Table 1
|
| 62 |
+
```html
|
| 63 |
+
<table>
|
| 64 |
+
<tr>
|
| 65 |
+
<th>Name</th>
|
| 66 |
+
<th>Age</th>
|
| 67 |
+
<th>City</th>
|
| 68 |
+
<th>State</th>
|
| 69 |
+
</tr>
|
| 70 |
+
<tr>
|
| 71 |
+
<td>John</td>
|
| 72 |
+
<td>25</td>
|
| 73 |
+
<td>Chicago</td>
|
| 74 |
+
<td>IL</td>
|
| 75 |
+
</tr>
|
| 76 |
+
```
|
| 77 |
+
Table 2
|
| 78 |
+
```html
|
| 79 |
+
<table>
|
| 80 |
+
<tr>
|
| 81 |
+
<td>Jane</td>
|
| 82 |
+
<td>30</td>
|
| 83 |
+
<td>Los Angeles</td>
|
| 84 |
+
<td>CA</td>
|
| 85 |
+
</tr>
|
| 86 |
+
```
|
| 87 |
+
Output:
|
| 88 |
+
```json
|
| 89 |
+
{
|
| 90 |
+
"table1_description": "Table 1 has 4 headers, and 1 row. The headers are Name, Age, City, and State.",
|
| 91 |
+
"table2_description": "Table 2 has no headers, but the values appear to represent a person's name, age, city, and state.",
|
| 92 |
+
"explanation": "The values in Table 2 match the headers in Table 1, and Table 2 has no headers. Table 2 should be merged to the bottom of Table 1.",
|
| 93 |
+
"merge": "true",
|
| 94 |
+
"direction": "bottom"
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
**Input:**
|
| 98 |
+
Table 1
|
| 99 |
+
```html
|
| 100 |
+
{{table1}}
|
| 101 |
+
Table 2
|
| 102 |
+
```html
|
| 103 |
+
{{table2}}
|
| 104 |
+
```
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def get_row_count(cells: List[TableCell]):
|
| 109 |
+
max_rows = None
|
| 110 |
+
for col_id in set([cell.col_id for cell in cells]):
|
| 111 |
+
col_cells = [cell for cell in cells if cell.col_id == col_id]
|
| 112 |
+
rows = 0
|
| 113 |
+
for cell in col_cells:
|
| 114 |
+
rows += cell.rowspan
|
| 115 |
+
if max_rows is None or rows > max_rows:
|
| 116 |
+
max_rows = rows
|
| 117 |
+
return max_rows
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def get_column_count(cells: List[TableCell]):
|
| 121 |
+
max_cols = None
|
| 122 |
+
for row_id in set([cell.row_id for cell in cells]):
|
| 123 |
+
row_cells = [cell for cell in cells if cell.row_id == row_id]
|
| 124 |
+
cols = 0
|
| 125 |
+
for cell in row_cells:
|
| 126 |
+
cols += cell.colspan
|
| 127 |
+
if max_cols is None or cols > max_cols:
|
| 128 |
+
max_cols = cols
|
| 129 |
+
return max_cols
|
| 130 |
+
|
| 131 |
+
def rewrite_blocks(self, document: Document):
|
| 132 |
+
pbar = tqdm(desc=f"{self.__class__.__name__} running")
|
| 133 |
+
table_runs = []
|
| 134 |
+
table_run = []
|
| 135 |
+
prev_block = None
|
| 136 |
+
prev_page_block_count = None
|
| 137 |
+
for page in document.pages:
|
| 138 |
+
page_blocks = page.contained_blocks(document, self.block_types)
|
| 139 |
+
for block in page_blocks:
|
| 140 |
+
merge_condition = False
|
| 141 |
+
if prev_block is not None:
|
| 142 |
+
prev_cells = prev_block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 143 |
+
curr_cells = block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 144 |
+
row_match = abs(self.get_row_count(prev_cells) - self.get_row_count(curr_cells)) < 5, # Similar number of rows
|
| 145 |
+
col_match = abs(self.get_column_count(prev_cells) - self.get_column_count(curr_cells)) < 2
|
| 146 |
+
|
| 147 |
+
subsequent_page_table = all([
|
| 148 |
+
prev_block.page_id == block.page_id - 1, # Subsequent pages
|
| 149 |
+
max(prev_block.polygon.height / page.polygon.height,
|
| 150 |
+
block.polygon.height / page.polygon.height) > self.table_height_threshold, # Take up most of the page height
|
| 151 |
+
(len(page_blocks) == 1 or prev_page_block_count == 1), # Only table on the page
|
| 152 |
+
(row_match or col_match)
|
| 153 |
+
])
|
| 154 |
+
|
| 155 |
+
same_page_vertical_table = all([
|
| 156 |
+
prev_block.page_id == block.page_id, # On the same page
|
| 157 |
+
(1 - self.vertical_table_height_threshold) < prev_block.polygon.height / block.polygon.height < (1 + self.vertical_table_height_threshold), # Similar height
|
| 158 |
+
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.vertical_table_distance_threshold, # Close together in x
|
| 159 |
+
abs(block.polygon.y_start - prev_block.polygon.y_start) < self.vertical_table_distance_threshold, # Close together in y
|
| 160 |
+
row_match
|
| 161 |
+
])
|
| 162 |
+
|
| 163 |
+
same_page_new_column = all([
|
| 164 |
+
prev_block.page_id == block.page_id, # On the same page
|
| 165 |
+
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
|
| 166 |
+
block.polygon.y_start < prev_block.polygon.y_end,
|
| 167 |
+
block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
|
| 168 |
+
col_match
|
| 169 |
+
])
|
| 170 |
+
merge_condition = any([subsequent_page_table, same_page_vertical_table, same_page_new_column])
|
| 171 |
+
|
| 172 |
+
if prev_block is not None and merge_condition:
|
| 173 |
+
if prev_block not in table_run:
|
| 174 |
+
table_run.append(prev_block)
|
| 175 |
+
table_run.append(block)
|
| 176 |
+
else:
|
| 177 |
+
if table_run:
|
| 178 |
+
table_runs.append(table_run)
|
| 179 |
+
table_run = []
|
| 180 |
+
prev_block = block
|
| 181 |
+
prev_page_block_count = len(page_blocks)
|
| 182 |
+
|
| 183 |
+
if table_run:
|
| 184 |
+
table_runs.append(table_run)
|
| 185 |
+
|
| 186 |
+
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
|
| 187 |
+
for future in as_completed([
|
| 188 |
+
executor.submit(self.process_rewriting, document, blocks)
|
| 189 |
+
for blocks in table_runs
|
| 190 |
+
]):
|
| 191 |
+
future.result() # Raise exceptions if any occurred
|
| 192 |
+
pbar.update(1)
|
| 193 |
+
|
| 194 |
+
pbar.close()
|
| 195 |
+
|
| 196 |
+
def process_rewriting(self, document: Document, blocks: List[Block]):
|
| 197 |
+
if len(blocks) < 2:
|
| 198 |
+
# Can't merge single tables
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
start_block = blocks[0]
|
| 202 |
+
for i in range(1, len(blocks)):
|
| 203 |
+
curr_block = blocks[i]
|
| 204 |
+
children = start_block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 205 |
+
children_curr = curr_block.contained_blocks(document, (BlockTypes.TableCell,))
|
| 206 |
+
if not children or not children_curr:
|
| 207 |
+
# Happens if table/form processors didn't run
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
start_image = start_block.get_image(document, highres=False)
|
| 211 |
+
curr_image = curr_block.get_image(document, highres=False)
|
| 212 |
+
start_html = start_block.render(document).html
|
| 213 |
+
curr_html = curr_block.render(document).html
|
| 214 |
+
|
| 215 |
+
prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
|
| 216 |
+
|
| 217 |
+
response_schema = content.Schema(
|
| 218 |
+
type=content.Type.OBJECT,
|
| 219 |
+
enum=[],
|
| 220 |
+
required=["table1_description", "table2_description", "explanation", "merge", "direction"],
|
| 221 |
+
properties={
|
| 222 |
+
"table1_description": content.Schema(
|
| 223 |
+
type=content.Type.STRING
|
| 224 |
+
),
|
| 225 |
+
"table2_description": content.Schema(
|
| 226 |
+
type=content.Type.STRING
|
| 227 |
+
),
|
| 228 |
+
"explanation": content.Schema(
|
| 229 |
+
type=content.Type.STRING
|
| 230 |
+
),
|
| 231 |
+
"merge": content.Schema(
|
| 232 |
+
type=content.Type.STRING,
|
| 233 |
+
enum=["true", "false"]
|
| 234 |
+
),
|
| 235 |
+
"direction": content.Schema(
|
| 236 |
+
type=content.Type.STRING,
|
| 237 |
+
enum=["bottom", "right"]
|
| 238 |
+
),
|
| 239 |
+
},
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
response = self.model.generate_response(
|
| 243 |
+
prompt,
|
| 244 |
+
[start_image, curr_image],
|
| 245 |
+
curr_block,
|
| 246 |
+
response_schema
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if not response or ("direction" not in response or "merge" not in response):
|
| 250 |
+
curr_block.update_metadata(llm_error_count=1)
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
merge = response["merge"]
|
| 254 |
+
|
| 255 |
+
# The original table is okay
|
| 256 |
+
if "true" not in merge:
|
| 257 |
+
start_block = curr_block
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
# Merge the cells and images of the tables
|
| 261 |
+
direction = response["direction"]
|
| 262 |
+
if not self.validate_merge(children, children_curr, direction):
|
| 263 |
+
start_block = curr_block
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
merged_image = self.join_images(start_image, curr_image, direction)
|
| 267 |
+
merged_cells = self.join_cells(children, children_curr, direction)
|
| 268 |
+
curr_block.structure = []
|
| 269 |
+
start_block.structure = [b.id for b in merged_cells]
|
| 270 |
+
start_block.lowres_image = merged_image
|
| 271 |
+
|
| 272 |
+
def validate_merge(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right'):
|
| 273 |
+
if direction == "right":
|
| 274 |
+
# Check if the number of rows is the same
|
| 275 |
+
cells1_row_count = self.get_row_count(cells1)
|
| 276 |
+
cells2_row_count = self.get_row_count(cells2)
|
| 277 |
+
return abs(cells1_row_count - cells2_row_count) < 5
|
| 278 |
+
elif direction == "bottom":
|
| 279 |
+
# Check if the number of columns is the same
|
| 280 |
+
cells1_col_count = self.get_column_count(cells1)
|
| 281 |
+
cells2_col_count = self.get_column_count(cells2)
|
| 282 |
+
return abs(cells1_col_count - cells2_col_count) < 2
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def join_cells(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right') -> List[TableCell]:
|
| 286 |
+
if direction == 'right':
|
| 287 |
+
# Shift columns right
|
| 288 |
+
col_count = self.get_column_count(cells1)
|
| 289 |
+
for cell in cells2:
|
| 290 |
+
cell.col_id += col_count
|
| 291 |
+
new_cells = cells1 + cells2
|
| 292 |
+
else:
|
| 293 |
+
# Shift rows up
|
| 294 |
+
row_count = self.get_row_count(cells1)
|
| 295 |
+
for cell in cells2:
|
| 296 |
+
cell.row_id += row_count
|
| 297 |
+
new_cells = cells1 + cells2
|
| 298 |
+
return new_cells
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
def join_images(image1: Image.Image, image2: Image.Image, direction: Literal['right', 'bottom'] = 'right') -> Image.Image:
|
| 302 |
+
# Get dimensions
|
| 303 |
+
w1, h1 = image1.size
|
| 304 |
+
w2, h2 = image2.size
|
| 305 |
+
|
| 306 |
+
if direction == 'right':
|
| 307 |
+
new_height = max(h1, h2)
|
| 308 |
+
new_width = w1 + w2
|
| 309 |
+
new_img = Image.new('RGB', (new_width, new_height), 'white')
|
| 310 |
+
new_img.paste(image1, (0, 0))
|
| 311 |
+
new_img.paste(image2, (w1, 0))
|
| 312 |
+
else:
|
| 313 |
+
new_width = max(w1, w2)
|
| 314 |
+
new_height = h1 + h2
|
| 315 |
+
new_img = Image.new('RGB', (new_width, new_height), 'white')
|
| 316 |
+
new_img.paste(image1, (0, 0))
|
| 317 |
+
new_img.paste(image2, (0, h1))
|
| 318 |
+
return new_img
|
marker/processors/llm/llm_text.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
|
| 3 |
from marker.processors.llm import BaseLLMProcessor
|
| 4 |
from bs4 import BeautifulSoup
|
|
@@ -13,10 +14,10 @@ from marker.schema.text.span import Span
|
|
| 13 |
|
| 14 |
class LLMTextProcessor(BaseLLMProcessor):
|
| 15 |
block_types = (BlockTypes.TextInlineMath, BlockTypes.Handwriting)
|
| 16 |
-
|
| 17 |
You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
|
| 18 |
Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
|
| 19 |
-
The number of output lines MUST match the number of input lines.
|
| 20 |
|
| 21 |
**Instructions:**
|
| 22 |
|
|
@@ -64,7 +65,9 @@ Output:
|
|
| 64 |
```
|
| 65 |
|
| 66 |
**Input:**
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
|
| 70 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
|
@@ -73,8 +76,8 @@ Output:
|
|
| 73 |
text_lines = block.contained_blocks(document, (BlockTypes.Line,))
|
| 74 |
extracted_lines = [line.formatted_text(document) for line in text_lines]
|
| 75 |
|
| 76 |
-
prompt = self.
|
| 77 |
-
image = self.extract_image(
|
| 78 |
response_schema = content.Schema(
|
| 79 |
type=content.Type.OBJECT,
|
| 80 |
enum=[],
|
|
|
|
| 1 |
import json
|
| 2 |
+
import textwrap
|
| 3 |
|
| 4 |
from marker.processors.llm import BaseLLMProcessor
|
| 5 |
from bs4 import BeautifulSoup
|
|
|
|
| 14 |
|
| 15 |
class LLMTextProcessor(BaseLLMProcessor):
|
| 16 |
block_types = (BlockTypes.TextInlineMath, BlockTypes.Handwriting)
|
| 17 |
+
text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
|
| 18 |
You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
|
| 19 |
Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
|
| 20 |
+
The number of output lines MUST match the number of input lines. Stay as faithful to the original text as possible.
|
| 21 |
|
| 22 |
**Instructions:**
|
| 23 |
|
|
|
|
| 65 |
```
|
| 66 |
|
| 67 |
**Input:**
|
| 68 |
+
```json
|
| 69 |
+
{extracted_lines}
|
| 70 |
+
```
|
| 71 |
"""
|
| 72 |
|
| 73 |
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
|
|
|
|
| 76 |
text_lines = block.contained_blocks(document, (BlockTypes.Line,))
|
| 77 |
extracted_lines = [line.formatted_text(document) for line in text_lines]
|
| 78 |
|
| 79 |
+
prompt = self.text_math_rewriting_prompt.replace("{extracted_lines}", json.dumps({"extracted_lines": extracted_lines}, indent=2))
|
| 80 |
+
image = self.extract_image(document, block)
|
| 81 |
response_schema = content.Schema(
|
| 82 |
type=content.Type.OBJECT,
|
| 83 |
enum=[],
|
marker/processors/llm/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import time
|
|
|
|
| 3 |
|
| 4 |
import PIL
|
| 5 |
import google.generativeai as genai
|
|
@@ -25,17 +26,19 @@ class GoogleModel:
|
|
| 25 |
def generate_response(
|
| 26 |
self,
|
| 27 |
prompt: str,
|
| 28 |
-
image: PIL.Image.Image,
|
| 29 |
block: Block,
|
| 30 |
response_schema: content.Schema,
|
| 31 |
max_retries: int = 3,
|
| 32 |
timeout: int = 60
|
| 33 |
):
|
|
|
|
|
|
|
| 34 |
tries = 0
|
| 35 |
while tries < max_retries:
|
| 36 |
try:
|
| 37 |
responses = self.model.generate_content(
|
| 38 |
-
[prompt, image
|
| 39 |
stream=False,
|
| 40 |
generation_config={
|
| 41 |
"temperature": 0,
|
|
|
|
| 1 |
import json
|
| 2 |
import time
|
| 3 |
+
from typing import List
|
| 4 |
|
| 5 |
import PIL
|
| 6 |
import google.generativeai as genai
|
|
|
|
| 26 |
def generate_response(
|
| 27 |
self,
|
| 28 |
prompt: str,
|
| 29 |
+
image: PIL.Image.Image | List[PIL.Image.Image],
|
| 30 |
block: Block,
|
| 31 |
response_schema: content.Schema,
|
| 32 |
max_retries: int = 3,
|
| 33 |
timeout: int = 60
|
| 34 |
):
|
| 35 |
+
if not isinstance(image, list):
|
| 36 |
+
image = [image]
|
| 37 |
tries = 0
|
| 38 |
while tries < max_retries:
|
| 39 |
try:
|
| 40 |
responses = self.model.generate_content(
|
| 41 |
+
image + [prompt], # According to gemini docs, it performs better if the image is the first element
|
| 42 |
stream=False,
|
| 43 |
generation_config={
|
| 44 |
"temperature": 0,
|
marker/processors/table.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from ftfy import fix_text
|
| 5 |
-
from surya.
|
| 6 |
-
from surya.
|
| 7 |
-
from surya.
|
| 8 |
-
from surya.
|
| 9 |
-
from
|
| 10 |
-
from tabled.inference.recognition import get_cells, recognize_tables
|
| 11 |
|
| 12 |
from marker.processors import BaseProcessor
|
| 13 |
from marker.schema import BlockTypes
|
|
|
|
| 14 |
from marker.schema.document import Document
|
|
|
|
| 15 |
from marker.settings import settings
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class TableProcessor(BaseProcessor):
|
|
@@ -42,9 +46,9 @@ class TableProcessor(BaseProcessor):
|
|
| 42 |
|
| 43 |
def __init__(
|
| 44 |
self,
|
| 45 |
-
detection_model:
|
| 46 |
-
recognition_model:
|
| 47 |
-
table_rec_model:
|
| 48 |
config=None
|
| 49 |
):
|
| 50 |
super().__init__(config)
|
|
@@ -59,51 +63,207 @@ class TableProcessor(BaseProcessor):
|
|
| 59 |
table_data = []
|
| 60 |
for page in document.pages:
|
| 61 |
for block in page.contained_blocks(document, self.block_types):
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
if block.text_extraction_method == "surya":
|
| 66 |
-
text_lines = None
|
| 67 |
-
else:
|
| 68 |
-
text_lines = get_page_text_lines(
|
| 69 |
-
filepath,
|
| 70 |
-
[page.page_id],
|
| 71 |
-
[page.highres_image.size],
|
| 72 |
-
flatten_pdf=True
|
| 73 |
-
)[0]
|
| 74 |
|
| 75 |
table_data.append({
|
| 76 |
"block_id": block.id,
|
|
|
|
| 77 |
"table_image": image,
|
| 78 |
"table_bbox": image_poly.bbox,
|
| 79 |
-
"
|
| 80 |
-
"
|
| 81 |
})
|
| 82 |
|
| 83 |
-
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
detect_boxes=self.detect_boxes,
|
| 89 |
-
detector_batch_size=self.get_detector_batch_size()
|
| 90 |
-
)
|
| 91 |
|
| 92 |
-
tables =
|
| 93 |
[t["table_image"] for t in table_data],
|
| 94 |
-
|
| 95 |
-
needs_ocr,
|
| 96 |
-
[self.table_rec_model, self.table_rec_model.processor, self.recognition_model, self.recognition_model.processor],
|
| 97 |
-
table_rec_batch_size=self.get_table_rec_batch_size(),
|
| 98 |
-
ocr_batch_size=self.get_recognition_batch_size()
|
| 99 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
for table_d, table_res in zip(table_data, tables):
|
| 102 |
-
block = document.get_block(table_d["block_id"])
|
| 103 |
-
cells = assign_rows_columns(table_res, table_d["img_size"])
|
| 104 |
-
for cell in cells:
|
| 105 |
-
cell.text = fix_text(cell.text)
|
| 106 |
-
block.cells = cells
|
| 107 |
|
| 108 |
def get_detector_batch_size(self):
|
| 109 |
if self.detector_batch_size is not None:
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Annotated, List
|
| 5 |
|
| 6 |
from ftfy import fix_text
|
| 7 |
+
from surya.detection import DetectionPredictor
|
| 8 |
+
from surya.recognition import RecognitionPredictor, OCRResult
|
| 9 |
+
from surya.table_rec import TableRecPredictor
|
| 10 |
+
from surya.table_rec.schema import TableResult, TableCell as SuryaTableCell
|
| 11 |
+
from pdftext.extraction import table_output
|
|
|
|
| 12 |
|
| 13 |
from marker.processors import BaseProcessor
|
| 14 |
from marker.schema import BlockTypes
|
| 15 |
+
from marker.schema.blocks.tablecell import TableCell
|
| 16 |
from marker.schema.document import Document
|
| 17 |
+
from marker.schema.polygon import PolygonBox
|
| 18 |
from marker.settings import settings
|
| 19 |
+
from marker.util import matrix_intersection_area
|
| 20 |
|
| 21 |
|
| 22 |
class TableProcessor(BaseProcessor):
|
|
|
|
| 46 |
|
| 47 |
def __init__(
|
| 48 |
self,
|
| 49 |
+
detection_model: DetectionPredictor,
|
| 50 |
+
recognition_model: RecognitionPredictor,
|
| 51 |
+
table_rec_model: TableRecPredictor,
|
| 52 |
config=None
|
| 53 |
):
|
| 54 |
super().__init__(config)
|
|
|
|
| 63 |
table_data = []
|
| 64 |
for page in document.pages:
|
| 65 |
for block in page.contained_blocks(document, self.block_types):
|
| 66 |
+
image = block.get_image(document, highres=True, expansion=(.01, .01))
|
| 67 |
+
image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.get_image(highres=True).size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
table_data.append({
|
| 70 |
"block_id": block.id,
|
| 71 |
+
"page_id": page.page_id,
|
| 72 |
"table_image": image,
|
| 73 |
"table_bbox": image_poly.bbox,
|
| 74 |
+
"img_size": page.get_image(highres=True).size,
|
| 75 |
+
"ocr_block": page.text_extraction_method == "surya",
|
| 76 |
})
|
| 77 |
|
| 78 |
+
extract_blocks = [t for t in table_data if not t["ocr_block"]]
|
| 79 |
+
self.assign_pdftext_lines(extract_blocks, filepath) # Handle tables where good text exists in the PDF
|
| 80 |
|
| 81 |
+
ocr_blocks = [t for t in table_data if t["ocr_block"]]
|
| 82 |
+
self.assign_ocr_lines(ocr_blocks) # Handle tables where OCR is needed
|
| 83 |
+
assert all("table_text_lines" in t for t in table_data), "All table data must have table cells"
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
tables: List[TableResult] = self.table_rec_model(
|
| 86 |
[t["table_image"] for t in table_data],
|
| 87 |
+
batch_size=self.get_table_rec_batch_size()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
)
|
| 89 |
+
self.assign_text_to_cells(tables, table_data)
|
| 90 |
+
self.split_combined_rows(tables) # Split up rows that were combined
|
| 91 |
+
|
| 92 |
+
# Assign table cells to the table
|
| 93 |
+
table_idx = 0
|
| 94 |
+
for page in document.pages:
|
| 95 |
+
for block in page.contained_blocks(document, self.block_types):
|
| 96 |
+
block.structure = [] # Remove any existing lines, spans, etc.
|
| 97 |
+
cells: List[SuryaTableCell] = tables[table_idx].cells
|
| 98 |
+
for cell in cells:
|
| 99 |
+
# Rescale the cell polygon to the page size
|
| 100 |
+
cell_polygon = PolygonBox(polygon=cell.polygon).rescale(page.get_image(highres=True).size, page.polygon.size)
|
| 101 |
+
cell_block = TableCell(
|
| 102 |
+
polygon=cell_polygon,
|
| 103 |
+
text=self.finalize_cell_text(cell),
|
| 104 |
+
rowspan=cell.rowspan,
|
| 105 |
+
colspan=cell.colspan,
|
| 106 |
+
row_id=cell.row_id,
|
| 107 |
+
col_id=cell.col_id,
|
| 108 |
+
is_header=bool(cell.is_header),
|
| 109 |
+
page_id=page.page_id,
|
| 110 |
+
)
|
| 111 |
+
page.add_full_block(cell_block)
|
| 112 |
+
block.add_structure(cell_block)
|
| 113 |
+
table_idx += 1
|
| 114 |
+
|
| 115 |
+
def finalize_cell_text(self, cell: SuryaTableCell):
|
| 116 |
+
text = "\n".join([t["text"].strip() for t in cell.text_lines]) if cell.text_lines else ""
|
| 117 |
+
text = re.sub(r"(\s\.){2,}", "", text) # Replace . . .
|
| 118 |
+
text = re.sub(r"\.{2,}", "", text) # Replace ..., like in table of contents
|
| 119 |
+
return self.normalize_spaces(fix_text(text))
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def normalize_spaces(text):
|
| 123 |
+
space_chars = [
|
| 124 |
+
'\u2003', # em space
|
| 125 |
+
'\u2002', # en space
|
| 126 |
+
'\u00A0', # non-breaking space
|
| 127 |
+
'\u200B', # zero-width space
|
| 128 |
+
'\u3000', # ideographic space
|
| 129 |
+
]
|
| 130 |
+
for space in space_chars:
|
| 131 |
+
text = text.replace(space, ' ')
|
| 132 |
+
return text
|
| 133 |
+
|
| 134 |
+
def split_combined_rows(self, tables: List[TableResult]):
|
| 135 |
+
for table in tables:
|
| 136 |
+
if len(table.cells) == 0:
|
| 137 |
+
# Skip empty tables
|
| 138 |
+
continue
|
| 139 |
+
unique_rows = sorted(list(set([c.row_id for c in table.cells])))
|
| 140 |
+
new_cells = []
|
| 141 |
+
shift_up = 0
|
| 142 |
+
max_cell_id = max([c.cell_id for c in table.cells])
|
| 143 |
+
new_cell_count = 0
|
| 144 |
+
for row in unique_rows:
|
| 145 |
+
# Cells in this row
|
| 146 |
+
# Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
|
| 147 |
+
# making them be processed twice
|
| 148 |
+
row_cells = deepcopy([c for c in table.cells if c.row_id == row])
|
| 149 |
+
rowspans = [c.rowspan for c in row_cells]
|
| 150 |
+
line_lens = [len(c.text_lines) if isinstance(c.text_lines, list) else 1 for c in row_cells]
|
| 151 |
+
|
| 152 |
+
# Other cells that span into this row
|
| 153 |
+
rowspan_cells = [c for c in table.cells if c.row_id != row and c.row_id + c.rowspan > row > c.row_id]
|
| 154 |
+
should_split = all([
|
| 155 |
+
len(row_cells) > 0,
|
| 156 |
+
len(rowspan_cells) == 0,
|
| 157 |
+
all([r == 1 for r in rowspans]),
|
| 158 |
+
all([l > 1 for l in line_lens]),
|
| 159 |
+
all([l == line_lens[0] for l in line_lens])
|
| 160 |
+
])
|
| 161 |
+
if should_split:
|
| 162 |
+
for i in range(0, line_lens[0]):
|
| 163 |
+
for cell in row_cells:
|
| 164 |
+
line = cell.text_lines[i]
|
| 165 |
+
cell_id = max_cell_id + new_cell_count
|
| 166 |
+
new_cells.append(
|
| 167 |
+
SuryaTableCell(
|
| 168 |
+
polygon=line["bbox"],
|
| 169 |
+
text_lines=[line],
|
| 170 |
+
rowspan=1,
|
| 171 |
+
colspan=cell.colspan,
|
| 172 |
+
row_id=cell.row_id + shift_up + i,
|
| 173 |
+
col_id=cell.col_id,
|
| 174 |
+
is_header=cell.is_header and i == 0, # Only first line is header
|
| 175 |
+
within_row_id=cell.within_row_id,
|
| 176 |
+
cell_id=cell_id
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
new_cell_count += 1
|
| 180 |
+
|
| 181 |
+
# For each new row we add, shift up subsequent rows
|
| 182 |
+
shift_up += line_lens[0] - 1
|
| 183 |
+
else:
|
| 184 |
+
for cell in row_cells:
|
| 185 |
+
cell.row_id += shift_up
|
| 186 |
+
new_cells.append(cell)
|
| 187 |
+
|
| 188 |
+
# Only update the cells if we added new cells
|
| 189 |
+
if len(new_cells) > len(table.cells):
|
| 190 |
+
table.cells = new_cells
|
| 191 |
+
|
| 192 |
+
def assign_text_to_cells(self, tables: List[TableResult], table_data: list):
|
| 193 |
+
for table_result, table_page_data in zip(tables, table_data):
|
| 194 |
+
table_text_lines = table_page_data["table_text_lines"]
|
| 195 |
+
table_cells: List[SuryaTableCell] = table_result.cells
|
| 196 |
+
text_line_bboxes = [t["bbox"] for t in table_text_lines]
|
| 197 |
+
table_cell_bboxes = [c.bbox for c in table_cells]
|
| 198 |
+
|
| 199 |
+
intersection_matrix = matrix_intersection_area(text_line_bboxes, table_cell_bboxes)
|
| 200 |
+
|
| 201 |
+
cell_text = defaultdict(list)
|
| 202 |
+
for text_line_idx, table_text_line in enumerate(table_text_lines):
|
| 203 |
+
intersections = intersection_matrix[text_line_idx]
|
| 204 |
+
if intersections.sum() == 0:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
max_intersection = intersections.argmax()
|
| 208 |
+
cell_text[max_intersection].append(table_text_line)
|
| 209 |
+
|
| 210 |
+
for k in cell_text:
|
| 211 |
+
# TODO: see if the text needs to be sorted (based on rotation)
|
| 212 |
+
text = cell_text[k]
|
| 213 |
+
assert all("text" in t for t in text), "All text lines must have text"
|
| 214 |
+
assert all("bbox" in t for t in text), "All text lines must have a bbox"
|
| 215 |
+
table_cells[k].text_lines = text
|
| 216 |
+
|
| 217 |
+
def assign_pdftext_lines(self, extract_blocks: list, filepath: str):
|
| 218 |
+
table_inputs = []
|
| 219 |
+
unique_pages = list(set([t["page_id"] for t in extract_blocks]))
|
| 220 |
+
if len(unique_pages) == 0:
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
for page in unique_pages:
|
| 224 |
+
tables = []
|
| 225 |
+
img_size = None
|
| 226 |
+
for block in extract_blocks:
|
| 227 |
+
if block["page_id"] == page:
|
| 228 |
+
tables.append(block["table_bbox"])
|
| 229 |
+
img_size = block["img_size"]
|
| 230 |
+
|
| 231 |
+
table_inputs.append({
|
| 232 |
+
"tables": tables,
|
| 233 |
+
"img_size": img_size
|
| 234 |
+
})
|
| 235 |
+
cell_text = table_output(filepath, table_inputs, page_range=unique_pages)
|
| 236 |
+
assert len(cell_text) == len(unique_pages), "Number of pages and table inputs must match"
|
| 237 |
+
|
| 238 |
+
for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)):
|
| 239 |
+
table_idx = 0
|
| 240 |
+
for block in extract_blocks:
|
| 241 |
+
if block["page_id"] == pnum:
|
| 242 |
+
block["table_text_lines"] = page_tables[table_idx]
|
| 243 |
+
table_idx += 1
|
| 244 |
+
assert table_idx == len(page_tables), "Number of tables and table inputs must match"
|
| 245 |
+
|
| 246 |
+
def assign_ocr_lines(self, ocr_blocks: list):
|
| 247 |
+
det_images = [t["table_image"] for t in ocr_blocks]
|
| 248 |
+
ocr_results: List[OCRResult] = self.recognition_model(
|
| 249 |
+
det_images,
|
| 250 |
+
[None] * len(det_images),
|
| 251 |
+
self.detection_model,
|
| 252 |
+
recognition_batch_size=self.get_recognition_batch_size(),
|
| 253 |
+
detection_batch_size=self.get_detector_batch_size()
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
for block, ocr_res in zip(ocr_blocks, ocr_results):
|
| 257 |
+
table_cells = []
|
| 258 |
+
for line in ocr_res.text_lines:
|
| 259 |
+
# Don't need to correct back to image size
|
| 260 |
+
# Table rec boxes are relative to the table
|
| 261 |
+
table_cells.append({
|
| 262 |
+
"bbox": line.bbox,
|
| 263 |
+
"text": line.text
|
| 264 |
+
})
|
| 265 |
+
block["table_text_lines"] = table_cells
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
def get_detector_batch_size(self):
|
| 269 |
if self.detector_batch_size is not None:
|
marker/providers/__init__.py
CHANGED
|
@@ -3,6 +3,9 @@ from typing import List, Optional, Dict
|
|
| 3 |
from PIL import Image
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
from marker.schema.text import Span
|
| 7 |
from marker.schema.text.line import Line
|
| 8 |
from marker.util import assign_config
|
|
@@ -29,8 +32,17 @@ class BaseProvider:
|
|
| 29 |
def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
|
| 30 |
pass
|
| 31 |
|
| 32 |
-
def get_page_bbox(self, idx: int) ->
|
| 33 |
pass
|
| 34 |
|
| 35 |
def get_page_lines(self, idx: int) -> List[Line]:
|
| 36 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
+
from pdftext.schema import Reference
|
| 7 |
+
|
| 8 |
+
from marker.schema.polygon import PolygonBox
|
| 9 |
from marker.schema.text import Span
|
| 10 |
from marker.schema.text.line import Line
|
| 11 |
from marker.util import assign_config
|
|
|
|
| 32 |
def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
|
| 33 |
pass
|
| 34 |
|
| 35 |
+
def get_page_bbox(self, idx: int) -> PolygonBox | None:
|
| 36 |
pass
|
| 37 |
|
| 38 |
def get_page_lines(self, idx: int) -> List[Line]:
|
| 39 |
pass
|
| 40 |
+
|
| 41 |
+
def get_page_refs(self, idx: int) -> List[Reference]:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
def __enter__(self):
|
| 45 |
+
return self
|
| 46 |
+
|
| 47 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 48 |
+
raise NotImplementedError
|
marker/providers/image.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Annotated, Optional
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
from marker.providers import ProviderPageLines, BaseProvider
|
| 5 |
+
from marker.schema.polygon import PolygonBox
|
| 6 |
+
from marker.schema.text import Line
|
| 7 |
+
from pdftext.schema import Reference
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ImageProvider(BaseProvider):
|
| 11 |
+
page_range: Annotated[
|
| 12 |
+
Optional[List[int]],
|
| 13 |
+
"The range of pages to process.",
|
| 14 |
+
"Default is None, which will process all pages."
|
| 15 |
+
] = None
|
| 16 |
+
|
| 17 |
+
image_count: int = 1
|
| 18 |
+
|
| 19 |
+
def __init__(self, filepath: str, config=None):
|
| 20 |
+
super().__init__(filepath, config)
|
| 21 |
+
|
| 22 |
+
self.images = [Image.open(filepath)]
|
| 23 |
+
self.page_lines: ProviderPageLines = {i: [] for i in range(self.image_count)}
|
| 24 |
+
|
| 25 |
+
if self.page_range is None:
|
| 26 |
+
self.page_range = range(self.image_count)
|
| 27 |
+
|
| 28 |
+
assert max(self.page_range) < self.image_count and min(self.page_range) >= 0, \
|
| 29 |
+
f"Invalid page range, values must be between 0 and {len(self.doc) - 1}. Min of provided page range is {min(self.page_range)} and max is {max(self.page_range)}."
|
| 30 |
+
|
| 31 |
+
self.page_bboxes = {i: [0, 0, self.images[i].size[0], self.images[i].size[1]] for i in self.page_range}
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return self.image_count
|
| 35 |
+
|
| 36 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
|
| 40 |
+
return [self.images[i] for i in idxs]
|
| 41 |
+
|
| 42 |
+
def get_page_bbox(self, idx: int) -> PolygonBox | None:
|
| 43 |
+
bbox = self.page_bboxes[idx]
|
| 44 |
+
if bbox:
|
| 45 |
+
return PolygonBox.from_bbox(bbox)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_page_lines(self, idx: int) -> List[Line]:
|
| 49 |
+
return self.page_lines[idx]
|
| 50 |
+
|
| 51 |
+
def get_page_refs(self, idx: int) -> List[Reference]:
|
| 52 |
+
return []
|
marker/providers/pdf.py
CHANGED
|
@@ -9,6 +9,7 @@ from ftfy import fix_text
|
|
| 9 |
from pdftext.extraction import dictionary_output
|
| 10 |
from pdftext.schema import Reference
|
| 11 |
from PIL import Image
|
|
|
|
| 12 |
|
| 13 |
from marker.providers import BaseProvider, ProviderOutput, ProviderPageLines
|
| 14 |
from marker.providers.utils import alphanum_ratio
|
|
@@ -91,9 +92,6 @@ class PdfProvider(BaseProvider):
|
|
| 91 |
|
| 92 |
atexit.register(self.cleanup_pdf_doc)
|
| 93 |
|
| 94 |
-
def __enter__(self):
|
| 95 |
-
return self
|
| 96 |
-
|
| 97 |
def __exit__(self, exc_type, exc_value, traceback):
|
| 98 |
self.cleanup_pdf_doc()
|
| 99 |
|
|
@@ -155,6 +153,19 @@ class PdfProvider(BaseProvider):
|
|
| 155 |
formats.add("italic")
|
| 156 |
return formats
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def pdftext_extraction(self) -> ProviderPageLines:
|
| 159 |
page_lines: ProviderPageLines = {}
|
| 160 |
page_char_blocks = dictionary_output(
|
|
@@ -191,7 +202,7 @@ class PdfProvider(BaseProvider):
|
|
| 191 |
spans.append(
|
| 192 |
SpanClass(
|
| 193 |
polygon=polygon,
|
| 194 |
-
text=fix_text(span["text"]),
|
| 195 |
font=font_name,
|
| 196 |
font_weight=font_weight,
|
| 197 |
font_size=font_size,
|
|
@@ -234,7 +245,11 @@ class PdfProvider(BaseProvider):
|
|
| 234 |
def check_page(self, page_id: int) -> bool:
|
| 235 |
page = self.doc.get_page(page_id)
|
| 236 |
page_bbox = PolygonBox.from_bbox(page.get_bbox())
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
# if we do not see any text objects in the pdf, we can skip this page
|
| 240 |
if not any([obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT for obj in page_objs]):
|
|
@@ -313,7 +328,7 @@ class PdfProvider(BaseProvider):
|
|
| 313 |
def get_page_lines(self, idx: int) -> List[ProviderOutput]:
|
| 314 |
return self.page_lines[idx]
|
| 315 |
|
| 316 |
-
def get_page_refs(self, idx: int):
|
| 317 |
return self.page_refs[idx]
|
| 318 |
|
| 319 |
@staticmethod
|
|
|
|
| 9 |
from pdftext.extraction import dictionary_output
|
| 10 |
from pdftext.schema import Reference
|
| 11 |
from PIL import Image
|
| 12 |
+
from pypdfium2 import PdfiumError
|
| 13 |
|
| 14 |
from marker.providers import BaseProvider, ProviderOutput, ProviderPageLines
|
| 15 |
from marker.providers.utils import alphanum_ratio
|
|
|
|
| 92 |
|
| 93 |
atexit.register(self.cleanup_pdf_doc)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
def __exit__(self, exc_type, exc_value, traceback):
|
| 96 |
self.cleanup_pdf_doc()
|
| 97 |
|
|
|
|
| 153 |
formats.add("italic")
|
| 154 |
return formats
|
| 155 |
|
| 156 |
+
@staticmethod
|
| 157 |
+
def normalize_spaces(text):
|
| 158 |
+
space_chars = [
|
| 159 |
+
'\u2003', # em space
|
| 160 |
+
'\u2002', # en space
|
| 161 |
+
'\u00A0', # non-breaking space
|
| 162 |
+
'\u200B', # zero-width space
|
| 163 |
+
'\u3000', # ideographic space
|
| 164 |
+
]
|
| 165 |
+
for space in space_chars:
|
| 166 |
+
text = text.replace(space, ' ')
|
| 167 |
+
return text
|
| 168 |
+
|
| 169 |
def pdftext_extraction(self) -> ProviderPageLines:
|
| 170 |
page_lines: ProviderPageLines = {}
|
| 171 |
page_char_blocks = dictionary_output(
|
|
|
|
| 202 |
spans.append(
|
| 203 |
SpanClass(
|
| 204 |
polygon=polygon,
|
| 205 |
+
text=self.normalize_spaces(fix_text(span["text"])),
|
| 206 |
font=font_name,
|
| 207 |
font_weight=font_weight,
|
| 208 |
font_size=font_size,
|
|
|
|
| 245 |
def check_page(self, page_id: int) -> bool:
|
| 246 |
page = self.doc.get_page(page_id)
|
| 247 |
page_bbox = PolygonBox.from_bbox(page.get_bbox())
|
| 248 |
+
try:
|
| 249 |
+
page_objs = list(page.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_TEXT, pdfium_c.FPDF_PAGEOBJ_IMAGE]))
|
| 250 |
+
except PdfiumError:
|
| 251 |
+
# Happens when pdfium fails to get the number of page objects
|
| 252 |
+
return False
|
| 253 |
|
| 254 |
# if we do not see any text objects in the pdf, we can skip this page
|
| 255 |
if not any([obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT for obj in page_objs]):
|
|
|
|
| 328 |
def get_page_lines(self, idx: int) -> List[ProviderOutput]:
|
| 329 |
return self.page_lines[idx]
|
| 330 |
|
| 331 |
+
def get_page_refs(self, idx: int) -> List[Reference]:
|
| 332 |
return self.page_refs[idx]
|
| 333 |
|
| 334 |
@staticmethod
|
marker/providers/registry.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import filetype
|
| 2 |
+
|
| 3 |
+
from marker.providers.image import ImageProvider
|
| 4 |
+
from marker.providers.pdf import PdfProvider
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def provider_from_filepath(filepath: str):
|
| 8 |
+
kind = filetype.image_match(filepath)
|
| 9 |
+
if kind is not None:
|
| 10 |
+
return ImageProvider
|
| 11 |
+
|
| 12 |
+
return PdfProvider
|
marker/renderers/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@ import base64
|
|
| 2 |
import io
|
| 3 |
import re
|
| 4 |
from collections import Counter
|
| 5 |
-
from typing import Annotated, Optional, Tuple
|
| 6 |
|
| 7 |
from bs4 import BeautifulSoup
|
| 8 |
from pydantic import BaseModel
|
|
@@ -17,6 +17,11 @@ from marker.util import assign_config
|
|
| 17 |
class BaseRenderer:
|
| 18 |
image_blocks: Annotated[Tuple[BlockTypes, ...], "The block types to consider as images."] = (BlockTypes.Picture, BlockTypes.Figure)
|
| 19 |
extract_images: Annotated[bool, "Extract images from the document."] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def __init__(self, config: Optional[BaseModel | dict] = None):
|
| 22 |
assign_config(self, config)
|
|
@@ -25,13 +30,10 @@ class BaseRenderer:
|
|
| 25 |
# Children are in reading order
|
| 26 |
raise NotImplementedError
|
| 27 |
|
| 28 |
-
|
| 29 |
-
def extract_image(document: Document, image_id, to_base64=False):
|
| 30 |
image_block = document.get_block(image_id)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
|
| 34 |
-
cropped = page_img.crop(image_box.bbox)
|
| 35 |
if to_base64:
|
| 36 |
image_buffer = io.BytesIO()
|
| 37 |
cropped.save(image_buffer, format=settings.OUTPUT_IMAGE_FORMAT)
|
|
@@ -44,7 +46,11 @@ class BaseRenderer:
|
|
| 44 |
return html
|
| 45 |
|
| 46 |
def replace_whitespace(match):
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
pattern = fr'</{tag}>(\s*)<{tag}>'
|
| 50 |
|
|
|
|
| 2 |
import io
|
| 3 |
import re
|
| 4 |
from collections import Counter
|
| 5 |
+
from typing import Annotated, Optional, Tuple, Literal
|
| 6 |
|
| 7 |
from bs4 import BeautifulSoup
|
| 8 |
from pydantic import BaseModel
|
|
|
|
| 17 |
class BaseRenderer:
|
| 18 |
image_blocks: Annotated[Tuple[BlockTypes, ...], "The block types to consider as images."] = (BlockTypes.Picture, BlockTypes.Figure)
|
| 19 |
extract_images: Annotated[bool, "Extract images from the document."] = True
|
| 20 |
+
image_extraction_mode: Annotated[
|
| 21 |
+
Literal["lowres", "highres"],
|
| 22 |
+
"The mode to use for extracting images.",
|
| 23 |
+
] = "highres"
|
| 24 |
+
|
| 25 |
|
| 26 |
def __init__(self, config: Optional[BaseModel | dict] = None):
|
| 27 |
assign_config(self, config)
|
|
|
|
| 30 |
# Children are in reading order
|
| 31 |
raise NotImplementedError
|
| 32 |
|
| 33 |
+
def extract_image(self, document: Document, image_id, to_base64=False):
|
|
|
|
| 34 |
image_block = document.get_block(image_id)
|
| 35 |
+
cropped = image_block.get_image(document, highres=self.image_extraction_mode == "highres")
|
| 36 |
+
|
|
|
|
|
|
|
| 37 |
if to_base64:
|
| 38 |
image_buffer = io.BytesIO()
|
| 39 |
cropped.save(image_buffer, format=settings.OUTPUT_IMAGE_FORMAT)
|
|
|
|
| 46 |
return html
|
| 47 |
|
| 48 |
def replace_whitespace(match):
|
| 49 |
+
whitespace = match.group(1)
|
| 50 |
+
if len(whitespace) == 0:
|
| 51 |
+
return ""
|
| 52 |
+
else:
|
| 53 |
+
return " "
|
| 54 |
|
| 55 |
pattern = fr'</{tag}>(\s*)<{tag}>'
|
| 56 |
|
marker/renderers/html.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
from typing import Annotated, Literal, Tuple
|
| 3 |
|
|
@@ -35,17 +37,10 @@ class HTMLRenderer(BaseRenderer):
|
|
| 35 |
bool,
|
| 36 |
"Whether to paginate the output.",
|
| 37 |
] = False
|
| 38 |
-
image_extraction_mode: Annotated[
|
| 39 |
-
Literal["lowres", "highres"],
|
| 40 |
-
"The mode to use for extracting images.",
|
| 41 |
-
] = "highres"
|
| 42 |
|
| 43 |
def extract_image(self, document, image_id):
|
| 44 |
image_block = document.get_block(image_id)
|
| 45 |
-
|
| 46 |
-
page_img = page.lowres_image if self.image_extraction_mode == "lowres" else page.highres_image
|
| 47 |
-
image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
|
| 48 |
-
cropped = page_img.crop(image_box.bbox)
|
| 49 |
return cropped
|
| 50 |
|
| 51 |
def extract_html(self, document, document_output, level=0):
|
|
@@ -87,12 +82,25 @@ class HTMLRenderer(BaseRenderer):
|
|
| 87 |
if level == 0:
|
| 88 |
output = self.merge_consecutive_tags(output, 'b')
|
| 89 |
output = self.merge_consecutive_tags(output, 'i')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
return output, images
|
| 92 |
|
| 93 |
def __call__(self, document) -> HTMLOutput:
|
| 94 |
document_output = document.render()
|
| 95 |
full_html, images = self.extract_html(document, document_output)
|
|
|
|
|
|
|
| 96 |
return HTMLOutput(
|
| 97 |
html=full_html,
|
| 98 |
images=images,
|
|
|
|
| 1 |
+
import textwrap
|
| 2 |
+
|
| 3 |
from PIL import Image
|
| 4 |
from typing import Annotated, Literal, Tuple
|
| 5 |
|
|
|
|
| 37 |
bool,
|
| 38 |
"Whether to paginate the output.",
|
| 39 |
] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def extract_image(self, document, image_id):
|
| 42 |
image_block = document.get_block(image_id)
|
| 43 |
+
cropped = image_block.get_image(document, highres=self.image_extraction_mode == "highres")
|
|
|
|
|
|
|
|
|
|
| 44 |
return cropped
|
| 45 |
|
| 46 |
def extract_html(self, document, document_output, level=0):
|
|
|
|
| 82 |
if level == 0:
|
| 83 |
output = self.merge_consecutive_tags(output, 'b')
|
| 84 |
output = self.merge_consecutive_tags(output, 'i')
|
| 85 |
+
output = textwrap.dedent(f"""
|
| 86 |
+
<!DOCTYPE html>
|
| 87 |
+
<html>
|
| 88 |
+
<head>
|
| 89 |
+
<meta charset="utf-8" />
|
| 90 |
+
</head>
|
| 91 |
+
<body>
|
| 92 |
+
{output}
|
| 93 |
+
</body>
|
| 94 |
+
</html>
|
| 95 |
+
""")
|
| 96 |
|
| 97 |
return output, images
|
| 98 |
|
| 99 |
def __call__(self, document) -> HTMLOutput:
|
| 100 |
document_output = document.render()
|
| 101 |
full_html, images = self.extract_html(document, document_output)
|
| 102 |
+
soup = BeautifulSoup(full_html, 'html.parser')
|
| 103 |
+
full_html = soup.prettify() # Add indentation to the HTML
|
| 104 |
return HTMLOutput(
|
| 105 |
html=full_html,
|
| 106 |
images=images,
|
marker/renderers/json.py
CHANGED
|
@@ -14,6 +14,7 @@ class JSONBlockOutput(BaseModel):
|
|
| 14 |
block_type: str
|
| 15 |
html: str
|
| 16 |
polygon: List[List[float]]
|
|
|
|
| 17 |
children: List['JSONBlockOutput'] | None = None
|
| 18 |
section_hierarchy: Dict[int, str] | None = None
|
| 19 |
images: dict | None = None
|
|
@@ -52,6 +53,7 @@ class JSONRenderer(BaseRenderer):
|
|
| 52 |
return JSONBlockOutput(
|
| 53 |
html=html,
|
| 54 |
polygon=block_output.polygon.polygon,
|
|
|
|
| 55 |
id=str(block_output.id),
|
| 56 |
block_type=str(block_output.id.block_type),
|
| 57 |
images=images,
|
|
@@ -66,6 +68,7 @@ class JSONRenderer(BaseRenderer):
|
|
| 66 |
return JSONBlockOutput(
|
| 67 |
html=block_output.html,
|
| 68 |
polygon=block_output.polygon.polygon,
|
|
|
|
| 69 |
id=str(block_output.id),
|
| 70 |
block_type=str(block_output.id.block_type),
|
| 71 |
children=children,
|
|
|
|
| 14 |
block_type: str
|
| 15 |
html: str
|
| 16 |
polygon: List[List[float]]
|
| 17 |
+
bbox: List[float]
|
| 18 |
children: List['JSONBlockOutput'] | None = None
|
| 19 |
section_hierarchy: Dict[int, str] | None = None
|
| 20 |
images: dict | None = None
|
|
|
|
| 53 |
return JSONBlockOutput(
|
| 54 |
html=html,
|
| 55 |
polygon=block_output.polygon.polygon,
|
| 56 |
+
bbox=block_output.polygon.bbox,
|
| 57 |
id=str(block_output.id),
|
| 58 |
block_type=str(block_output.id.block_type),
|
| 59 |
images=images,
|
|
|
|
| 68 |
return JSONBlockOutput(
|
| 69 |
html=block_output.html,
|
| 70 |
polygon=block_output.polygon.polygon,
|
| 71 |
+
bbox=block_output.polygon.bbox,
|
| 72 |
id=str(block_output.id),
|
| 73 |
block_type=str(block_output.id.block_type),
|
| 74 |
children=children,
|
marker/renderers/markdown.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
from typing import Annotated, Tuple
|
| 3 |
|
| 4 |
import regex
|
|
@@ -13,7 +14,7 @@ from marker.schema.document import Document
|
|
| 13 |
def cleanup_text(full_text):
|
| 14 |
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
|
| 15 |
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
|
| 16 |
-
return full_text
|
| 17 |
|
| 18 |
|
| 19 |
class Markdownify(MarkdownConverter):
|
|
@@ -53,13 +54,88 @@ class Markdownify(MarkdownConverter):
|
|
| 53 |
else:
|
| 54 |
return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n"
|
| 55 |
|
| 56 |
-
def
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
def convert_a(self, el, text, convert_as_inline):
|
| 65 |
text = self.escape(text)
|
|
|
|
| 1 |
import re
|
| 2 |
+
from collections import defaultdict
|
| 3 |
from typing import Annotated, Tuple
|
| 4 |
|
| 5 |
import regex
|
|
|
|
| 14 |
def cleanup_text(full_text):
|
| 15 |
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
|
| 16 |
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
|
| 17 |
+
return full_text.strip()
|
| 18 |
|
| 19 |
|
| 20 |
class Markdownify(MarkdownConverter):
|
|
|
|
| 54 |
else:
|
| 55 |
return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n"
|
| 56 |
|
| 57 |
+
def convert_table(self, el, text, convert_as_inline):
|
| 58 |
+
total_rows = len(el.find_all('tr'))
|
| 59 |
+
colspans = []
|
| 60 |
+
rowspan_cols = defaultdict(int)
|
| 61 |
+
for i, row in enumerate(el.find_all('tr')):
|
| 62 |
+
row_cols = rowspan_cols[i]
|
| 63 |
+
for cell in row.find_all(['td', 'th']):
|
| 64 |
+
colspan = int(cell.get('colspan', 1))
|
| 65 |
+
row_cols += colspan
|
| 66 |
+
for r in range(int(cell.get('rowspan', 1)) - 1):
|
| 67 |
+
rowspan_cols[i + r] += colspan # Add the colspan to the next rows, so they get the correct number of columns
|
| 68 |
+
colspans.append(row_cols)
|
| 69 |
+
total_cols = max(colspans)
|
| 70 |
+
|
| 71 |
+
grid = [[None for _ in range(total_cols)] for _ in range(total_rows)]
|
| 72 |
+
|
| 73 |
+
for row_idx, tr in enumerate(el.find_all('tr')):
|
| 74 |
+
col_idx = 0
|
| 75 |
+
for cell in tr.find_all(['td', 'th']):
|
| 76 |
+
# Skip filled positions
|
| 77 |
+
while col_idx < total_cols and grid[row_idx][col_idx] is not None:
|
| 78 |
+
col_idx += 1
|
| 79 |
+
|
| 80 |
+
# Fill in grid
|
| 81 |
+
value = cell.get_text(strip=True).replace("\n", " ").replace("|", " ")
|
| 82 |
+
rowspan = int(cell.get('rowspan', 1))
|
| 83 |
+
colspan = int(cell.get('colspan', 1))
|
| 84 |
+
|
| 85 |
+
if col_idx >= total_cols:
|
| 86 |
+
# Skip this cell if we're out of bounds
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
for r in range(rowspan):
|
| 90 |
+
for c in range(colspan):
|
| 91 |
+
try:
|
| 92 |
+
if r == 0 and c == 0:
|
| 93 |
+
grid[row_idx][col_idx] = value
|
| 94 |
+
else:
|
| 95 |
+
grid[row_idx + r][col_idx + c] = ''
|
| 96 |
+
except IndexError:
|
| 97 |
+
# Sometimes the colspan/rowspan predictions can overflow
|
| 98 |
+
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
col_idx += colspan
|
| 102 |
+
|
| 103 |
+
markdown_lines = []
|
| 104 |
+
col_widths = [0] * total_cols
|
| 105 |
+
for row in grid:
|
| 106 |
+
for col_idx, cell in enumerate(row):
|
| 107 |
+
if cell is not None:
|
| 108 |
+
col_widths[col_idx] = max(col_widths[col_idx], len(str(cell)))
|
| 109 |
+
|
| 110 |
+
add_header_line = lambda: markdown_lines.append('|' + '|'.join('-' * (width + 2) for width in col_widths) + '|')
|
| 111 |
+
|
| 112 |
+
# Generate markdown rows
|
| 113 |
+
added_header = False
|
| 114 |
+
for i, row in enumerate(grid):
|
| 115 |
+
is_empty_line = all(not cell for cell in row)
|
| 116 |
+
if is_empty_line and not added_header:
|
| 117 |
+
# Skip leading blank lines
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
line = []
|
| 121 |
+
for col_idx, cell in enumerate(row):
|
| 122 |
+
if cell is None:
|
| 123 |
+
cell = ''
|
| 124 |
+
padding = col_widths[col_idx] - len(str(cell))
|
| 125 |
+
line.append(f" {cell}{' ' * padding} ")
|
| 126 |
+
markdown_lines.append('|' + '|'.join(line) + '|')
|
| 127 |
+
|
| 128 |
+
if not added_header:
|
| 129 |
+
# Skip empty lines when adding the header row
|
| 130 |
+
add_header_line()
|
| 131 |
+
added_header = True
|
| 132 |
+
|
| 133 |
+
# Handle one row tables
|
| 134 |
+
if total_rows == 1:
|
| 135 |
+
add_header_line()
|
| 136 |
+
|
| 137 |
+
table_md = '\n'.join(markdown_lines)
|
| 138 |
+
return "\n\n" + table_md + "\n\n"
|
| 139 |
|
| 140 |
def convert_a(self, el, text, convert_as_inline):
|
| 141 |
text = self.escape(text)
|
marker/schema/__init__.py
CHANGED
|
@@ -27,6 +27,7 @@ class BlockTypes(str, Enum):
|
|
| 27 |
TableOfContents = auto()
|
| 28 |
Document = auto()
|
| 29 |
ComplexRegion = auto()
|
|
|
|
| 30 |
Reference = auto()
|
| 31 |
|
| 32 |
def __str__(self):
|
|
|
|
| 27 |
TableOfContents = auto()
|
| 28 |
Document = auto()
|
| 29 |
ComplexRegion = auto()
|
| 30 |
+
TableCell = auto()
|
| 31 |
Reference = auto()
|
| 32 |
|
| 33 |
def __str__(self):
|
marker/schema/blocks/__init__.py
CHANGED
|
@@ -18,4 +18,5 @@ from marker.schema.blocks.table import Table
|
|
| 18 |
from marker.schema.blocks.text import Text
|
| 19 |
from marker.schema.blocks.toc import TableOfContents
|
| 20 |
from marker.schema.blocks.complexregion import ComplexRegion
|
|
|
|
| 21 |
from marker.schema.blocks.reference import Reference
|
|
|
|
| 18 |
from marker.schema.blocks.text import Text
|
| 19 |
from marker.schema.blocks.toc import TableOfContents
|
| 20 |
from marker.schema.blocks.complexregion import ComplexRegion
|
| 21 |
+
from marker.schema.blocks.tablecell import TableCell
|
| 22 |
from marker.schema.blocks.reference import Reference
|
marker/schema/blocks/base.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Sequence
|
| 4 |
|
| 5 |
from pydantic import BaseModel, ConfigDict, field_validator
|
|
|
|
| 6 |
|
| 7 |
from marker.schema import BlockTypes
|
| 8 |
from marker.schema.polygon import PolygonBox
|
|
@@ -71,6 +72,7 @@ class BlockId(BaseModel):
|
|
| 71 |
|
| 72 |
class Block(BaseModel):
|
| 73 |
polygon: PolygonBox
|
|
|
|
| 74 |
block_type: Optional[BlockTypes] = None
|
| 75 |
block_id: Optional[int] = None
|
| 76 |
page_id: Optional[int] = None
|
|
@@ -81,6 +83,8 @@ class Block(BaseModel):
|
|
| 81 |
source: Literal['layout', 'heuristics', 'processor'] = 'layout'
|
| 82 |
top_k: Optional[Dict[BlockTypes, float]] = None
|
| 83 |
metadata: BlockMetadata | None = None
|
|
|
|
|
|
|
| 84 |
|
| 85 |
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 86 |
|
|
@@ -97,6 +101,21 @@ class Block(BaseModel):
|
|
| 97 |
block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
|
| 98 |
return cls(**block_attrs)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def structure_blocks(self, document_page: Document | PageGroup) -> List[Block]:
|
| 101 |
if self.structure is None:
|
| 102 |
return []
|
|
@@ -163,7 +182,7 @@ class Block(BaseModel):
|
|
| 163 |
text += "\n"
|
| 164 |
return text
|
| 165 |
|
| 166 |
-
def assemble_html(self, child_blocks: List[BlockOutput], parent_structure: Optional[List[str]] = None):
|
| 167 |
if self.ignore_for_output:
|
| 168 |
return ""
|
| 169 |
|
|
@@ -172,7 +191,8 @@ class Block(BaseModel):
|
|
| 172 |
template += f"<content-ref src='{c.id}'></content-ref>"
|
| 173 |
|
| 174 |
if self.replace_output_newlines:
|
| 175 |
-
template =
|
|
|
|
| 176 |
|
| 177 |
return template
|
| 178 |
|
|
@@ -205,7 +225,7 @@ class Block(BaseModel):
|
|
| 205 |
self.structure[i] = new_block.id
|
| 206 |
break
|
| 207 |
|
| 208 |
-
def render(self, document: Document, parent_structure: Optional[List[str]], section_hierarchy=None):
|
| 209 |
child_content = []
|
| 210 |
if section_hierarchy is None:
|
| 211 |
section_hierarchy = {}
|
|
@@ -219,7 +239,7 @@ class Block(BaseModel):
|
|
| 219 |
child_content.append(rendered)
|
| 220 |
|
| 221 |
return BlockOutput(
|
| 222 |
-
html=self.assemble_html(child_content, parent_structure),
|
| 223 |
polygon=self.polygon,
|
| 224 |
id=self.id,
|
| 225 |
children=child_content,
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Sequence, Tuple
|
| 4 |
|
| 5 |
from pydantic import BaseModel, ConfigDict, field_validator
|
| 6 |
+
from PIL import Image
|
| 7 |
|
| 8 |
from marker.schema import BlockTypes
|
| 9 |
from marker.schema.polygon import PolygonBox
|
|
|
|
| 72 |
|
| 73 |
class Block(BaseModel):
|
| 74 |
polygon: PolygonBox
|
| 75 |
+
block_description: str
|
| 76 |
block_type: Optional[BlockTypes] = None
|
| 77 |
block_id: Optional[int] = None
|
| 78 |
page_id: Optional[int] = None
|
|
|
|
| 83 |
source: Literal['layout', 'heuristics', 'processor'] = 'layout'
|
| 84 |
top_k: Optional[Dict[BlockTypes, float]] = None
|
| 85 |
metadata: BlockMetadata | None = None
|
| 86 |
+
lowres_image: Image.Image | None = None
|
| 87 |
+
highres_image: Image.Image | None = None
|
| 88 |
|
| 89 |
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 90 |
|
|
|
|
| 101 |
block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
|
| 102 |
return cls(**block_attrs)
|
| 103 |
|
| 104 |
+
def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None) -> Image.Image | None:
|
| 105 |
+
image = self.highres_image if highres else self.lowres_image
|
| 106 |
+
if image is None:
|
| 107 |
+
page = document.get_page(self.page_id)
|
| 108 |
+
page_image = page.highres_image if highres else page.lowres_image
|
| 109 |
+
|
| 110 |
+
# Scale to the image size
|
| 111 |
+
bbox = self.polygon.rescale((page.polygon.width, page.polygon.height), page_image.size)
|
| 112 |
+
if expansion:
|
| 113 |
+
bbox = bbox.expand(*expansion)
|
| 114 |
+
bbox = bbox.bbox
|
| 115 |
+
image = page_image.crop(bbox)
|
| 116 |
+
return image
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def structure_blocks(self, document_page: Document | PageGroup) -> List[Block]:
|
| 120 |
if self.structure is None:
|
| 121 |
return []
|
|
|
|
| 182 |
text += "\n"
|
| 183 |
return text
|
| 184 |
|
| 185 |
+
def assemble_html(self, document: Document, child_blocks: List[BlockOutput], parent_structure: Optional[List[str]] = None):
|
| 186 |
if self.ignore_for_output:
|
| 187 |
return ""
|
| 188 |
|
|
|
|
| 191 |
template += f"<content-ref src='{c.id}'></content-ref>"
|
| 192 |
|
| 193 |
if self.replace_output_newlines:
|
| 194 |
+
template = template.replace("\n", " ")
|
| 195 |
+
template = "<p>" + template + "</p>"
|
| 196 |
|
| 197 |
return template
|
| 198 |
|
|
|
|
| 225 |
self.structure[i] = new_block.id
|
| 226 |
break
|
| 227 |
|
| 228 |
+
def render(self, document: Document, parent_structure: Optional[List[str]] = None, section_hierarchy: dict | None = None):
|
| 229 |
child_content = []
|
| 230 |
if section_hierarchy is None:
|
| 231 |
section_hierarchy = {}
|
|
|
|
| 239 |
child_content.append(rendered)
|
| 240 |
|
| 241 |
return BlockOutput(
|
| 242 |
+
html=self.assemble_html(document, child_content, parent_structure),
|
| 243 |
polygon=self.polygon,
|
| 244 |
id=self.id,
|
| 245 |
children=child_content,
|
marker/schema/blocks/basetable.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from marker.schema import BlockTypes
|
| 4 |
+
from marker.schema.blocks import Block, BlockOutput
|
| 5 |
+
from marker.schema.blocks.tablecell import TableCell
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseTable(Block):
|
| 9 |
+
block_type: BlockTypes | None = None
|
| 10 |
+
html: str | None = None
|
| 11 |
+
|
| 12 |
+
def format_cells(self, document, child_blocks):
|
| 13 |
+
child_cells: List[TableCell] = [document.get_block(c.id) for c in child_blocks]
|
| 14 |
+
unique_rows = sorted(list(set([c.row_id for c in child_cells])))
|
| 15 |
+
html_repr = "<table><tbody>"
|
| 16 |
+
for row_id in unique_rows:
|
| 17 |
+
row_cells = sorted([c for c in child_cells if c.row_id == row_id], key=lambda x: x.col_id)
|
| 18 |
+
html_repr += "<tr>"
|
| 19 |
+
for cell in row_cells:
|
| 20 |
+
html_repr += cell.assemble_html(document, child_blocks, None)
|
| 21 |
+
html_repr += "</tr>"
|
| 22 |
+
html_repr += "</tbody></table>"
|
| 23 |
+
return html_repr
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def assemble_html(self, document, child_blocks: List[BlockOutput], parent_structure=None):
|
| 27 |
+
# Filter out the table cells, so they don't render twice
|
| 28 |
+
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
|
| 29 |
+
template = super().assemble_html(document, child_ref_blocks, parent_structure)
|
| 30 |
+
|
| 31 |
+
if self.html:
|
| 32 |
+
# LLM processor
|
| 33 |
+
return template + self.html
|
| 34 |
+
elif len(child_blocks) > 0 and child_blocks[0].id.block_type == BlockTypes.TableCell:
|
| 35 |
+
# Table processor
|
| 36 |
+
return template + self.format_cells(document, child_blocks)
|
| 37 |
+
else:
|
| 38 |
+
# Default text lines and spans
|
| 39 |
+
return f"<p>{template}</p>"
|
marker/schema/blocks/caption.py
CHANGED
|
@@ -4,4 +4,6 @@ from marker.schema.blocks import Block
|
|
| 4 |
|
| 5 |
class Caption(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Caption
|
|
|
|
| 7 |
replace_output_newlines: bool = True
|
|
|
|
|
|
| 4 |
|
| 5 |
class Caption(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Caption
|
| 7 |
+
block_description: str = "A text caption that is directly above or below an image or table. Only used for text describing the image or table. "
|
| 8 |
replace_output_newlines: bool = True
|
| 9 |
+
|
marker/schema/blocks/code.py
CHANGED
|
@@ -7,8 +7,9 @@ from marker.schema.blocks import Block
|
|
| 7 |
class Code(Block):
|
| 8 |
block_type: BlockTypes = BlockTypes.Code
|
| 9 |
code: str | None = None
|
|
|
|
| 10 |
|
| 11 |
-
def assemble_html(self, child_blocks, parent_structure):
|
| 12 |
code = self.code or ""
|
| 13 |
return (f"<pre>"
|
| 14 |
f"{html.escape(code)}"
|
|
|
|
| 7 |
class Code(Block):
|
| 8 |
block_type: BlockTypes = BlockTypes.Code
|
| 9 |
code: str | None = None
|
| 10 |
+
block_description: str = "A programming code block."
|
| 11 |
|
| 12 |
+
def assemble_html(self, document, child_blocks, parent_structure):
|
| 13 |
code = self.code or ""
|
| 14 |
return (f"<pre>"
|
| 15 |
f"{html.escape(code)}"
|
marker/schema/blocks/complexregion.py
CHANGED
|
@@ -5,10 +5,11 @@ from marker.schema.blocks import Block
|
|
| 5 |
class ComplexRegion(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.ComplexRegion
|
| 7 |
html: str | None = None
|
|
|
|
| 8 |
|
| 9 |
-
def assemble_html(self, child_blocks, parent_structure):
|
| 10 |
if self.html:
|
| 11 |
return self.html
|
| 12 |
else:
|
| 13 |
-
template = super().assemble_html(child_blocks, parent_structure)
|
| 14 |
return f"<p>{template}</p>"
|
|
|
|
| 5 |
class ComplexRegion(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.ComplexRegion
|
| 7 |
html: str | None = None
|
| 8 |
+
block_description: str = "A complex region that can consist of multiple different types of blocks mixed with images. This block is chosen when it is difficult to categorize the region as a single block type."
|
| 9 |
|
| 10 |
+
def assemble_html(self, document, child_blocks, parent_structure):
|
| 11 |
if self.html:
|
| 12 |
return self.html
|
| 13 |
else:
|
| 14 |
+
template = super().assemble_html(document, child_blocks, parent_structure)
|
| 15 |
return f"<p>{template}</p>"
|
marker/schema/blocks/equation.py
CHANGED
|
@@ -7,11 +7,12 @@ from marker.schema.blocks import Block
|
|
| 7 |
class Equation(Block):
|
| 8 |
block_type: BlockTypes = BlockTypes.Equation
|
| 9 |
latex: str | None = None
|
|
|
|
| 10 |
|
| 11 |
-
def assemble_html(self, child_blocks, parent_structure=None):
|
| 12 |
if self.latex:
|
| 13 |
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
|
| 14 |
-
html_out = super().assemble_html(child_ref_blocks, parent_structure)
|
| 15 |
html_out += f"<p block-type='{self.block_type}'>"
|
| 16 |
|
| 17 |
try:
|
|
@@ -33,7 +34,7 @@ class Equation(Block):
|
|
| 33 |
html_out += "</p>"
|
| 34 |
return html_out
|
| 35 |
else:
|
| 36 |
-
template = super().assemble_html(child_blocks, parent_structure)
|
| 37 |
return f"<p block-type='{self.block_type}'>{template}</p>"
|
| 38 |
|
| 39 |
@staticmethod
|
|
|
|
| 7 |
class Equation(Block):
|
| 8 |
block_type: BlockTypes = BlockTypes.Equation
|
| 9 |
latex: str | None = None
|
| 10 |
+
block_description: str = "A block math equation."
|
| 11 |
|
| 12 |
+
def assemble_html(self, document, child_blocks, parent_structure=None):
|
| 13 |
if self.latex:
|
| 14 |
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
|
| 15 |
+
html_out = super().assemble_html(document, child_ref_blocks, parent_structure)
|
| 16 |
html_out += f"<p block-type='{self.block_type}'>"
|
| 17 |
|
| 18 |
try:
|
|
|
|
| 34 |
html_out += "</p>"
|
| 35 |
return html_out
|
| 36 |
else:
|
| 37 |
+
template = super().assemble_html(document, child_blocks, parent_structure)
|
| 38 |
return f"<p block-type='{self.block_type}'>{template}</p>"
|
| 39 |
|
| 40 |
@staticmethod
|
marker/schema/blocks/figure.py
CHANGED
|
@@ -5,10 +5,11 @@ from marker.schema.blocks import Block
|
|
| 5 |
class Figure(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Figure
|
| 7 |
description: str | None = None
|
|
|
|
| 8 |
|
| 9 |
-
def assemble_html(self, child_blocks, parent_structure):
|
| 10 |
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
|
| 11 |
-
html = super().assemble_html(child_ref_blocks, parent_structure)
|
| 12 |
if self.description:
|
| 13 |
html += f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
|
| 14 |
return html
|
|
|
|
| 5 |
class Figure(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Figure
|
| 7 |
description: str | None = None
|
| 8 |
+
block_description: str = "A chart or other image that contains data."
|
| 9 |
|
| 10 |
+
def assemble_html(self, document, child_blocks, parent_structure):
|
| 11 |
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
|
| 12 |
+
html = super().assemble_html(document, child_ref_blocks, parent_structure)
|
| 13 |
if self.description:
|
| 14 |
html += f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
|
| 15 |
return html
|
marker/schema/blocks/footnote.py
CHANGED
|
@@ -4,4 +4,5 @@ from marker.schema.blocks import Block
|
|
| 4 |
|
| 5 |
class Footnote(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Footnote
|
|
|
|
| 7 |
replace_output_newlines: bool = True
|
|
|
|
| 4 |
|
| 5 |
class Footnote(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Footnote
|
| 7 |
+
block_description: str = "A footnote that explains a term or concept in the document."
|
| 8 |
replace_output_newlines: bool = True
|
marker/schema/blocks/form.py
CHANGED
|
@@ -1,20 +1,9 @@
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
-
from tabled.formats import html_format
|
| 4 |
-
from tabled.schema import SpanTableCell
|
| 5 |
-
|
| 6 |
from marker.schema import BlockTypes
|
| 7 |
-
from marker.schema.blocks import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Form(Block):
|
| 11 |
-
block_type: str = BlockTypes.Form
|
| 12 |
-
cells: List[SpanTableCell] | None = None
|
| 13 |
-
html: str | None = None
|
| 14 |
|
| 15 |
-
def assemble_html(self, child_blocks, parent_structure=None):
|
| 16 |
-
# Some processors convert the form to html
|
| 17 |
-
if self.html is not None:
|
| 18 |
-
return self.html
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from marker.schema import BlockTypes
|
| 4 |
+
from marker.schema.blocks.basetable import BaseTable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
class Form(BaseTable):
|
| 8 |
+
block_type: BlockTypes = BlockTypes.Form
|
| 9 |
+
block_description: str = "A form, such as a tax form, that contains fields and labels. It most likely doesn't have a table structure."
|
marker/schema/blocks/handwriting.py
CHANGED
|
@@ -4,4 +4,12 @@ from marker.schema.blocks import Block
|
|
| 4 |
|
| 5 |
class Handwriting(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Handwriting
|
|
|
|
|
|
|
| 7 |
replace_output_newlines: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class Handwriting(Block):
|
| 6 |
block_type: BlockTypes = BlockTypes.Handwriting
|
| 7 |
+
block_description: str = "A region that contains handwriting."
|
| 8 |
+
html: str | None = None
|
| 9 |
replace_output_newlines: bool = True
|
| 10 |
+
|
| 11 |
+
def assemble_html(self, document, child_blocks, parent_structure):
|
| 12 |
+
if self.html:
|
| 13 |
+
return self.html
|
| 14 |
+
else:
|
| 15 |
+
return super().assemble_html(document, child_blocks, parent_structure)
|