Add some tests
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/{benchmark.yml → benchmarks.yml} +13 -14
- .github/workflows/ci.yml +0 -4
- .github/workflows/scripts.yml +0 -4
- README.md +111 -57
- benchmarks/__init__.py +0 -0
- benchmarks/overall.py +0 -132
- benchmarks/overall/__init__.py +0 -0
- benchmarks/overall/display/__init__.py +0 -0
- benchmarks/overall/display/dataset.py +48 -0
- benchmarks/overall/display/table.py +68 -0
- benchmarks/overall/download/__init__.py +0 -0
- benchmarks/overall/download/base.py +60 -0
- benchmarks/overall/download/llamaparse.py +64 -0
- benchmarks/overall/download/main.py +23 -0
- benchmarks/overall/download/mathpix.py +80 -0
- benchmarks/overall/elo.py +225 -0
- benchmarks/overall/methods/__init__.py +100 -0
- benchmarks/overall/methods/docling.py +26 -0
- benchmarks/overall/methods/gt.py +29 -0
- benchmarks/overall/methods/llamaparse.py +22 -0
- benchmarks/overall/methods/marker.py +29 -0
- benchmarks/overall/methods/mathpix.py +22 -0
- benchmarks/overall/methods/schema.py +6 -0
- benchmarks/overall/overall.py +148 -0
- benchmarks/overall/registry.py +20 -0
- benchmarks/overall/schema.py +12 -0
- benchmarks/overall/scorers/__init__.py +11 -0
- benchmarks/overall/scorers/clean.py +113 -0
- benchmarks/overall/scorers/heuristic.py +96 -0
- benchmarks/overall/scorers/llm.py +147 -0
- benchmarks/overall/scorers/schema.py +6 -0
- benchmarks/scoring.py +0 -36
- benchmarks/table/__init__.py +0 -0
- benchmarks/table/gemini.py +17 -18
- benchmarks/table/inference.py +182 -0
- benchmarks/table/table.py +15 -148
- benchmarks/throughput/__init__.py +0 -0
- benchmarks/throughput/main.py +39 -0
- benchmarks/verify_scores.py +6 -6
- marker/builders/document.py +3 -1
- marker/builders/layout.py +11 -99
- marker/builders/line.py +512 -0
- marker/builders/llm_layout.py +22 -30
- marker/builders/ocr.py +62 -71
- marker/config/crawler.py +2 -1
- marker/config/parser.py +17 -1
- marker/converters/__init__.py +48 -2
- marker/converters/pdf.py +35 -29
- marker/converters/table.py +6 -5
- marker/logger.py +4 -0
.github/workflows/{benchmark.yml → benchmarks.yml}
RENAMED
|
@@ -1,33 +1,32 @@
|
|
| 1 |
-
name: Integration test
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
| 5 |
env:
|
| 6 |
-
|
| 7 |
|
| 8 |
jobs:
|
| 9 |
benchmark:
|
| 10 |
-
runs-on:
|
| 11 |
steps:
|
| 12 |
- uses: actions/checkout@v3
|
| 13 |
- name: Set up Python 3.11
|
| 14 |
uses: actions/setup-python@v4
|
| 15 |
with:
|
| 16 |
python-version: 3.11
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
- name: Install python dependencies
|
| 18 |
run: |
|
| 19 |
pip install poetry
|
| 20 |
poetry install
|
| 21 |
-
poetry remove torch
|
| 22 |
-
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 23 |
-
- name: Download benchmark data
|
| 24 |
-
run: |
|
| 25 |
-
wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
|
| 26 |
-
unzip -o benchmark_data.zip
|
| 27 |
- name: Run benchmark test
|
| 28 |
run: |
|
| 29 |
-
poetry run python benchmarks/overall.py
|
| 30 |
-
poetry run python benchmarks/verify_scores.py
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
| 1 |
+
name: Integration test
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
| 5 |
env:
|
| 6 |
+
PYTHONIOENCODING: "utf-8"
|
| 7 |
|
| 8 |
jobs:
|
| 9 |
benchmark:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
steps:
|
| 12 |
- uses: actions/checkout@v3
|
| 13 |
- name: Set up Python 3.11
|
| 14 |
uses: actions/setup-python@v4
|
| 15 |
with:
|
| 16 |
python-version: 3.11
|
| 17 |
+
- name: Install apt dependencies
|
| 18 |
+
run: |
|
| 19 |
+
sudo apt-get update
|
| 20 |
+
sudo apt-get install -y pandoc
|
| 21 |
- name: Install python dependencies
|
| 22 |
run: |
|
| 23 |
pip install poetry
|
| 24 |
poetry install
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
- name: Run benchmark test
|
| 26 |
run: |
|
| 27 |
+
poetry run python benchmarks/overall/overall.py --max_rows 5
|
| 28 |
+
poetry run python benchmarks/verify_scores.py conversion_results/benchmark/overall/result.json --type marker
|
| 29 |
+
- name: Run table benchmark
|
| 30 |
+
run: |
|
| 31 |
+
poetry run python benchmarks/table/table.py --max_rows 5
|
| 32 |
+
poetry run python benchmarks/verify_scores.py conversion_results/benchmark/table/table.json --type table
|
.github/workflows/ci.yml
CHANGED
|
@@ -2,10 +2,6 @@ name: CI tests
|
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
| 5 |
-
env:
|
| 6 |
-
TORCH_DEVICE: "cpu"
|
| 7 |
-
OCR_ENGINE: "surya"
|
| 8 |
-
|
| 9 |
jobs:
|
| 10 |
tests:
|
| 11 |
runs-on: ubuntu-latest
|
|
|
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
jobs:
|
| 6 |
tests:
|
| 7 |
runs-on: ubuntu-latest
|
.github/workflows/scripts.yml
CHANGED
|
@@ -2,10 +2,6 @@ name: Test CLI scripts
|
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
| 5 |
-
env:
|
| 6 |
-
TORCH_DEVICE: "cpu"
|
| 7 |
-
OCR_ENGINE: "surya"
|
| 8 |
-
|
| 9 |
jobs:
|
| 10 |
tests:
|
| 11 |
runs-on: ubuntu-latest
|
|
|
|
| 2 |
|
| 3 |
on: [push]
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
jobs:
|
| 6 |
tests:
|
| 7 |
runs-on: ubuntu-latest
|
README.md
CHANGED
|
@@ -3,24 +3,32 @@
|
|
| 3 |
Marker converts PDFs and images to markdown, JSON, and HTML quickly and accurately.
|
| 4 |
|
| 5 |
- Supports a range of documents in all languages
|
| 6 |
-
- Formats tables, forms, equations, links, references, and code blocks
|
| 7 |
-
- Extracts and saves images
|
| 8 |
- Removes headers/footers/other artifacts
|
| 9 |
-
-
|
| 10 |
-
- Optionally boost accuracy with
|
| 11 |
- Works on GPU, CPU, or MPS
|
| 12 |
|
| 13 |
-
##
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
-
- Detect page layout and find reading order ([surya](https://github.com/VikParuchuri/surya))
|
| 19 |
-
- Clean and format each block (heuristics, [texify](https://github.com/VikParuchuri/texify), [surya](https://github.com/VikParuchuri/surya))
|
| 20 |
-
- Optionally use an LLM to improve quality
|
| 21 |
-
- Combine blocks and postprocess complete text
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
## Examples
|
| 26 |
|
|
@@ -30,19 +38,11 @@ It only uses models where necessary, which improves speed and accuracy.
|
|
| 30 |
| [Switch Transformers](https://arxiv.org/pdf/2101.03961.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/switch_transformers/switch_trans.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/switch_trans.json) |
|
| 31 |
| [Multi-column CNN](https://arxiv.org/pdf/1804.07821.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/multicolcnn/multicolcnn.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/multicolcnn.json) |
|
| 32 |
|
| 33 |
-
## Performance
|
| 34 |
-
|
| 35 |
-

|
| 36 |
-
|
| 37 |
-
The above results are with marker setup so it takes ~7GB of VRAM on an A10.
|
| 38 |
-
|
| 39 |
-
See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instructions on how to run your own benchmarks.
|
| 40 |
-
|
| 41 |
# Commercial usage
|
| 42 |
|
| 43 |
I want marker to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage.
|
| 44 |
|
| 45 |
-
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under
|
| 46 |
|
| 47 |
# Hosted API
|
| 48 |
|
|
@@ -56,17 +56,6 @@ There's a hosted API for marker available [here](https://www.datalab.to/):
|
|
| 56 |
|
| 57 |
[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.
|
| 58 |
|
| 59 |
-
# Limitations
|
| 60 |
-
|
| 61 |
-
PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
|
| 62 |
-
|
| 63 |
-
- Marker will only convert block equations
|
| 64 |
-
- Tables are not always formatted 100% correctly
|
| 65 |
-
- Forms are not converted optimally
|
| 66 |
-
- Very complex layouts, with nested tables and forms, may not work
|
| 67 |
-
|
| 68 |
-
Note: Passing the `--use_llm` flag will mostly solve these issues.
|
| 69 |
-
|
| 70 |
# Installation
|
| 71 |
|
| 72 |
You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details.
|
|
@@ -82,7 +71,7 @@ pip install marker-pdf
|
|
| 82 |
First, some configuration:
|
| 83 |
|
| 84 |
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
|
| 85 |
-
- Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag
|
| 86 |
|
| 87 |
## Interactive App
|
| 88 |
|
|
@@ -116,6 +105,8 @@ Options:
|
|
| 116 |
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
|
| 117 |
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
|
| 118 |
- `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
|
|
|
|
|
|
|
| 119 |
|
| 120 |
The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/recognition/languages.py). If you don't need OCR, marker can work with any language.
|
| 121 |
|
|
@@ -157,7 +148,7 @@ text, _, images = text_from_rendered(rendered)
|
|
| 157 |
|
| 158 |
### Custom configuration
|
| 159 |
|
| 160 |
-
You can pass configuration using the `ConfigParser
|
| 161 |
|
| 162 |
```python
|
| 163 |
from marker.converters.pdf import PdfConverter
|
|
@@ -174,7 +165,8 @@ converter = PdfConverter(
|
|
| 174 |
config=config_parser.generate_config_dict(),
|
| 175 |
artifact_dict=create_model_dict(),
|
| 176 |
processor_list=config_parser.get_processors(),
|
| 177 |
-
renderer=config_parser.get_renderer()
|
|
|
|
| 178 |
)
|
| 179 |
rendered = converter("FILEPATH")
|
| 180 |
```
|
|
@@ -219,11 +211,11 @@ rendered = converter("FILEPATH")
|
|
| 219 |
text, _, images = text_from_rendered(rendered)
|
| 220 |
```
|
| 221 |
|
| 222 |
-
This takes all the same configuration as the PdfConverter. You can specify the configuration
|
| 223 |
|
| 224 |
You can also run this via the CLI with
|
| 225 |
```shell
|
| 226 |
-
|
| 227 |
```
|
| 228 |
|
| 229 |
# Output Formats
|
|
@@ -321,6 +313,16 @@ All output formats will return a metadata dictionary, with the following fields:
|
|
| 321 |
}
|
| 322 |
```
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# Internals
|
| 325 |
|
| 326 |
Marker is easy to extend. The core units of marker are:
|
|
@@ -377,36 +379,55 @@ There are some settings that you may find useful if things aren't working the wa
|
|
| 377 |
Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
|
| 378 |
|
| 379 |
# Benchmarks
|
|
|
|
| 380 |
## Overall PDF Conversion
|
| 381 |
-
Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods. It's noisy, but at least directionally correct.
|
| 382 |
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|---------|----------------|---------------|------------------|
|
| 387 |
-
| marker | 0.625115 | 0.234184 | 21.545 |
|
| 388 |
|
| 389 |
-
|
| 390 |
|
| 391 |
-
|
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
|
| 396 |
|
| 397 |
-
|
| 398 |
|
| 399 |
-
|
|
|
|
|
|
|
| 400 |
|
| 401 |
-
|
| 402 |
|
| 403 |
## Table Conversion
|
|
|
|
| 404 |
Marker can extract tables from PDFs using `marker.converters.table.TableConverter`. The table extraction performance is measured by comparing the extracted HTML representation of tables against the original HTML representations using the test split of [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/). The HTML representations are compared using a tree edit distance based metric to judge both structure and content. Marker detects and identifies the structure of all tables in a PDF page and achieves these scores:
|
| 405 |
|
| 406 |
-
| Avg score | Total tables |
|
| 407 |
-
|
| 408 |
-
| 0.
|
| 409 |
-
| 0.
|
|
|
|
| 410 |
|
| 411 |
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
|
| 412 |
|
|
@@ -426,16 +447,49 @@ poetry install
|
|
| 426 |
Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
|
| 427 |
|
| 428 |
```shell
|
| 429 |
-
python benchmarks/overall.py
|
| 430 |
```
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
### Table Conversion
|
| 433 |
The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:
|
| 434 |
|
| 435 |
```shell
|
| 436 |
-
python benchmarks/table/table.py
|
| 437 |
```
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# Thanks
|
| 440 |
|
| 441 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
|
@@ -445,4 +499,4 @@ This work would not have been possible without amazing open source models and da
|
|
| 445 |
- Pypdfium2/pdfium
|
| 446 |
- DocLayNet from IBM
|
| 447 |
|
| 448 |
-
Thank you to the authors of these models and datasets for making them available to the community!
|
|
|
|
| 3 |
Marker converts PDFs and images to markdown, JSON, and HTML quickly and accurately.
|
| 4 |
|
| 5 |
- Supports a range of documents in all languages
|
| 6 |
+
- Formats tables, forms, equations, inline math, links, references, and code blocks
|
| 7 |
+
- Extracts and saves images
|
| 8 |
- Removes headers/footers/other artifacts
|
| 9 |
+
- Extensible with your own formatting and logic
|
| 10 |
+
- Optionally boost accuracy with LLMs
|
| 11 |
- Works on GPU, CPU, or MPS
|
| 12 |
|
| 13 |
+
## Performance
|
| 14 |
|
| 15 |
+
<img src="data/images/overall.png" width="800px"/>
|
| 16 |
|
| 17 |
+
Marker benchmarks favorably compared to cloud services like Llamaparse and Mathpix, as well as other open source tools.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
The above results are running single PDF pages serially. Marker is significantly faster when running in batch mode, with a projected throughput of 122 pages/second on an H100 (.18 seconds per page across 22 processes).
|
| 20 |
+
|
| 21 |
+
See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instructions on how to run your own benchmarks.
|
| 22 |
+
|
| 23 |
+
## Hybrid Mode
|
| 24 |
+
|
| 25 |
+
For the highest accuracy, pass the `--use_llm` flag to use an LLM alongside marker. This will do things like merge tables across pages, handle inline math, format tables properly, and extract values from forms. It can use any gemini or ollama model. By default, it uses `gemini-2.0-flash`. See [below](#llm-services) for details.
|
| 26 |
+
|
| 27 |
+
Here is a table benchmark comparing marker, gemini flash alone, and marker with use_llm:
|
| 28 |
+
|
| 29 |
+
<img src="data/images/table.png" width="400px"/>
|
| 30 |
+
|
| 31 |
+
As you can see, the use_llm mode offers higher accuracy than marker or gemini alone.
|
| 32 |
|
| 33 |
## Examples
|
| 34 |
|
|
|
|
| 38 |
| [Switch Transformers](https://arxiv.org/pdf/2101.03961.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/switch_transformers/switch_trans.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/switch_trans.json) |
|
| 39 |
| [Multi-column CNN](https://arxiv.org/pdf/1804.07821.pdf) | arXiv paper | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/markdown/multicolcnn/multicolcnn.md) | [View](https://github.com/VikParuchuri/marker/blob/master/data/examples/json/multicolcnn.json) |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# Commercial usage
|
| 42 |
|
| 43 |
I want marker to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage.
|
| 44 |
|
| 45 |
+
The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under \$5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to).
|
| 46 |
|
| 47 |
# Hosted API
|
| 48 |
|
|
|
|
| 56 |
|
| 57 |
[Discord](https://discord.gg//KuZwXNGnfH) is where we discuss future development.
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Installation
|
| 60 |
|
| 61 |
You'll need python 3.10+ and PyTorch. You may need to install the CPU version of torch first if you're not using a Mac or a GPU machine. See [here](https://pytorch.org/get-started/locally/) for more details.
|
|
|
|
| 71 |
First, some configuration:
|
| 72 |
|
| 73 |
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
|
| 74 |
+
- Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag to ensure your PDF runs through OCR, or the `strip_existing_ocr` to keep all digital text, and strip out any existing OCR text.
|
| 75 |
|
| 76 |
## Interactive App
|
| 77 |
|
|
|
|
| 105 |
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
|
| 106 |
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
|
| 107 |
- `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
|
| 108 |
+
- `--llm_service`: Which llm service to use if `--use_llm` is passed. This defaults to `marker.services.gemini.GoogleGeminiService`.
|
| 109 |
+
- `--help`: see all of the flags that can be passed into marker. (it supports many more options then are listed above)
|
| 110 |
|
| 111 |
The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/recognition/languages.py). If you don't need OCR, marker can work with any language.
|
| 112 |
|
|
|
|
| 148 |
|
| 149 |
### Custom configuration
|
| 150 |
|
| 151 |
+
You can pass configuration using the `ConfigParser`. To see all available options, do `marker_single --help`.
|
| 152 |
|
| 153 |
```python
|
| 154 |
from marker.converters.pdf import PdfConverter
|
|
|
|
| 165 |
config=config_parser.generate_config_dict(),
|
| 166 |
artifact_dict=create_model_dict(),
|
| 167 |
processor_list=config_parser.get_processors(),
|
| 168 |
+
renderer=config_parser.get_renderer(),
|
| 169 |
+
llm_service=config_parser.get_llm_service()
|
| 170 |
)
|
| 171 |
rendered = converter("FILEPATH")
|
| 172 |
```
|
|
|
|
| 211 |
text, _, images = text_from_rendered(rendered)
|
| 212 |
```
|
| 213 |
|
| 214 |
+
This takes all the same configuration as the PdfConverter. You can specify the configuration `force_layout_block=Table` to avoid layout detection and instead assume every page is a table. Set `output_format=json` to also get cell bounding boxes.
|
| 215 |
|
| 216 |
You can also run this via the CLI with
|
| 217 |
```shell
|
| 218 |
+
marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
|
| 219 |
```
|
| 220 |
|
| 221 |
# Output Formats
|
|
|
|
| 313 |
}
|
| 314 |
```
|
| 315 |
|
| 316 |
+
# LLM Services
|
| 317 |
+
|
| 318 |
+
When running with the `--use_llm` flag, you have a choice of services you can use:
|
| 319 |
+
|
| 320 |
+
- `Gemini` - this will use the Gemini developer API by default. You'll need to pass `--gemini_api_key` to configuration.
|
| 321 |
+
- `Google Vertex` - this will use vertex, which can be more reliable. You'll need to pass `--vertex_project_id`. To use it, set `--llm_service=marker.services.vertex.GoogleVertexService`.
|
| 322 |
+
- `Ollama` - this will use local models. You can configure `--ollama_base_url` and `--ollama_model`. To use it, set `--llm_service=marker.services.ollama.OllamaService`.
|
| 323 |
+
|
| 324 |
+
These services may have additional optional configuration as well - you can see it by viewing the classes.
|
| 325 |
+
|
| 326 |
# Internals
|
| 327 |
|
| 328 |
Marker is easy to extend. The core units of marker are:
|
|
|
|
| 379 |
Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
|
| 380 |
|
| 381 |
# Benchmarks
|
| 382 |
+
|
| 383 |
## Overall PDF Conversion
|
|
|
|
| 384 |
|
| 385 |
+
We created a [benchmark set](https://huggingface.co/datasets/datalab-to/marker_benchmark) by extracting single PDF pages from common crawl. We scored based on a heuristic that aligns text with ground truth text segments, and an LLM as a judge scoring method.
|
| 386 |
+
|
| 387 |
+
| Method | Avg Time | Heuristic Score | LLM Score |
|
| 388 |
+
|------------|----------|-----------------|-----------|
|
| 389 |
+
| marker | 2.83837 | 95.6709 | 4.23916 |
|
| 390 |
+
| llamaparse | 23.348 | 84.2442 | 3.97619 |
|
| 391 |
+
| mathpix | 6.36223 | 86.4281 | 4.15626 |
|
| 392 |
+
| docling | 3.69949 | 86.7073 | 3.70429 |
|
| 393 |
|
| 394 |
+
Benchmarks were run on an H100 for markjer and docling - llamaparse and mathpix used their cloud services. We can also look at it by document type:
|
|
|
|
|
|
|
| 395 |
|
| 396 |
+
<img src="data/images/per_doc.png" width="1000px"/>
|
| 397 |
|
| 398 |
+
| Document Type | Marker heuristic | Marker LLM | Llamaparse Heuristic | Llamaparse LLM | Mathpix Heuristic | Mathpix LLM | Docling Heuristic | Docling LLM |
|
| 399 |
+
|----------------------|------------------|------------|----------------------|----------------|-------------------|-------------|-------------------|-------------|
|
| 400 |
+
| Scientific paper | 96.6737 | 4.34899 | 87.1651 | 3.96421 | 91.2267 | 4.46861 | 92.135 | 3.72422 |
|
| 401 |
+
| Book page | 97.1846 | 4.16168 | 90.9532 | 4.07186 | 93.8886 | 4.35329 | 90.0556 | 3.64671 |
|
| 402 |
+
| Other | 95.1632 | 4.25076 | 81.1385 | 4.01835 | 79.6231 | 4.00306 | 83.8223 | 3.76147 |
|
| 403 |
+
| Form | 88.0147 | 3.84663 | 66.3081 | 3.68712 | 64.7512 | 3.33129 | 68.3857 | 3.40491 |
|
| 404 |
+
| Presentation | 95.1562 | 4.13669 | 81.2261 | 4 | 83.6737 | 3.95683 | 84.8405 | 3.86331 |
|
| 405 |
+
| Financial document | 95.3697 | 4.39106 | 82.5812 | 4.16111 | 81.3115 | 4.05556 | 86.3882 | 3.8 |
|
| 406 |
+
| Letter | 98.4021 | 4.5 | 93.4477 | 4.28125 | 96.0383 | 4.45312 | 92.0952 | 4.09375 |
|
| 407 |
+
| Engineering document | 93.9244 | 4.04412 | 77.4854 | 3.72059 | 80.3319 | 3.88235 | 79.6807 | 3.42647 |
|
| 408 |
+
| Legal document | 96.689 | 4.27759 | 86.9769 | 3.87584 | 91.601 | 4.20805 | 87.8383 | 3.65552 |
|
| 409 |
+
| Newspaper page | 98.8733 | 4.25806 | 84.7492 | 3.90323 | 96.9963 | 4.45161 | 92.6496 | 3.51613 |
|
| 410 |
+
| Magazine page | 98.2145 | 4.38776 | 87.2902 | 3.97959 | 93.5934 | 4.16327 | 93.0892 | 4.02041 |
|
| 411 |
|
| 412 |
+
## Throughput
|
| 413 |
|
| 414 |
+
We benchmarked throughput using a [single long PDF](https://www.greenteapress.com/thinkpython/thinkpython.pdf).
|
| 415 |
|
| 416 |
+
| Method | Time per page | Time per document | VRAM used |
|
| 417 |
+
|---------|---------------|-------------------|---------- |
|
| 418 |
+
| marker | 0.18 | 43.42 | 3.17GB |
|
| 419 |
|
| 420 |
+
The projected throughput is 122 pages per second on an H100 - we can run 22 individual processes given the VRAM used.
|
| 421 |
|
| 422 |
## Table Conversion
|
| 423 |
+
|
| 424 |
Marker can extract tables from PDFs using `marker.converters.table.TableConverter`. The table extraction performance is measured by comparing the extracted HTML representation of tables against the original HTML representations using the test split of [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/). The HTML representations are compared using a tree edit distance based metric to judge both structure and content. Marker detects and identifies the structure of all tables in a PDF page and achieves these scores:
|
| 425 |
|
| 426 |
+
| Method | Avg score | Total tables |
|
| 427 |
+
|------------------|-----------|--------------|
|
| 428 |
+
| marker | 0.816 | 99 |
|
| 429 |
+
| marker w/use_llm | 0.907 | 99 |
|
| 430 |
+
| gemini | 0.829 | 99 |
|
| 431 |
|
| 432 |
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
|
| 433 |
|
|
|
|
| 447 |
Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
|
| 448 |
|
| 449 |
```shell
|
| 450 |
+
python benchmarks/overall.py --methods marker --scores heuristic,llm
|
| 451 |
```
|
| 452 |
|
| 453 |
+
Options:
|
| 454 |
+
|
| 455 |
+
- `--use_llm` use an llm to improve the marker results.
|
| 456 |
+
- `--max_rows` how many rows to process for the benchmark.
|
| 457 |
+
- `--methods` can be `llamaparse`, `mathpix`, `docling`, `marker`. Comma separated.
|
| 458 |
+
- `--scores` which scoring functions to use, can be `llm`, `heuristic`. Comma separated.
|
| 459 |
+
|
| 460 |
### Table Conversion
|
| 461 |
The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:
|
| 462 |
|
| 463 |
```shell
|
| 464 |
+
python benchmarks/table/table.py --max_rows 100
|
| 465 |
```
|
| 466 |
|
| 467 |
+
Options:
|
| 468 |
+
|
| 469 |
+
- `--use_llm` uses an llm with marker to improve accuracy.
|
| 470 |
+
- `--use_gemini` also benchmarks gemini 2.0 flash.
|
| 471 |
+
|
| 472 |
+
# How it works
|
| 473 |
+
|
| 474 |
+
Marker is a pipeline of deep learning models:
|
| 475 |
+
|
| 476 |
+
- Extract text, OCR if necessary (heuristics, [surya](https://github.com/VikParuchuri/surya))
|
| 477 |
+
- Detect page layout and find reading order ([surya](https://github.com/VikParuchuri/surya))
|
| 478 |
+
- Clean and format each block (heuristics, [texify](https://github.com/VikParuchuri/texify), [surya](https://github.com/VikParuchuri/surya))
|
| 479 |
+
- Optionally use an LLM to improve quality
|
| 480 |
+
- Combine blocks and postprocess complete text
|
| 481 |
+
|
| 482 |
+
It only uses models where necessary, which improves speed and accuracy.
|
| 483 |
+
|
| 484 |
+
# Limitations
|
| 485 |
+
|
| 486 |
+
PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
|
| 487 |
+
|
| 488 |
+
- Very complex layouts, with nested tables and forms, may not work
|
| 489 |
+
- Forms may not be rendered well
|
| 490 |
+
|
| 491 |
+
Note: Passing the `--use_llm` flag will mostly solve these issues.
|
| 492 |
+
|
| 493 |
# Thanks
|
| 494 |
|
| 495 |
This work would not have been possible without amazing open source models and datasets, including (but not limited to):
|
|
|
|
| 499 |
- Pypdfium2/pdfium
|
| 500 |
- DocLayNet from IBM
|
| 501 |
|
| 502 |
+
Thank you to the authors of these models and datasets for making them available to the community!
|
benchmarks/__init__.py
ADDED
|
File without changes
|
benchmarks/overall.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import tempfile
|
| 2 |
-
import time
|
| 3 |
-
from collections import defaultdict
|
| 4 |
-
|
| 5 |
-
import click
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
import pypdfium2 as pdfium
|
| 8 |
-
|
| 9 |
-
from marker.config.parser import ConfigParser
|
| 10 |
-
from marker.converters.pdf import PdfConverter
|
| 11 |
-
from marker.logger import configure_logging
|
| 12 |
-
from marker.models import create_model_dict
|
| 13 |
-
from pdftext.extraction import plain_text_output
|
| 14 |
-
import json
|
| 15 |
-
import os
|
| 16 |
-
import subprocess
|
| 17 |
-
import shutil
|
| 18 |
-
from tabulate import tabulate
|
| 19 |
-
|
| 20 |
-
from marker.settings import settings
|
| 21 |
-
from scoring import score_text
|
| 22 |
-
|
| 23 |
-
configure_logging()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def nougat_prediction(pdf_filename, batch_size=1):
|
| 27 |
-
out_dir = tempfile.mkdtemp()
|
| 28 |
-
subprocess.run(["nougat", pdf_filename, "-o", out_dir, "--no-skipping", "--recompute", "--batchsize", str(batch_size)], check=True)
|
| 29 |
-
md_file = os.listdir(out_dir)[0]
|
| 30 |
-
with open(os.path.join(out_dir, md_file), "r") as f:
|
| 31 |
-
data = f.read()
|
| 32 |
-
shutil.rmtree(out_dir)
|
| 33 |
-
return data
|
| 34 |
-
|
| 35 |
-
@click.command(help="Benchmark PDF to MD conversion.")
|
| 36 |
-
@click.argument("in_folder", type=str)
|
| 37 |
-
@click.argument("reference_folder", type=str)
|
| 38 |
-
@click.argument("out_file", type=str)
|
| 39 |
-
@click.option("--nougat", is_flag=True, help="Run nougat and compare")
|
| 40 |
-
@click.option("--md_out_path", type=str, default=None, help="Output path for generated markdown files")
|
| 41 |
-
def main(in_folder: str, reference_folder: str, out_file: str, nougat: bool, md_out_path: str):
|
| 42 |
-
methods = ["marker"]
|
| 43 |
-
if nougat:
|
| 44 |
-
methods.append("nougat")
|
| 45 |
-
|
| 46 |
-
model_dict = create_model_dict()
|
| 47 |
-
|
| 48 |
-
scores = defaultdict(dict)
|
| 49 |
-
benchmark_files = os.listdir(in_folder)
|
| 50 |
-
benchmark_files = [b for b in benchmark_files if b.endswith(".pdf")]
|
| 51 |
-
times = defaultdict(dict)
|
| 52 |
-
pages = defaultdict(int)
|
| 53 |
-
|
| 54 |
-
for idx, fname in tqdm(enumerate(benchmark_files)):
|
| 55 |
-
md_filename = fname.rsplit(".", 1)[0] + ".md"
|
| 56 |
-
|
| 57 |
-
reference_filename = os.path.join(reference_folder, md_filename)
|
| 58 |
-
with open(reference_filename, "r") as f:
|
| 59 |
-
reference = f.read()
|
| 60 |
-
|
| 61 |
-
pdf_filename = os.path.join(in_folder, fname)
|
| 62 |
-
doc = pdfium.PdfDocument(pdf_filename)
|
| 63 |
-
pages[fname] = len(doc)
|
| 64 |
-
|
| 65 |
-
config_parser = ConfigParser({"output_format": "markdown"})
|
| 66 |
-
for method in methods:
|
| 67 |
-
start = time.time()
|
| 68 |
-
if method == "marker":
|
| 69 |
-
converter = PdfConverter(
|
| 70 |
-
config=config_parser.generate_config_dict(),
|
| 71 |
-
artifact_dict=model_dict,
|
| 72 |
-
processor_list=None,
|
| 73 |
-
renderer=config_parser.get_renderer()
|
| 74 |
-
)
|
| 75 |
-
full_text = converter(pdf_filename).markdown
|
| 76 |
-
elif method == "nougat":
|
| 77 |
-
full_text = nougat_prediction(pdf_filename, batch_size=1)
|
| 78 |
-
elif method == "naive":
|
| 79 |
-
full_text = plain_text_output(doc, workers=1)
|
| 80 |
-
else:
|
| 81 |
-
raise ValueError(f"Unknown method {method}")
|
| 82 |
-
|
| 83 |
-
times[method][fname] = time.time() - start
|
| 84 |
-
|
| 85 |
-
score = score_text(full_text, reference)
|
| 86 |
-
scores[method][fname] = score
|
| 87 |
-
|
| 88 |
-
if md_out_path:
|
| 89 |
-
md_out_filename = f"{method}_{md_filename}"
|
| 90 |
-
with open(os.path.join(md_out_path, md_out_filename), "w+") as f:
|
| 91 |
-
f.write(full_text)
|
| 92 |
-
|
| 93 |
-
total_pages = sum(pages.values())
|
| 94 |
-
with open(out_file, "w+") as f:
|
| 95 |
-
write_data = defaultdict(dict)
|
| 96 |
-
for method in methods:
|
| 97 |
-
total_time = sum(times[method].values())
|
| 98 |
-
file_stats = {
|
| 99 |
-
fname:
|
| 100 |
-
{
|
| 101 |
-
"time": times[method][fname],
|
| 102 |
-
"score": scores[method][fname],
|
| 103 |
-
"pages": pages[fname]
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
for fname in benchmark_files
|
| 107 |
-
}
|
| 108 |
-
write_data[method] = {
|
| 109 |
-
"files": file_stats,
|
| 110 |
-
"avg_score": sum(scores[method].values()) / len(scores[method]),
|
| 111 |
-
"time_per_page": total_time / total_pages,
|
| 112 |
-
"time_per_doc": total_time / len(scores[method])
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
json.dump(write_data, f, indent=4)
|
| 116 |
-
|
| 117 |
-
summary_table = []
|
| 118 |
-
score_table = []
|
| 119 |
-
score_headers = benchmark_files
|
| 120 |
-
for method in methods:
|
| 121 |
-
summary_table.append([method, write_data[method]["avg_score"], write_data[method]["time_per_page"], write_data[method]["time_per_doc"]])
|
| 122 |
-
score_table.append([method, *[write_data[method]["files"][h]["score"] for h in score_headers]])
|
| 123 |
-
|
| 124 |
-
print(tabulate(summary_table, headers=["Method", "Average Score", "Time per page", "Time per document"]))
|
| 125 |
-
print("")
|
| 126 |
-
print("Scores by file")
|
| 127 |
-
print(tabulate(score_table, headers=["Method", *score_headers]))
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
if __name__ == "__main__":
|
| 131 |
-
main()
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/overall/__init__.py
ADDED
|
File without changes
|
benchmarks/overall/display/__init__.py
ADDED
|
File without changes
|
benchmarks/overall/display/dataset.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import datasets
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from benchmarks.overall.registry import METHOD_REGISTRY
|
| 8 |
+
from benchmarks.overall.schema import FullResult
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_dataset(bench_dataset: datasets.Dataset, result: FullResult, score_types: List[str], max_rows: int | None = None) -> datasets.Dataset:
|
| 12 |
+
rows = []
|
| 13 |
+
for idx, sample in tqdm(enumerate(bench_dataset), desc="Building dataset"):
|
| 14 |
+
if idx not in result["markdown"]:
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
if max_rows is not None and idx >= max_rows:
|
| 18 |
+
break
|
| 19 |
+
|
| 20 |
+
row = {
|
| 21 |
+
"uuid": sample["uuid"],
|
| 22 |
+
"classification": sample["classification"],
|
| 23 |
+
"language": sample["language"],
|
| 24 |
+
"img": sample["img"],
|
| 25 |
+
}
|
| 26 |
+
for method in result["markdown"][idx]:
|
| 27 |
+
if method == "gt":
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
method_cls = METHOD_REGISTRY[method]()
|
| 31 |
+
md = result["markdown"][idx][method]
|
| 32 |
+
method_img = method_cls.render(result["markdown"][idx][method])
|
| 33 |
+
row[f"{method}_md"] = md
|
| 34 |
+
row[f"{method}_img"] = method_img
|
| 35 |
+
|
| 36 |
+
for score_type in score_types:
|
| 37 |
+
try:
|
| 38 |
+
row[f"{method}_{score_type}"] = result["scores"][idx][method][score_type]["score"]
|
| 39 |
+
except KeyError:
|
| 40 |
+
row[f"{method}_{score_type}"] = -1.0 # Missing score
|
| 41 |
+
try:
|
| 42 |
+
row[f"{method}_{score_type}_detail"] = json.dumps(result["scores"][idx][method][score_type]["specific_scores"])
|
| 43 |
+
except KeyError:
|
| 44 |
+
row[f"{method}_{score_type}_detail"] = "" # Missing detail
|
| 45 |
+
rows.append(row)
|
| 46 |
+
ds = datasets.Dataset.from_list(rows)
|
| 47 |
+
return ds
|
| 48 |
+
|
benchmarks/overall/display/table.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Dict, List
|
| 3 |
+
|
| 4 |
+
import tabulate
|
| 5 |
+
|
| 6 |
+
from benchmarks.overall.schema import FullResult
|
| 7 |
+
|
| 8 |
+
def write_table(title: str, rows: list, headers: list, out_path: Path, filename: str):
|
| 9 |
+
table = tabulate.tabulate(rows, headers=headers, tablefmt="github")
|
| 10 |
+
with open(out_path / filename, "w", encoding="utf-8") as f:
|
| 11 |
+
f.write(f"# {title}\n")
|
| 12 |
+
f.write(table)
|
| 13 |
+
print(title)
|
| 14 |
+
print(table)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def print_scores(result: FullResult, out_path: Path, methods: List[str], score_types: List[str], default_score_type="heuristic", default_method="marker"):
|
| 18 |
+
document_types = list(result["averages_by_type"][default_method][default_score_type].keys())
|
| 19 |
+
headers = ["Document Type"]
|
| 20 |
+
for method in methods:
|
| 21 |
+
for score_type in score_types:
|
| 22 |
+
headers.append(f"{method} {score_type}")
|
| 23 |
+
|
| 24 |
+
document_rows = [[k] for k in document_types]
|
| 25 |
+
for i, doc_type in enumerate(document_types):
|
| 26 |
+
for method in methods:
|
| 27 |
+
for score_type in score_types:
|
| 28 |
+
avg_score = sum(result["averages_by_type"][method][score_type][doc_type]) / max(1, len(result["averages_by_type"][method][score_type][doc_type]))
|
| 29 |
+
document_rows[i].append(avg_score)
|
| 30 |
+
|
| 31 |
+
write_table("Document Types", document_rows, headers, out_path, "document_types.md")
|
| 32 |
+
|
| 33 |
+
headers = ["Block Type"]
|
| 34 |
+
block_types = list(result["averages_by_block_type"][default_method][default_score_type].keys()) # all possible blocks
|
| 35 |
+
block_score_types = list(result["averages_by_block_type"][default_method].keys())
|
| 36 |
+
for method in methods:
|
| 37 |
+
for score_type in block_score_types:
|
| 38 |
+
headers.append(f"{method} {score_type}")
|
| 39 |
+
|
| 40 |
+
block_rows = [[k] for k in block_types]
|
| 41 |
+
for i, block_type in enumerate(block_types):
|
| 42 |
+
for method in methods:
|
| 43 |
+
for score_type in block_score_types:
|
| 44 |
+
avg_score = sum(result["averages_by_block_type"][method][score_type][block_type]) / max(1, len(result["averages_by_block_type"][method][score_type][block_type]))
|
| 45 |
+
block_rows[i].append(avg_score)
|
| 46 |
+
|
| 47 |
+
write_table("Block types", block_rows, headers, out_path, "block_types.md")
|
| 48 |
+
|
| 49 |
+
headers = ["Method", "Avg Time"] + score_types
|
| 50 |
+
inference_rows = [[k] for k in methods]
|
| 51 |
+
all_raw_scores = [result["scores"][i] for i in result["scores"]]
|
| 52 |
+
for i, method in enumerate(methods):
|
| 53 |
+
avg_time = sum(result["average_times"][method]) / max(1, len(result["average_times"][method]))
|
| 54 |
+
inference_rows[i].append(avg_time)
|
| 55 |
+
for score_type in score_types:
|
| 56 |
+
scores_lst = []
|
| 57 |
+
for ar in all_raw_scores:
|
| 58 |
+
try:
|
| 59 |
+
# Sometimes a few llm scores are missing
|
| 60 |
+
scores_lst.append(ar[method][score_type]["score"])
|
| 61 |
+
except KeyError:
|
| 62 |
+
continue
|
| 63 |
+
avg_score = sum(scores_lst) / max(1, len(scores_lst))
|
| 64 |
+
inference_rows[i].append(avg_score)
|
| 65 |
+
|
| 66 |
+
write_table("Overall Results", inference_rows, headers, out_path, "overall.md")
|
| 67 |
+
|
| 68 |
+
print("Scores computed by aligning ground truth markdown blocks with predicted markdown for each method. The scores are 0-100 based on edit distance.")
|
benchmarks/overall/download/__init__.py
ADDED
|
File without changes
|
benchmarks/overall/download/base.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from json import JSONDecodeError
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import datasets
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Downloader:
|
| 10 |
+
cache_path: Path = Path("cache")
|
| 11 |
+
service: str
|
| 12 |
+
|
| 13 |
+
def __init__(self, api_key, app_id, max_rows: int = 2200):
|
| 14 |
+
self.cache_path.mkdir(exist_ok=True)
|
| 15 |
+
self.max_rows = max_rows
|
| 16 |
+
self.api_key = api_key
|
| 17 |
+
self.app_id = app_id
|
| 18 |
+
self.ds = datasets.load_dataset("datalab-to/marker_benchmark", split="train")
|
| 19 |
+
|
| 20 |
+
def get_html(self, pdf_bytes):
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
|
| 23 |
+
def upload_ds(self):
|
| 24 |
+
rows = []
|
| 25 |
+
for file in self.cache_path.glob("*.json"):
|
| 26 |
+
with open(file, "r") as f:
|
| 27 |
+
data = json.load(f)
|
| 28 |
+
rows.append(data)
|
| 29 |
+
|
| 30 |
+
out_ds = datasets.Dataset.from_list(rows, features=datasets.Features({
|
| 31 |
+
"md": datasets.Value("string"),
|
| 32 |
+
"uuid": datasets.Value("string"),
|
| 33 |
+
"time": datasets.Value("float"),
|
| 34 |
+
}))
|
| 35 |
+
out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}")
|
| 36 |
+
|
| 37 |
+
def generate_data(self):
|
| 38 |
+
max_rows = 2200
|
| 39 |
+
for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
|
| 40 |
+
cache_file = self.cache_path / f"{idx}.json"
|
| 41 |
+
if cache_file.exists():
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
pdf_bytes = sample["pdf"] # This is a single page PDF
|
| 45 |
+
try:
|
| 46 |
+
out_data = self.get_html(pdf_bytes)
|
| 47 |
+
except JSONDecodeError as e:
|
| 48 |
+
print(f"Error with sample {idx}: {e}")
|
| 49 |
+
continue
|
| 50 |
+
out_data["uuid"] = sample["uuid"]
|
| 51 |
+
|
| 52 |
+
with cache_file.open("w") as f:
|
| 53 |
+
json.dump(out_data, f)
|
| 54 |
+
|
| 55 |
+
if idx >= max_rows:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
def __call__(self):
|
| 59 |
+
self.generate_data()
|
| 60 |
+
self.upload_ds()
|
benchmarks/overall/download/llamaparse.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
from benchmarks.overall.download.base import Downloader
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LlamaParseDownloader(Downloader):
|
| 11 |
+
service = "llamaparse"
|
| 12 |
+
|
| 13 |
+
def get_html(self, pdf_bytes):
|
| 14 |
+
rand_name = str(time.time()) + ".pdf"
|
| 15 |
+
start = time.time()
|
| 16 |
+
buff = io.BytesIO(pdf_bytes)
|
| 17 |
+
md = upload_and_parse_file(self.api_key, rand_name, buff)
|
| 18 |
+
end = time.time()
|
| 19 |
+
if isinstance(md, bytes):
|
| 20 |
+
md = md.decode("utf-8")
|
| 21 |
+
|
| 22 |
+
return {
|
| 23 |
+
"md": md,
|
| 24 |
+
"time": end - start,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def upload_and_parse_file(api_key: str, fname: str, buff, max_retries: int = 180, delay: int = 1):
|
| 29 |
+
headers = {
|
| 30 |
+
"Authorization": f"Bearer {api_key}",
|
| 31 |
+
"Accept": "application/json"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Upload file
|
| 35 |
+
files = {
|
| 36 |
+
'file': (fname, buff, 'application/pdf')
|
| 37 |
+
}
|
| 38 |
+
response = requests.post(
|
| 39 |
+
'https://api.cloud.llamaindex.ai/api/v1/parsing/upload',
|
| 40 |
+
headers=headers,
|
| 41 |
+
files=files
|
| 42 |
+
)
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
job_id = response.json()['id']
|
| 45 |
+
|
| 46 |
+
# Poll for completion
|
| 47 |
+
for _ in range(max_retries):
|
| 48 |
+
status_response = requests.get(
|
| 49 |
+
f'https://api.cloud.llamaindex.ai/api/v1/parsing/job/{job_id}',
|
| 50 |
+
headers=headers
|
| 51 |
+
)
|
| 52 |
+
status_response.raise_for_status()
|
| 53 |
+
if status_response.json()['status'] == 'SUCCESS':
|
| 54 |
+
# Get results
|
| 55 |
+
result_response = requests.get(
|
| 56 |
+
f'https://api.cloud.llamaindex.ai/api/v1/parsing/job/{job_id}/result/markdown',
|
| 57 |
+
headers=headers
|
| 58 |
+
)
|
| 59 |
+
result_response.raise_for_status()
|
| 60 |
+
return result_response.json()['markdown']
|
| 61 |
+
|
| 62 |
+
time.sleep(delay)
|
| 63 |
+
|
| 64 |
+
raise TimeoutError("Job did not complete within the maximum retry attempts")
|
benchmarks/overall/download/main.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.download.llamaparse import LlamaParseDownloader
|
| 4 |
+
from benchmarks.overall.download.mathpix import MathpixDownloader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@click.command("Download data from inference services")
|
| 8 |
+
@click.argument("service", type=click.Choice(["mathpix", "llamaparse"]))
|
| 9 |
+
@click.argument("--max_rows", type=int, default=2200)
|
| 10 |
+
@click.argument("--api_key", type=str, default=None)
|
| 11 |
+
@click.argument("--app_id", type=str, default=None)
|
| 12 |
+
def main(service: str, max_rows: int, api_key: str, app_id: str):
|
| 13 |
+
registry = {
|
| 14 |
+
"mathpix": MathpixDownloader,
|
| 15 |
+
"llamaparse": LlamaParseDownloader
|
| 16 |
+
}
|
| 17 |
+
downloader = registry[service](api_key, app_id, max_rows=max_rows)
|
| 18 |
+
|
| 19 |
+
# Generate data and upload to hub
|
| 20 |
+
downloader()
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
benchmarks/overall/download/mathpix.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
from benchmarks.overall.download.base import Downloader
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MathpixDownloader(Downloader):
|
| 10 |
+
service = "mathpix"
|
| 11 |
+
|
| 12 |
+
def get_html(self, pdf_bytes):
|
| 13 |
+
headers = {
|
| 14 |
+
"app_id": self.app_id,
|
| 15 |
+
"app_key": self.api_key,
|
| 16 |
+
}
|
| 17 |
+
start = time.time()
|
| 18 |
+
pdf_id = mathpix_request(pdf_bytes, headers)
|
| 19 |
+
status = mathpix_status(pdf_id, headers)
|
| 20 |
+
if status in ["processing", "error"]:
|
| 21 |
+
md = ""
|
| 22 |
+
else:
|
| 23 |
+
md = mathpix_results(pdf_id, headers)
|
| 24 |
+
end = time.time()
|
| 25 |
+
if isinstance(md, bytes):
|
| 26 |
+
md = md.decode("utf-8")
|
| 27 |
+
|
| 28 |
+
return {
|
| 29 |
+
"md": md,
|
| 30 |
+
"time": end - start
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def mathpix_request(buffer, headers):
|
| 34 |
+
response = requests.post("https://api.mathpix.com/v3/pdf",
|
| 35 |
+
headers=headers,
|
| 36 |
+
data={
|
| 37 |
+
"options_json": json.dumps(
|
| 38 |
+
{
|
| 39 |
+
"conversion_formats": {
|
| 40 |
+
"md": True,
|
| 41 |
+
"html": True
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
},
|
| 46 |
+
files={
|
| 47 |
+
"file": buffer
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
data = response.json()
|
| 51 |
+
pdf_id = data["pdf_id"]
|
| 52 |
+
return pdf_id
|
| 53 |
+
|
| 54 |
+
def mathpix_status(pdf_id, headers):
|
| 55 |
+
max_iters = 120
|
| 56 |
+
i = 0
|
| 57 |
+
status = "processing"
|
| 58 |
+
status2 = "processing"
|
| 59 |
+
while i < max_iters:
|
| 60 |
+
time.sleep(1)
|
| 61 |
+
response = requests.get(f"https://api.mathpix.com/v3/converter/{pdf_id}",
|
| 62 |
+
headers=headers
|
| 63 |
+
)
|
| 64 |
+
status_resp = response.json()
|
| 65 |
+
if "conversion_status" not in status_resp:
|
| 66 |
+
continue
|
| 67 |
+
status = status_resp["conversion_status"]["md"]["status"]
|
| 68 |
+
status2 = status_resp["conversion_status"]["html"]["status"]
|
| 69 |
+
if status == "completed" and status2 == "completed":
|
| 70 |
+
break
|
| 71 |
+
elif status == "error" or status2 == "error":
|
| 72 |
+
break
|
| 73 |
+
out_status = "completed" if status == "completed" and status2 == "completed" else "error"
|
| 74 |
+
return out_status
|
| 75 |
+
|
| 76 |
+
def mathpix_results(pdf_id, headers, ext="md"):
|
| 77 |
+
response = requests.get(f"https://api.mathpix.com/v3/converter/{pdf_id}.{ext}",
|
| 78 |
+
headers=headers
|
| 79 |
+
)
|
| 80 |
+
return response.content
|
benchmarks/overall/elo.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import List, Dict, Tuple, Literal
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
import datasets
|
| 10 |
+
from google import genai
|
| 11 |
+
from google.genai.errors import APIError
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from marker.settings import settings
|
| 16 |
+
|
| 17 |
+
rating_prompt = """
|
| 18 |
+
You're a document analysis expert who is comparing two different markdown samples to an image to see which one represents the content of the image better. The markdown will be called version A and version B.
|
| 19 |
+
|
| 20 |
+
Here are some notes on the image and markdown:
|
| 21 |
+
- Some parts of the page may have been recognized as images and linked from the markdown, like ``.
|
| 22 |
+
- Tables will be formatted as Github flavored markdown.
|
| 23 |
+
- Block equations will be in LaTeX.
|
| 24 |
+
- The image and markdown may be in any language.
|
| 25 |
+
- The markdown is based on the text extracted from the document, and sometimes the document may have had bad OCR applied to it, resulting in gibberish text.
|
| 26 |
+
|
| 27 |
+
The markdown should fully capture the meaning and formatting of the text in the image. You'll evaluate the markdown based on the image provided.
|
| 28 |
+
|
| 29 |
+
**Instructions**
|
| 30 |
+
Follow this process to evaluate the markdown:
|
| 31 |
+
1. Carefully examine the image.
|
| 32 |
+
2. Carefully examine the first markdown input provided.
|
| 33 |
+
3. Describe how well version a represents the image.
|
| 34 |
+
4. Carefully examine the second markdown input provided.
|
| 35 |
+
5. Describe how well version B represents the image.
|
| 36 |
+
6. Compare version A and version B.
|
| 37 |
+
7. Decide which markdown representation is better, based on the criteria below. Output version_a if version a is better, and version_b if version b is better.
|
| 38 |
+
|
| 39 |
+
Use these criteria when judging the markdown:
|
| 40 |
+
- Overall - the overall quality of the markdown as compared to the image.
|
| 41 |
+
- Text quality - the quality of the text extraction from the image.
|
| 42 |
+
- Formatting quality - the quality of the formatting applied to the markdown, as compared to the image.
|
| 43 |
+
- Tables - how effectively the tables have been extracted and formatted.
|
| 44 |
+
- Forms - how effectively the forms have extracted and formatted.
|
| 45 |
+
- Equations - how effectively block equations have been converted to LaTeX.
|
| 46 |
+
- Lists - if the lists have been properly extracted and formatted.
|
| 47 |
+
- Images - if images are identified and placed correctly.
|
| 48 |
+
|
| 49 |
+
Notes on scoring:
|
| 50 |
+
- Perfect markdown will include all of the important text from the image, and the formatting will be correct (minor mistakes okay). It's okay to omit some text that isn't important to the meaning, like page numbers and chapter headings. If the entire page is an image, it's okay if the markdown is just a link to the image, unless the image would be better represented as text.
|
| 51 |
+
- Bad markdown will have major missing text segments from the markdown or completely unreadable formatting.
|
| 52 |
+
|
| 53 |
+
Output json, like in the example below.
|
| 54 |
+
|
| 55 |
+
**Example**
|
| 56 |
+
Version A
|
| 57 |
+
```markdown
|
| 58 |
+
# *Section 1*
|
| 59 |
+
This is some *markdown* extracted from a document. Here is a block equation:
|
| 60 |
+
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}$$
|
| 61 |
+
```
|
| 62 |
+
Version B
|
| 63 |
+
```markdown
|
| 64 |
+
# Section 1
|
| 65 |
+
This is some markdown extracted from a document. Here is a block equation:
|
| 66 |
+
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}$$
|
| 67 |
+
```
|
| 68 |
+
Output
|
| 69 |
+
```json
|
| 70 |
+
{
|
| 71 |
+
"image_description": "In the image, there is a section header 'Section 1', followed by some text and a block equation.",
|
| 72 |
+
"version_a_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation.",
|
| 73 |
+
"version_b_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation. The formatting in version b is slightly different from the image.",
|
| 74 |
+
"comparison": "Version A is better than version B. The text and formatting in version A matches the image better than version B.",
|
| 75 |
+
"winner": "version_a",
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
**Input**
|
| 79 |
+
Version A
|
| 80 |
+
```markdown
|
| 81 |
+
{{version_a}}
|
| 82 |
+
```
|
| 83 |
+
Version B
|
| 84 |
+
```markdown
|
| 85 |
+
{{version_b}}
|
| 86 |
+
```
|
| 87 |
+
**Output**
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
class ComparerSchema(BaseModel):
|
| 91 |
+
image_description: str
|
| 92 |
+
version_a_description: str
|
| 93 |
+
version_b_description: str
|
| 94 |
+
comparison: str
|
| 95 |
+
winner: Literal["version_a", "version_b"]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Comparer:
|
| 99 |
+
def __init__(self):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
def __call__(
|
| 103 |
+
self,
|
| 104 |
+
img: Image.Image,
|
| 105 |
+
version_a: str,
|
| 106 |
+
version_b: str
|
| 107 |
+
) -> str | None:
|
| 108 |
+
hydrated_prompt = rating_prompt.replace("{{version_a}}", version_a).replace("{{version_b}}", version_b)
|
| 109 |
+
try:
|
| 110 |
+
rating = self.llm_rater(img, hydrated_prompt)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error: {e}")
|
| 113 |
+
return
|
| 114 |
+
return rating
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def llm_rater(self, img: Image.Image, prompt: str):
|
| 118 |
+
response = self.llm_response_wrapper(
|
| 119 |
+
[img, prompt],
|
| 120 |
+
ComparerSchema
|
| 121 |
+
)
|
| 122 |
+
assert "winner" in response, f"Response missing 'winner' key: {response}"
|
| 123 |
+
return response["winner"]
|
| 124 |
+
|
| 125 |
+
def llm_response_wrapper(
|
| 126 |
+
self,
|
| 127 |
+
prompt,
|
| 128 |
+
response_schema,
|
| 129 |
+
):
|
| 130 |
+
client = genai.Client(
|
| 131 |
+
api_key=settings.GOOGLE_API_KEY,
|
| 132 |
+
http_options={"timeout": 60000}
|
| 133 |
+
)
|
| 134 |
+
try:
|
| 135 |
+
responses = client.models.generate_content(
|
| 136 |
+
model="gemini-2.0-flash",
|
| 137 |
+
contents=prompt,
|
| 138 |
+
config={
|
| 139 |
+
"temperature": 0,
|
| 140 |
+
"response_schema": response_schema,
|
| 141 |
+
"response_mime_type": "application/json",
|
| 142 |
+
},
|
| 143 |
+
)
|
| 144 |
+
output = responses.candidates[0].content.parts[0].text
|
| 145 |
+
return json.loads(output)
|
| 146 |
+
except APIError as e:
|
| 147 |
+
print(f"Hit Gemini rate limit")
|
| 148 |
+
return
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error: {e}")
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
@dataclass
|
| 154 |
+
class Method:
|
| 155 |
+
name: str
|
| 156 |
+
rating: float = 1500
|
| 157 |
+
k_factor: float = 32
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class EloSystem:
|
| 161 |
+
def __init__(self, player_names: List[str]):
|
| 162 |
+
self.methods = {name: Method(name) for name in player_names}
|
| 163 |
+
|
| 164 |
+
def expected_score(self, rating_a: float, rating_b: float) -> float:
|
| 165 |
+
return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
|
| 166 |
+
|
| 167 |
+
def update_ratings(self, winner: str, loser: str) -> Tuple[float, float]:
|
| 168 |
+
method_a = self.methods[winner]
|
| 169 |
+
method_b = self.methods[loser]
|
| 170 |
+
|
| 171 |
+
expected_a = self.expected_score(method_a.rating, method_b.rating)
|
| 172 |
+
expected_b = self.expected_score(method_b.rating, method_a.rating)
|
| 173 |
+
|
| 174 |
+
# Winner gets score of 1, loser gets 0
|
| 175 |
+
method_a.rating += method_a.k_factor * (1 - expected_a)
|
| 176 |
+
method_b.rating += method_b.k_factor * (0 - expected_b)
|
| 177 |
+
|
| 178 |
+
return method_a.rating, method_b.rating
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@click.command("Calculate ELO scores for document conversion methods")
|
| 182 |
+
@click.argument("dataset", type=str)
|
| 183 |
+
@click.option("--methods", type=str, help="List of methods to compare: comma separated like marker,mathpix")
|
| 184 |
+
@click.option("--row_samples", type=int, default=2, help="Number of samples per row")
|
| 185 |
+
@click.option("--max_rows", type=int, default=100, help="Maximum number of rows to process")
|
| 186 |
+
def main(
|
| 187 |
+
dataset: str,
|
| 188 |
+
methods: str,
|
| 189 |
+
row_samples: int,
|
| 190 |
+
max_rows: int
|
| 191 |
+
):
|
| 192 |
+
ds = datasets.load_dataset(dataset, split="train")
|
| 193 |
+
method_lst = methods.split(",")
|
| 194 |
+
elo = EloSystem(method_lst)
|
| 195 |
+
comparer = Comparer()
|
| 196 |
+
|
| 197 |
+
for i in tqdm(range(min(len(ds), max_rows)), desc="Calculating ELO"):
|
| 198 |
+
row = ds[i]
|
| 199 |
+
# Avoid any bias in ordering
|
| 200 |
+
random.shuffle(method_lst)
|
| 201 |
+
|
| 202 |
+
for j, method_a in enumerate(method_lst[:-1]):
|
| 203 |
+
for z, method_b in enumerate(method_lst[j:]):
|
| 204 |
+
if method_a == method_b:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
method_a_md = row[f"{method_a}_md"]
|
| 208 |
+
method_b_md = row[f"{method_b}_md"]
|
| 209 |
+
winner = comparer(row["img"], method_a_md, method_b_md)
|
| 210 |
+
if not winner:
|
| 211 |
+
continue
|
| 212 |
+
|
| 213 |
+
if winner == "version_a":
|
| 214 |
+
elo.update_ratings(method_a, method_b)
|
| 215 |
+
else:
|
| 216 |
+
elo.update_ratings(method_b, method_a)
|
| 217 |
+
if i % 10 == 0:
|
| 218 |
+
print(elo.methods)
|
| 219 |
+
|
| 220 |
+
# Print out ratings
|
| 221 |
+
print(elo.methods)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
main()
|
benchmarks/overall/methods/__init__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import markdown2
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from playwright.sync_api import sync_playwright
|
| 9 |
+
|
| 10 |
+
from benchmarks.overall.methods.schema import BenchmarkResult
|
| 11 |
+
from marker.renderers.markdown import MarkdownRenderer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseMethod:
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
for kwarg in kwargs:
|
| 17 |
+
if hasattr(self, kwarg):
|
| 18 |
+
setattr(self, kwarg, kwargs[kwarg])
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def convert_to_md(html: str):
|
| 22 |
+
md = MarkdownRenderer()
|
| 23 |
+
markdown = md.md_cls.convert(html)
|
| 24 |
+
return markdown
|
| 25 |
+
|
| 26 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 27 |
+
raise NotImplementedError()
|
| 28 |
+
|
| 29 |
+
def render(self, markdown: str):
|
| 30 |
+
return self.html_to_image(self.convert_to_html(markdown))
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def convert_to_html(md: str):
|
| 34 |
+
block_placeholders = []
|
| 35 |
+
inline_placeholders = []
|
| 36 |
+
|
| 37 |
+
# Add placeholders for the math
|
| 38 |
+
def block_sub(match):
|
| 39 |
+
content = match.group(1)
|
| 40 |
+
placeholder = f"1BLOCKMATH{len(block_placeholders)}1"
|
| 41 |
+
block_placeholders.append((placeholder, f"$${content}$$"))
|
| 42 |
+
return placeholder
|
| 43 |
+
|
| 44 |
+
def inline_sub(match):
|
| 45 |
+
content = match.group(1)
|
| 46 |
+
placeholder = f"1INLINEMATH{len(inline_placeholders)}1"
|
| 47 |
+
inline_placeholders.append((placeholder, f"${content}$"))
|
| 48 |
+
return placeholder
|
| 49 |
+
|
| 50 |
+
md = re.sub(r'\${2}(.*?)\${2}', block_sub, md, flags=re.DOTALL)
|
| 51 |
+
md = re.sub(r'\$(.*?)\$', inline_sub, md)
|
| 52 |
+
|
| 53 |
+
html = markdown2.markdown(md, extras=['tables'])
|
| 54 |
+
|
| 55 |
+
# Replace placeholders
|
| 56 |
+
for placeholder, math_str in block_placeholders:
|
| 57 |
+
html = html.replace(placeholder, math_str)
|
| 58 |
+
for placeholder, math_str in inline_placeholders:
|
| 59 |
+
html = html.replace(placeholder, math_str)
|
| 60 |
+
|
| 61 |
+
return html
|
| 62 |
+
|
| 63 |
+
def html_to_image(self, html: str) -> Image.Image:
|
| 64 |
+
with sync_playwright() as p:
|
| 65 |
+
browser = p.chromium.launch()
|
| 66 |
+
page = browser.new_page()
|
| 67 |
+
html_str = f"""
|
| 68 |
+
<!DOCTYPE html>
|
| 69 |
+
<html>
|
| 70 |
+
<head>
|
| 71 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.css" integrity="sha384-zh0CIslj+VczCZtlzBcjt5ppRcsAmDnRem7ESsYwWwg3m/OaJ2l4x7YBZl9Kxxib" crossorigin="anonymous">
|
| 72 |
+
<!-- The loading of KaTeX is deferred to speed up page rendering -->
|
| 73 |
+
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/katex.min.js" integrity="sha384-Rma6DA2IPUwhNxmrB/7S3Tno0YY7sFu9WSYMCuulLhIqYSGZ2gKCJWIqhBWqMQfh" crossorigin="anonymous"></script>
|
| 74 |
+
<!-- To automatically render math in text elements, include the auto-render extension: -->
|
| 75 |
+
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.21/dist/contrib/auto-render.min.js" integrity="sha384-hCXGrW6PitJEwbkoStFjeJxv+fSOOQKOPbJxSfM6G5sWZjAyWhXiTIIAmQqnlLlh" crossorigin="anonymous"></script>
|
| 76 |
+
</head>
|
| 77 |
+
<body>
|
| 78 |
+
{html}
|
| 79 |
+
<script>
|
| 80 |
+
document.addEventListener("DOMContentLoaded", function() {{
|
| 81 |
+
renderMathInElement(document.body, {{
|
| 82 |
+
delimiters: [
|
| 83 |
+
{{left: '$$', right: '$$', display: true}},
|
| 84 |
+
{{left: '$', right: '$', display: false}}
|
| 85 |
+
],
|
| 86 |
+
throwOnError : false
|
| 87 |
+
}});
|
| 88 |
+
}});
|
| 89 |
+
</script>
|
| 90 |
+
</body>
|
| 91 |
+
</html>
|
| 92 |
+
""".strip()
|
| 93 |
+
page.set_viewport_size({"width": 1200, "height": 800})
|
| 94 |
+
page.set_content(html_str)
|
| 95 |
+
page.wait_for_load_state("domcontentloaded")
|
| 96 |
+
page.wait_for_timeout(500) # Wait for KaTeX to render
|
| 97 |
+
screenshot_bytes = page.screenshot(full_page=True)
|
| 98 |
+
browser.close()
|
| 99 |
+
|
| 100 |
+
return Image.open(io.BytesIO(screenshot_bytes))
|
benchmarks/overall/methods/docling.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DoclingMethod(BaseMethod):
|
| 8 |
+
model_dict: dict = None
|
| 9 |
+
use_llm: bool = False
|
| 10 |
+
|
| 11 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 12 |
+
from docling.document_converter import DocumentConverter
|
| 13 |
+
pdf_bytes = sample["pdf"] # This is a single page PDF
|
| 14 |
+
converter = DocumentConverter()
|
| 15 |
+
|
| 16 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
|
| 17 |
+
f.write(pdf_bytes)
|
| 18 |
+
start = time.time()
|
| 19 |
+
result = converter.convert(f.name)
|
| 20 |
+
total = time.time() - start
|
| 21 |
+
|
| 22 |
+
return {
|
| 23 |
+
"markdown": result.document.export_to_markdown(),
|
| 24 |
+
"time": total
|
| 25 |
+
}
|
| 26 |
+
|
benchmarks/overall/methods/gt.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GTMethod(BaseMethod):
|
| 10 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 11 |
+
gt_blocks = json.loads(sample["gt_blocks"])
|
| 12 |
+
gt_html = [block["html"] for block in gt_blocks if len(block["html"]) > 0]
|
| 13 |
+
gt_markdown = [self.convert_to_md(block) for block in gt_html]
|
| 14 |
+
return {
|
| 15 |
+
"markdown": gt_markdown,
|
| 16 |
+
"time": 0
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def render(self, html: List[str]) -> Image.Image:
|
| 20 |
+
joined = "\n\n".join(html)
|
| 21 |
+
html = f"""
|
| 22 |
+
<html>
|
| 23 |
+
<head></head>
|
| 24 |
+
<body>
|
| 25 |
+
{joined}
|
| 26 |
+
</body>
|
| 27 |
+
</html>
|
| 28 |
+
""".strip()
|
| 29 |
+
return self.html_to_image(html)
|
benchmarks/overall/methods/llamaparse.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LlamaParseMethod(BaseMethod):
|
| 7 |
+
llamaparse_ds: datasets.Dataset = None
|
| 8 |
+
|
| 9 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 10 |
+
uuid = sample["uuid"]
|
| 11 |
+
data = None
|
| 12 |
+
for row in self.llamaparse_ds:
|
| 13 |
+
if str(row["uuid"]) == str(uuid):
|
| 14 |
+
data = row
|
| 15 |
+
break
|
| 16 |
+
if not data:
|
| 17 |
+
raise ValueError(f"Could not find data for uuid {uuid}")
|
| 18 |
+
|
| 19 |
+
return {
|
| 20 |
+
"markdown": data["md"],
|
| 21 |
+
"time": data["time"]
|
| 22 |
+
}
|
benchmarks/overall/methods/marker.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 5 |
+
from marker.converters.pdf import PdfConverter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MarkerMethod(BaseMethod):
|
| 9 |
+
model_dict: dict = None
|
| 10 |
+
use_llm: bool = False
|
| 11 |
+
|
| 12 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 13 |
+
pdf_bytes = sample["pdf"] # This is a single page PDF
|
| 14 |
+
block_converter = PdfConverter(
|
| 15 |
+
artifact_dict=self.model_dict,
|
| 16 |
+
config={"page_range": [0], "disable_tqdm": True, "use_llm": self.use_llm}
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
|
| 20 |
+
f.write(pdf_bytes)
|
| 21 |
+
start = time.time()
|
| 22 |
+
rendered = block_converter(f.name)
|
| 23 |
+
total = time.time() - start
|
| 24 |
+
|
| 25 |
+
return {
|
| 26 |
+
"markdown": rendered.markdown,
|
| 27 |
+
"time": total
|
| 28 |
+
}
|
| 29 |
+
|
benchmarks/overall/methods/mathpix.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datasets
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MathpixMethod(BaseMethod):
|
| 7 |
+
mathpix_ds: datasets.Dataset = None
|
| 8 |
+
|
| 9 |
+
def __call__(self, sample) -> BenchmarkResult:
|
| 10 |
+
uuid = sample["uuid"]
|
| 11 |
+
data = None
|
| 12 |
+
for row in self.mathpix_ds:
|
| 13 |
+
if str(row["uuid"]) == str(uuid):
|
| 14 |
+
data = row
|
| 15 |
+
break
|
| 16 |
+
if not data:
|
| 17 |
+
raise ValueError(f"Could not find data for uuid {uuid}")
|
| 18 |
+
|
| 19 |
+
return {
|
| 20 |
+
"markdown": data["md"],
|
| 21 |
+
"time": data["time"]
|
| 22 |
+
}
|
benchmarks/overall/methods/schema.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BenchmarkResult(TypedDict):
|
| 5 |
+
markdown: str | List[str]
|
| 6 |
+
time: float | None
|
benchmarks/overall/overall.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import click
|
| 8 |
+
import datasets
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from benchmarks.overall.display.dataset import build_dataset
|
| 12 |
+
from benchmarks.overall.registry import SCORE_REGISTRY, METHOD_REGISTRY
|
| 13 |
+
from benchmarks.overall.schema import FullResult
|
| 14 |
+
from marker.logger import configure_logging
|
| 15 |
+
from marker.models import create_model_dict
|
| 16 |
+
from marker.settings import settings
|
| 17 |
+
from benchmarks.overall.display.table import print_scores
|
| 18 |
+
|
| 19 |
+
configure_logging()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_method_scores(benchmark_dataset: datasets.Dataset, methods: List[str], score_types: List[str], artifacts: dict, max_rows=None) -> FullResult:
|
| 23 |
+
bench_scores = {}
|
| 24 |
+
averages_by_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
| 25 |
+
averages_by_block_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
| 26 |
+
average_times = defaultdict(list)
|
| 27 |
+
markdown_by_method = defaultdict(dict)
|
| 28 |
+
for idx, sample in tqdm(enumerate(benchmark_dataset), desc="Running benchmark"):
|
| 29 |
+
if max_rows is not None and idx >= max_rows:
|
| 30 |
+
break
|
| 31 |
+
|
| 32 |
+
doc_type = sample["classification"]
|
| 33 |
+
gt_cls = METHOD_REGISTRY["gt"]
|
| 34 |
+
gt_blocks = json.loads(sample["gt_blocks"])
|
| 35 |
+
gt_md = gt_cls(**artifacts)(sample)["markdown"]
|
| 36 |
+
markdown_by_method[idx]["gt"] = gt_md
|
| 37 |
+
|
| 38 |
+
out_data = defaultdict(dict)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
for method in methods:
|
| 42 |
+
method_cls = METHOD_REGISTRY[method](**artifacts)
|
| 43 |
+
method_info = method_cls(sample)
|
| 44 |
+
method_md = method_info["markdown"]
|
| 45 |
+
average_times[method].append(method_info["time"])
|
| 46 |
+
markdown_by_method[idx][method] = method_md
|
| 47 |
+
|
| 48 |
+
for score_type in score_types:
|
| 49 |
+
score_cls = SCORE_REGISTRY[score_type]()
|
| 50 |
+
try:
|
| 51 |
+
scores = score_cls(sample, gt_md, method_md)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
# Some scorers can fail, like the LLM one
|
| 54 |
+
print(f"Failed to score {method} with {score_type}: {e}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
out_data[method][score_type] = scores
|
| 58 |
+
|
| 59 |
+
averages_by_type[method][score_type][doc_type].append(scores["score"])
|
| 60 |
+
|
| 61 |
+
if "by_block" in scores["specific_scores"]: # Not all scorers support this
|
| 62 |
+
for score, gt_block in zip(scores["specific_scores"]["by_block"], gt_blocks):
|
| 63 |
+
averages_by_block_type[method][score_type][gt_block["block_type"]].append(score)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Failed to process {idx}: {e}")
|
| 66 |
+
if idx in markdown_by_method:
|
| 67 |
+
del markdown_by_method[idx]
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
bench_scores[idx] = out_data
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
"scores": bench_scores,
|
| 74 |
+
"markdown": markdown_by_method,
|
| 75 |
+
"averages_by_type": averages_by_type,
|
| 76 |
+
"averages_by_block_type": averages_by_block_type,
|
| 77 |
+
"average_times": average_times,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
@click.command(help="Benchmark PDF to MD conversion.")
|
| 81 |
+
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
|
| 82 |
+
@click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
|
| 83 |
+
@click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling", default="marker")
|
| 84 |
+
@click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
|
| 85 |
+
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
|
| 86 |
+
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
|
| 87 |
+
@click.option("--use_llm", is_flag=True, help="Use the LLM model for better marker quality.")
|
| 88 |
+
def main(
|
| 89 |
+
dataset: str,
|
| 90 |
+
out_dataset: str,
|
| 91 |
+
methods: str,
|
| 92 |
+
scores: str,
|
| 93 |
+
result_path: str,
|
| 94 |
+
max_rows: int,
|
| 95 |
+
use_llm: bool
|
| 96 |
+
):
|
| 97 |
+
out_path = Path(result_path)
|
| 98 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
methods = methods.split(",")
|
| 101 |
+
for method in methods:
|
| 102 |
+
if method not in METHOD_REGISTRY:
|
| 103 |
+
raise ValueError(f"Method {method} not allowed. Allowed methods are {METHOD_REGISTRY.keys()}")
|
| 104 |
+
|
| 105 |
+
# Ensure marker is always first
|
| 106 |
+
all_methods = list(set(methods))
|
| 107 |
+
methods = ["marker"] if "marker" in all_methods else []
|
| 108 |
+
methods += [m for m in all_methods if m != "marker"]
|
| 109 |
+
|
| 110 |
+
score_types = scores.split(",")
|
| 111 |
+
for score_type in score_types:
|
| 112 |
+
if score_type not in SCORE_REGISTRY:
|
| 113 |
+
raise ValueError(f"Score type {score_type} not allowed. Allowed types are {SCORE_REGISTRY.keys()}")
|
| 114 |
+
|
| 115 |
+
benchmark_dataset = datasets.load_dataset(dataset, split="train")
|
| 116 |
+
artifacts = {
|
| 117 |
+
"model_dict": create_model_dict(),
|
| 118 |
+
"use_llm": use_llm,
|
| 119 |
+
"mathpix_ds": None,
|
| 120 |
+
"llamaparse_ds": None,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
if "mathpix" in methods:
|
| 124 |
+
artifacts["mathpix_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mathpix", split="train")
|
| 125 |
+
|
| 126 |
+
if "llamaparse" in methods:
|
| 127 |
+
artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
|
| 128 |
+
|
| 129 |
+
print(f"Running benchmark with methods: {methods} and scores: {score_types}")
|
| 130 |
+
result = get_method_scores(benchmark_dataset, methods, score_types, artifacts, max_rows=max_rows)
|
| 131 |
+
|
| 132 |
+
# Display benchmark scoring tables
|
| 133 |
+
print_scores(result, out_path, methods, score_types, default_method=methods[0], default_score_type=score_types[0])
|
| 134 |
+
|
| 135 |
+
# Write to json
|
| 136 |
+
with open(out_path / "result.json", "w") as f:
|
| 137 |
+
json.dump(result, f)
|
| 138 |
+
|
| 139 |
+
if out_dataset:
|
| 140 |
+
if use_llm:
|
| 141 |
+
out_dataset += "_llm"
|
| 142 |
+
dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
|
| 143 |
+
dataset.push_to_hub(out_dataset)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
main()
|
| 148 |
+
|
benchmarks/overall/registry.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from benchmarks.overall.methods.docling import DoclingMethod
|
| 2 |
+
from benchmarks.overall.methods.gt import GTMethod
|
| 3 |
+
from benchmarks.overall.methods.llamaparse import LlamaParseMethod
|
| 4 |
+
from benchmarks.overall.methods.marker import MarkerMethod
|
| 5 |
+
from benchmarks.overall.methods.mathpix import MathpixMethod
|
| 6 |
+
from benchmarks.overall.scorers.heuristic import HeuristicScorer
|
| 7 |
+
from benchmarks.overall.scorers.llm import LLMScorer
|
| 8 |
+
|
| 9 |
+
SCORE_REGISTRY = {
|
| 10 |
+
"heuristic": HeuristicScorer,
|
| 11 |
+
"llm": LLMScorer
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
METHOD_REGISTRY = {
|
| 15 |
+
"marker": MarkerMethod,
|
| 16 |
+
"gt": GTMethod,
|
| 17 |
+
"mathpix": MathpixMethod,
|
| 18 |
+
"llamaparse": LlamaParseMethod,
|
| 19 |
+
"docling": DoclingMethod
|
| 20 |
+
}
|
benchmarks/overall/schema.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List, Dict
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.scorers.schema import BlockScores
|
| 4 |
+
|
| 5 |
+
AVG_TYPE = Dict[str, Dict[str, Dict[str, List[float]]]]
|
| 6 |
+
|
| 7 |
+
class FullResult(TypedDict):
|
| 8 |
+
scores: Dict[int, Dict[str, Dict[str, BlockScores]]]
|
| 9 |
+
averages_by_type: AVG_TYPE
|
| 10 |
+
averages_by_block_type: AVG_TYPE
|
| 11 |
+
average_times: Dict[str, List[float]]
|
| 12 |
+
markdown: Dict[int, Dict[str, str]]
|
benchmarks/overall/scorers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from benchmarks.overall.scorers.schema import BlockScores
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseScorer:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
def __call__(self, sample, gt_markdown: List[str], method_markdown: str) -> BlockScores:
|
| 11 |
+
raise NotImplementedError()
|
benchmarks/overall/scorers/clean.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import latex2mathml.converter
|
| 7 |
+
|
| 8 |
+
class MarkdownCleaner:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def __call__(self, markdown):
|
| 13 |
+
markdown = self.normalize_markdown(markdown) # Use pandoc to normalize
|
| 14 |
+
|
| 15 |
+
# Replace math expressions with latexml
|
| 16 |
+
pattern = r'(?<!\\)\$(?:\$([^$]+)\$\$|\s*([^$\n]+?)\s*\$)'
|
| 17 |
+
markdown = re.sub(pattern, self.standardize_math, markdown)
|
| 18 |
+
|
| 19 |
+
# Replace image urls with a generic tag
|
| 20 |
+
pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)'
|
| 21 |
+
markdown = re.sub(pattern, r'![link]', markdown)
|
| 22 |
+
|
| 23 |
+
# Clean up stray html tags
|
| 24 |
+
markdown = markdown.replace("<br>", "\n")
|
| 25 |
+
markdown = re.sub(r"<sub>(.*?)</sub>", r"\1", markdown)
|
| 26 |
+
markdown = re.sub(r"<sup>(.*?)</sup>", r"\1", markdown)
|
| 27 |
+
markdown = re.sub(r"<span.*?>(.*?)</span>", r"\1", markdown) # Remove span tags and keep content
|
| 28 |
+
|
| 29 |
+
# Clean up markdown formatting
|
| 30 |
+
markdown = re.sub(r"\s+", " ", markdown)
|
| 31 |
+
markdown = re.sub(r"\n+", "\n", markdown)
|
| 32 |
+
markdown = re.sub("\\.+", ".",
|
| 33 |
+
markdown) # Replace repeated periods with a single period, like in table of contents
|
| 34 |
+
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
|
| 35 |
+
markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly
|
| 36 |
+
return markdown.strip().lower()
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def normalize_markdown(md_text: str) -> str:
|
| 40 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 41 |
+
dirpath = Path(tmp_dir)
|
| 42 |
+
input_file = dirpath / 'input.md'
|
| 43 |
+
input_file.write_text(md_text, encoding='utf-8')
|
| 44 |
+
|
| 45 |
+
# Markdown to HTML
|
| 46 |
+
html_file = dirpath / 'temp.html'
|
| 47 |
+
subprocess.run(
|
| 48 |
+
[
|
| 49 |
+
'pandoc',
|
| 50 |
+
str(input_file),
|
| 51 |
+
'-f', 'markdown+tex_math_dollars',
|
| 52 |
+
'-t', 'html',
|
| 53 |
+
'-o', str(html_file),
|
| 54 |
+
'--quiet'
|
| 55 |
+
],
|
| 56 |
+
check=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# HTML to Markdown
|
| 60 |
+
output_file = dirpath / 'output.md'
|
| 61 |
+
subprocess.run(
|
| 62 |
+
[
|
| 63 |
+
'pandoc',
|
| 64 |
+
str(html_file),
|
| 65 |
+
'-f', 'html',
|
| 66 |
+
'-t', 'markdown+tex_math_dollars',
|
| 67 |
+
'-o', str(output_file),
|
| 68 |
+
'--quiet'
|
| 69 |
+
],
|
| 70 |
+
check=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Read back the normalized Markdown
|
| 74 |
+
normalized_md = output_file.read_text(encoding='utf-8')
|
| 75 |
+
|
| 76 |
+
return normalized_md
|
| 77 |
+
|
| 78 |
+
def standardize_math(self, match):
|
| 79 |
+
try:
|
| 80 |
+
delim = "$$" if match.group(0).startswith('$$') else "$"
|
| 81 |
+
math_content = match.group(1) or match.group(2)
|
| 82 |
+
if delim == "$$":
|
| 83 |
+
math_content = latex2mathml.converter.convert(math_content)
|
| 84 |
+
else:
|
| 85 |
+
math_content = self.clean_latex(math_content)
|
| 86 |
+
return f'{delim}{math_content}{delim}'
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Failed to standardize math expression: {match.group(0)} with error: {e}")
|
| 89 |
+
return match.group(0)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def clean_latex(latex_str):
|
| 93 |
+
latex_str = re.sub(r'\s+', ' ', latex_str.strip())
|
| 94 |
+
for tag in [r'\\text', r'\\mathrm', r'\\mathbf', r'\\textbf']:
|
| 95 |
+
latex_str = re.sub(tag + r'\{([^}]+)\}', r'\1', latex_str)
|
| 96 |
+
|
| 97 |
+
replacements = {
|
| 98 |
+
'\\times': '*',
|
| 99 |
+
'\\cdot': '*',
|
| 100 |
+
'\\div': '/',
|
| 101 |
+
'\\le': '<=',
|
| 102 |
+
'\\ge': '>=',
|
| 103 |
+
'\\neq': '!=',
|
| 104 |
+
'\\to': '\\rightarrow',
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
for old, new in replacements.items():
|
| 108 |
+
latex_str = latex_str.replace(old, new)
|
| 109 |
+
|
| 110 |
+
return latex_str
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
benchmarks/overall/scorers/heuristic.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from rapidfuzz import fuzz
|
| 4 |
+
|
| 5 |
+
from benchmarks.overall.scorers.clean import MarkdownCleaner
|
| 6 |
+
from benchmarks.overall.scorers.schema import BlockScores
|
| 7 |
+
from benchmarks.overall.scorers import BaseScorer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HeuristicScorer(BaseScorer):
|
| 11 |
+
def __call__(self, sample, gt_markdown: List[str], method_markdown: str) -> BlockScores:
|
| 12 |
+
# Standardize inputs
|
| 13 |
+
gt_markdown = [self.clean_input(block) for block in gt_markdown]
|
| 14 |
+
method_markdown = self.clean_input(method_markdown)
|
| 15 |
+
|
| 16 |
+
alignments = self.find_fuzzy_alignments(method_markdown, gt_markdown)
|
| 17 |
+
scores = [alignment["score"] for alignment in alignments]
|
| 18 |
+
|
| 19 |
+
# Find order score
|
| 20 |
+
orders = [alignment["start"] for alignment in alignments]
|
| 21 |
+
correct_order = list(range(len(gt_markdown)))
|
| 22 |
+
actual_order = sorted(range(len(gt_markdown)), key=lambda x: orders[x])
|
| 23 |
+
order_score = self.kendall_tau(correct_order, actual_order)
|
| 24 |
+
|
| 25 |
+
# Weight score by sequence length
|
| 26 |
+
gt_weights = [len(g) for g in gt_markdown]
|
| 27 |
+
weighted_scores = [score * weight for score, weight in zip(scores, gt_weights)]
|
| 28 |
+
|
| 29 |
+
# Weight the score by sequence length
|
| 30 |
+
overall_score = sum(weighted_scores) / max(1, sum(gt_weights))
|
| 31 |
+
overall_score = overall_score * 0.8 + order_score * 0.2
|
| 32 |
+
return {
|
| 33 |
+
"score": overall_score,
|
| 34 |
+
"specific_scores": {
|
| 35 |
+
"order": order_score,
|
| 36 |
+
"by_block": scores
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def kendall_tau(correct_order: List[int], actual_order: List[int]) -> float:
|
| 42 |
+
n = len(correct_order)
|
| 43 |
+
concordant = 0
|
| 44 |
+
discordant = 0
|
| 45 |
+
|
| 46 |
+
if n <= 1:
|
| 47 |
+
return 100
|
| 48 |
+
|
| 49 |
+
for i in range(n):
|
| 50 |
+
for j in range(i + 1, n):
|
| 51 |
+
correct_sign = correct_order[i] - correct_order[j]
|
| 52 |
+
actual_sign = actual_order[i] - actual_order[j]
|
| 53 |
+
|
| 54 |
+
if (correct_sign > 0 and actual_sign > 0) or (correct_sign < 0 and actual_sign < 0):
|
| 55 |
+
concordant += 1
|
| 56 |
+
elif (correct_sign < 0 and actual_sign > 0) or (correct_sign > 0 and actual_sign < 0):
|
| 57 |
+
discordant += 1
|
| 58 |
+
|
| 59 |
+
total_pairs = (n * (n - 1)) // 2
|
| 60 |
+
tau = (concordant - discordant) / total_pairs
|
| 61 |
+
tau = (tau + 1) / 2 # 0-1 scale
|
| 62 |
+
return tau * 100 # 0-100 scale
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def find_fuzzy_alignments(
|
| 66 |
+
main_string: str,
|
| 67 |
+
substrings: List[str],
|
| 68 |
+
threshold: int = 70
|
| 69 |
+
) -> List[dict]:
|
| 70 |
+
alignments = []
|
| 71 |
+
|
| 72 |
+
for idx, substr in enumerate(substrings):
|
| 73 |
+
result = fuzz.partial_ratio_alignment(substr, main_string, score_cutoff=threshold)
|
| 74 |
+
|
| 75 |
+
score = 0
|
| 76 |
+
dest_start = 0
|
| 77 |
+
dest_end = 0
|
| 78 |
+
if result:
|
| 79 |
+
score = result.score
|
| 80 |
+
dest_start = result.dest_start
|
| 81 |
+
dest_end = result.dest_end
|
| 82 |
+
|
| 83 |
+
alignments.append({
|
| 84 |
+
"string": substr,
|
| 85 |
+
"start": dest_start,
|
| 86 |
+
"end": dest_end,
|
| 87 |
+
"score": score,
|
| 88 |
+
"idx": idx
|
| 89 |
+
})
|
| 90 |
+
return alignments
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def clean_input(md: str):
|
| 95 |
+
cleaner = MarkdownCleaner()
|
| 96 |
+
return cleaner(md)
|
benchmarks/overall/scorers/llm.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import tempfile
|
| 3 |
+
import time
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from google.genai.errors import APIError
|
| 8 |
+
from google import genai
|
| 9 |
+
import pypdfium2 as pdfium
|
| 10 |
+
|
| 11 |
+
from benchmarks.overall.scorers import BaseScorer, BlockScores
|
| 12 |
+
from marker.settings import settings
|
| 13 |
+
|
| 14 |
+
rating_prompt = """
|
| 15 |
+
You're a document analysis expert who is comparing some markdown to an image to make sure the markdown is correct. You're rating how effectively the provided markdown represents the full text and formatting in the image provided.
|
| 16 |
+
You're given an image, along with the extracted markdown:
|
| 17 |
+
- Some parts of the page may have been recognized as images and linked from the markdown, like ``.
|
| 18 |
+
- Tables will be formatted as Github flavored markdown.
|
| 19 |
+
- Block equations will be in LaTeX.
|
| 20 |
+
- The image and markdown may be in any language.
|
| 21 |
+
- The markdown is based on the text extracted from the document, and sometimes the document may have had bad OCR applied to it, resulting in gibberish text.
|
| 22 |
+
|
| 23 |
+
The markdown should fully capture the meaning and formatting of the text in the image. You'll evaluate the markdown based on the image provided.
|
| 24 |
+
|
| 25 |
+
**Instructions**
|
| 26 |
+
Follow this process to evaluate the markdown:
|
| 27 |
+
1. Carefully examine the image.
|
| 28 |
+
2. Carefully examine the markdown input provided.
|
| 29 |
+
3. Compare the image to the markdown representation. Does the markdown representation properly represent the important text and formatting in the image?
|
| 30 |
+
4. Assign component scores, as described below.
|
| 31 |
+
|
| 32 |
+
These are the primary scores:
|
| 33 |
+
- Overall - the overall quality of the markdown as compared to the image.
|
| 34 |
+
- Text quality - the quality of the text extraction from the image.
|
| 35 |
+
- Formatting quality - the quality of the formatting applied to the markdown, as compared to the image.
|
| 36 |
+
|
| 37 |
+
Depending on which elements are present in the markdown, you will assign element-specific scores.
|
| 38 |
+
- Tables - how effectively the tables have been extracted and formatted.
|
| 39 |
+
- Forms - how effectively the forms have extracted and formatted.
|
| 40 |
+
- Equations - how effectively block equations have been converted to LaTeX.
|
| 41 |
+
- Section headers - if all of the section headers have been detected, and the right levels set.
|
| 42 |
+
- Lists - if the lists have been properly extracted and formatted.
|
| 43 |
+
- Images - if images are identified and placed correctly.
|
| 44 |
+
|
| 45 |
+
Notes on scoring:
|
| 46 |
+
- To get a 5/5, all of the important text from the image must appear in the markdown, and the formatting should be correct (minor mistakes okay). It's okay to omit some text that isn't important to the meaning, like page numbers and chapter headings. If the entire page is an image, it's okay if the markdown is just a link to the image, unless the image would be better represented as text.
|
| 47 |
+
- A 3/5 may have small missing text elements from the markdown and/or moderate formatting issues.
|
| 48 |
+
- A 1/5 will have major missing text segments from the markdown or completely unreadable formatting.
|
| 49 |
+
- Use 0/5 if a field isn't applicable, like if the image doesn't contain a table.
|
| 50 |
+
|
| 51 |
+
Output json, like in the example below.
|
| 52 |
+
|
| 53 |
+
**Example**
|
| 54 |
+
Input
|
| 55 |
+
```markdown
|
| 56 |
+
# Section 1
|
| 57 |
+
This is some *markdown* extracted from a document. Here is a block equation:
|
| 58 |
+
$$\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}$$
|
| 59 |
+
```
|
| 60 |
+
Output
|
| 61 |
+
```json
|
| 62 |
+
{
|
| 63 |
+
"image_description": "In the image, there is a section header 'Section 1', followed by some text and a block equation.",
|
| 64 |
+
"markdown_description": "In the markdown, there is a section header 'Section 1', followed by some text and a block equation.",
|
| 65 |
+
"comparison": "The text and formatting matches the image. There are no formatting or text extraction issues. The equations and section headers are correct.",
|
| 66 |
+
"overall": 5,
|
| 67 |
+
"text": 5,
|
| 68 |
+
"formatting": 5,
|
| 69 |
+
"section_headers": 5,
|
| 70 |
+
"tables": 0,
|
| 71 |
+
"forms": 0,
|
| 72 |
+
"equations": 5,
|
| 73 |
+
"lists": 0,
|
| 74 |
+
"images": 0
|
| 75 |
+
}
|
| 76 |
+
```
|
| 77 |
+
**Input**
|
| 78 |
+
```markdown
|
| 79 |
+
{{markdown}}
|
| 80 |
+
```
|
| 81 |
+
**Output**
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
comparison_keys = ["comparison"]
|
| 85 |
+
description_keys = ["image_description", "markdown_description"]
|
| 86 |
+
text_keys = comparison_keys + description_keys
|
| 87 |
+
score_keys = ["overall", "text", "formatting", "section_headers", "tables", "forms", "equations",
|
| 88 |
+
"lists", "images"]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LLMScorer(BaseScorer):
|
| 92 |
+
def __call__(self, sample, gt_markdown: List[str], markdown: str) -> BlockScores:
|
| 93 |
+
pdf_bytes = sample["pdf"]
|
| 94 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
|
| 95 |
+
f.write(pdf_bytes)
|
| 96 |
+
f.flush()
|
| 97 |
+
f.seek(0)
|
| 98 |
+
doc = pdfium.PdfDocument(f.name)
|
| 99 |
+
img = doc[0].render(scale=96/72).to_pil()
|
| 100 |
+
doc.close()
|
| 101 |
+
|
| 102 |
+
return self.llm_rater(img, markdown)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def llm_rater(self, img: Image.Image, markdown: str) -> BlockScores:
|
| 106 |
+
req_keys = text_keys + score_keys
|
| 107 |
+
properties = {}
|
| 108 |
+
for key in req_keys:
|
| 109 |
+
content_type = "INTEGER" if key in score_keys else "STRING"
|
| 110 |
+
properties[key] = {"type": content_type}
|
| 111 |
+
|
| 112 |
+
response_schema = {
|
| 113 |
+
"required": req_keys,
|
| 114 |
+
"properties": properties,
|
| 115 |
+
"type": "OBJECT"
|
| 116 |
+
}
|
| 117 |
+
prompt = rating_prompt.replace("{{markdown}}", markdown)
|
| 118 |
+
response = self.llm_response_wrapper([img, prompt], response_schema)
|
| 119 |
+
assert all([k in response for k in req_keys]), f"Missing keys in response: {response}"
|
| 120 |
+
return {
|
| 121 |
+
"score": response["overall"],
|
| 122 |
+
"specific_scores": response,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
def llm_response_wrapper(self, prompt, response_schema, depth=0):
|
| 126 |
+
client = genai.Client(
|
| 127 |
+
api_key=settings.GOOGLE_API_KEY,
|
| 128 |
+
http_options={"timeout": 60000}
|
| 129 |
+
)
|
| 130 |
+
try:
|
| 131 |
+
responses = client.models.generate_content(
|
| 132 |
+
model="gemini-2.0-flash",
|
| 133 |
+
contents=prompt,
|
| 134 |
+
config={
|
| 135 |
+
"temperature": 0,
|
| 136 |
+
"response_schema": response_schema,
|
| 137 |
+
"response_mime_type": "application/json",
|
| 138 |
+
},
|
| 139 |
+
)
|
| 140 |
+
output = responses.candidates[0].content.parts[0].text
|
| 141 |
+
return json.loads(output)
|
| 142 |
+
except APIError as e:
|
| 143 |
+
print(f"Hit Gemini rate limit, waiting 120 seconds")
|
| 144 |
+
time.sleep(120)
|
| 145 |
+
if depth > 2:
|
| 146 |
+
raise e
|
| 147 |
+
return self.llm_response_wrapper(prompt, response_schema, depth + 1)
|
benchmarks/overall/scorers/schema.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, List, Optional, Dict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BlockScores(TypedDict):
|
| 5 |
+
score: float
|
| 6 |
+
specific_scores: Dict[str, float | List[float]]
|
benchmarks/scoring.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
from rapidfuzz import fuzz
|
| 2 |
-
from statistics import mean
|
| 3 |
-
|
| 4 |
-
CHUNK_MIN_CHARS = 25
|
| 5 |
-
|
| 6 |
-
def chunk_text(text, chunk_len=500):
|
| 7 |
-
chunks = [text[i:i+chunk_len] for i in range(0, len(text), chunk_len)]
|
| 8 |
-
chunks = [c for c in chunks if c.strip() and len(c) > CHUNK_MIN_CHARS]
|
| 9 |
-
return chunks
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def overlap_score(hypothesis_chunks, reference_chunks):
|
| 13 |
-
length_modifier = len(hypothesis_chunks) / len(reference_chunks)
|
| 14 |
-
search_distance = max(len(reference_chunks) // 5, 10)
|
| 15 |
-
chunk_scores = []
|
| 16 |
-
for i, hyp_chunk in enumerate(hypothesis_chunks):
|
| 17 |
-
max_score = 0
|
| 18 |
-
total_len = 0
|
| 19 |
-
i_offset = int(i * length_modifier)
|
| 20 |
-
chunk_range = range(max(0, i_offset-search_distance), min(len(reference_chunks), i_offset+search_distance))
|
| 21 |
-
for j in chunk_range:
|
| 22 |
-
ref_chunk = reference_chunks[j]
|
| 23 |
-
score = fuzz.ratio(hyp_chunk, ref_chunk, score_cutoff=30) / 100
|
| 24 |
-
if score > max_score:
|
| 25 |
-
max_score = score
|
| 26 |
-
total_len = len(ref_chunk)
|
| 27 |
-
chunk_scores.append(max_score)
|
| 28 |
-
return chunk_scores
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def score_text(hypothesis, reference):
|
| 32 |
-
# Returns a 0-1 alignment score
|
| 33 |
-
hypothesis_chunks = chunk_text(hypothesis)
|
| 34 |
-
reference_chunks = chunk_text(reference)
|
| 35 |
-
chunk_scores = overlap_score(hypothesis_chunks, reference_chunks)
|
| 36 |
-
return mean(chunk_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/table/__init__.py
ADDED
|
File without changes
|
benchmarks/table/gemini.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
import json
|
| 2 |
from PIL import Image
|
| 3 |
-
|
| 4 |
-
from google.
|
|
|
|
|
|
|
|
|
|
| 5 |
from marker.settings import settings
|
| 6 |
|
| 7 |
prompt = """
|
|
@@ -19,30 +22,26 @@ Guidelines:
|
|
| 19 |
3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
|
| 20 |
""".strip()
|
| 21 |
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
def gemini_table_rec(image: Image.Image):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
properties={
|
| 29 |
-
"table_html": content.Schema(
|
| 30 |
-
type=content.Type.STRING,
|
| 31 |
-
)
|
| 32 |
-
}
|
| 33 |
)
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
responses =
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
"temperature": 0,
|
| 42 |
-
"response_schema":
|
| 43 |
"response_mime_type": "application/json",
|
| 44 |
},
|
| 45 |
-
request_options={'timeout': 60}
|
| 46 |
)
|
| 47 |
|
| 48 |
output = responses.candidates[0].content.parts[0].text
|
|
|
|
| 1 |
import json
|
| 2 |
from PIL import Image
|
| 3 |
+
from google import genai
|
| 4 |
+
from google.genai import types
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
from marker.settings import settings
|
| 9 |
|
| 10 |
prompt = """
|
|
|
|
| 22 |
3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
|
| 23 |
""".strip()
|
| 24 |
|
| 25 |
+
class TableSchema(BaseModel):
|
| 26 |
+
table_html: str
|
| 27 |
|
| 28 |
def gemini_table_rec(image: Image.Image):
|
| 29 |
+
client = genai.Client(
|
| 30 |
+
api_key=settings.GOOGLE_API_KEY,
|
| 31 |
+
http_options={"timeout": 60000}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
+
image_bytes = BytesIO()
|
| 35 |
+
image.save(image_bytes, format="PNG")
|
| 36 |
|
| 37 |
+
responses = client.models.generate_content(
|
| 38 |
+
model="gemini-2.0-flash",
|
| 39 |
+
contents=[types.Part.from_bytes(data=image_bytes.getvalue(), mime_type="image/png"), prompt], # According to gemini docs, it performs better if the image is the first element
|
| 40 |
+
config={
|
| 41 |
"temperature": 0,
|
| 42 |
+
"response_schema": TableSchema,
|
| 43 |
"response_mime_type": "application/json",
|
| 44 |
},
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
output = responses.candidates[0].content.parts[0].text
|
benchmarks/table/inference.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from bs4 import BeautifulSoup
|
| 5 |
+
import pypdfium2 as pdfium
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import base64
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
from benchmarks.table.gemini import gemini_table_rec
|
| 11 |
+
from marker.config.parser import ConfigParser
|
| 12 |
+
from marker.converters.table import TableConverter
|
| 13 |
+
from marker.models import create_model_dict
|
| 14 |
+
from marker.processors.llm.llm_table import LLMTableProcessor
|
| 15 |
+
from marker.processors.table import TableProcessor
|
| 16 |
+
from marker.renderers.json import JSONBlockOutput
|
| 17 |
+
from marker.schema.polygon import PolygonBox
|
| 18 |
+
from marker.util import matrix_intersection_area
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def extract_tables(children: List[JSONBlockOutput]):
|
| 22 |
+
tables = []
|
| 23 |
+
for child in children:
|
| 24 |
+
if child.block_type == 'Table':
|
| 25 |
+
tables.append(child)
|
| 26 |
+
elif child.children:
|
| 27 |
+
tables.extend(extract_tables(child.children))
|
| 28 |
+
return tables
|
| 29 |
+
|
| 30 |
+
def fix_table_html(table_html: str) -> str:
|
| 31 |
+
marker_table_soup = BeautifulSoup(table_html, 'html.parser')
|
| 32 |
+
tbody = marker_table_soup.find('tbody')
|
| 33 |
+
if tbody:
|
| 34 |
+
tbody.unwrap()
|
| 35 |
+
for th_tag in marker_table_soup.find_all('th'):
|
| 36 |
+
th_tag.name = 'td'
|
| 37 |
+
for br_tag in marker_table_soup.find_all('br'):
|
| 38 |
+
br_tag.replace_with(marker_table_soup.new_string(''))
|
| 39 |
+
|
| 40 |
+
marker_table_html = str(marker_table_soup)
|
| 41 |
+
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 42 |
+
return marker_table_html
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
|
| 46 |
+
models = create_model_dict()
|
| 47 |
+
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
|
| 48 |
+
total_unaligned = 0
|
| 49 |
+
results = []
|
| 50 |
+
|
| 51 |
+
iterations = len(dataset)
|
| 52 |
+
if max_rows is not None:
|
| 53 |
+
iterations = min(max_rows, len(dataset))
|
| 54 |
+
|
| 55 |
+
for i in tqdm(range(iterations), desc='Converting Tables'):
|
| 56 |
+
try:
|
| 57 |
+
row = dataset[i]
|
| 58 |
+
pdf_binary = base64.b64decode(row['pdf'])
|
| 59 |
+
gt_tables = row['tables'] # Already sorted by reading order, which is what marker returns
|
| 60 |
+
|
| 61 |
+
# Only use the basic table processors
|
| 62 |
+
converter = TableConverter(
|
| 63 |
+
config=config_parser.generate_config_dict(),
|
| 64 |
+
artifact_dict=models,
|
| 65 |
+
processor_list=[
|
| 66 |
+
"marker.processors.table.TableProcessor",
|
| 67 |
+
"marker.processors.llm.llm_table.LLMTableProcessor",
|
| 68 |
+
],
|
| 69 |
+
renderer=config_parser.get_renderer()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
|
| 73 |
+
temp_pdf_file.write(pdf_binary)
|
| 74 |
+
temp_pdf_file.seek(0)
|
| 75 |
+
marker_json = converter(temp_pdf_file.name).children
|
| 76 |
+
|
| 77 |
+
doc = pdfium.PdfDocument(temp_pdf_file.name)
|
| 78 |
+
page_image = doc[0].render(scale=96/72).to_pil()
|
| 79 |
+
doc.close()
|
| 80 |
+
|
| 81 |
+
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 82 |
+
print(f'No tables detected, skipping...')
|
| 83 |
+
total_unaligned += len(gt_tables)
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
marker_tables = extract_tables(marker_json)
|
| 87 |
+
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 88 |
+
page_bbox = marker_json[0].bbox
|
| 89 |
+
|
| 90 |
+
if len(marker_tables) != len(gt_tables):
|
| 91 |
+
print(f'Number of tables do not match, skipping...')
|
| 92 |
+
total_unaligned += len(gt_tables)
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
table_images = [
|
| 96 |
+
page_image.crop(
|
| 97 |
+
PolygonBox.from_bbox(bbox)
|
| 98 |
+
.rescale(
|
| 99 |
+
(page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
|
| 100 |
+
).bbox
|
| 101 |
+
)
|
| 102 |
+
for bbox
|
| 103 |
+
in marker_table_boxes
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# Normalize the bboxes
|
| 107 |
+
for bbox in marker_table_boxes:
|
| 108 |
+
bbox[0] = bbox[0] / page_bbox[2]
|
| 109 |
+
bbox[1] = bbox[1] / page_bbox[3]
|
| 110 |
+
bbox[2] = bbox[2] / page_bbox[2]
|
| 111 |
+
bbox[3] = bbox[3] / page_bbox[3]
|
| 112 |
+
|
| 113 |
+
gt_boxes = [table['normalized_bbox'] for table in gt_tables]
|
| 114 |
+
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
|
| 115 |
+
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
|
| 116 |
+
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
|
| 117 |
+
|
| 118 |
+
aligned_tables = []
|
| 119 |
+
used_tables = set()
|
| 120 |
+
unaligned_tables = set()
|
| 121 |
+
for table_idx, alignment in enumerate(table_alignments):
|
| 122 |
+
try:
|
| 123 |
+
max_area = np.max(alignment)
|
| 124 |
+
aligned_idx = np.argmax(alignment)
|
| 125 |
+
except ValueError:
|
| 126 |
+
# No alignment found
|
| 127 |
+
unaligned_tables.add(table_idx)
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
if max_area <= .01:
|
| 131 |
+
# No alignment found
|
| 132 |
+
unaligned_tables.add(table_idx)
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
if aligned_idx in used_tables:
|
| 136 |
+
# Marker table already aligned with another gt table
|
| 137 |
+
unaligned_tables.add(table_idx)
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
# Gt table doesn't align well with any marker table
|
| 141 |
+
gt_table_pct = gt_areas[table_idx] / max_area
|
| 142 |
+
if not .85 < gt_table_pct < 1.15:
|
| 143 |
+
unaligned_tables.add(table_idx)
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
# Marker table doesn't align with gt table
|
| 147 |
+
marker_table_pct = marker_areas[aligned_idx] / max_area
|
| 148 |
+
if not .85 < marker_table_pct < 1.15:
|
| 149 |
+
unaligned_tables.add(table_idx)
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
gemini_html = ""
|
| 153 |
+
if use_gemini:
|
| 154 |
+
try:
|
| 155 |
+
gemini_html = gemini_table_rec(table_images[aligned_idx])
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f'Gemini failed: {e}')
|
| 158 |
+
|
| 159 |
+
aligned_tables.append(
|
| 160 |
+
(marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
|
| 161 |
+
)
|
| 162 |
+
used_tables.add(aligned_idx)
|
| 163 |
+
|
| 164 |
+
total_unaligned += len(unaligned_tables)
|
| 165 |
+
|
| 166 |
+
for marker_table, gt_table, gemini_table in aligned_tables:
|
| 167 |
+
gt_table_html = gt_table['html']
|
| 168 |
+
|
| 169 |
+
# marker wraps the table in <tbody> which fintabnet data doesn't
|
| 170 |
+
# Fintabnet doesn't use th tags, need to be replaced for fair comparison
|
| 171 |
+
marker_table_html = fix_table_html(marker_table.html)
|
| 172 |
+
gemini_table_html = fix_table_html(gemini_table)
|
| 173 |
+
|
| 174 |
+
results.append({
|
| 175 |
+
"marker_table": marker_table_html,
|
| 176 |
+
"gt_table": gt_table_html,
|
| 177 |
+
"gemini_table": gemini_table_html
|
| 178 |
+
})
|
| 179 |
+
except pdfium.PdfiumError:
|
| 180 |
+
print('Broken PDF, Skipping...')
|
| 181 |
+
continue
|
| 182 |
+
return results, total_unaligned
|
benchmarks/table/table.py
CHANGED
|
@@ -1,32 +1,22 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
from tkinter import Image
|
| 4 |
-
|
| 5 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from typing import List
|
| 8 |
-
|
| 9 |
-
import base64
|
| 10 |
import time
|
| 11 |
import datasets
|
| 12 |
from tqdm import tqdm
|
| 13 |
-
import tempfile
|
| 14 |
import click
|
| 15 |
from tabulate import tabulate
|
| 16 |
import json
|
| 17 |
-
from bs4 import BeautifulSoup
|
| 18 |
from concurrent.futures import ProcessPoolExecutor
|
| 19 |
-
from pypdfium2._helpers.misc import PdfiumError
|
| 20 |
-
import pypdfium2 as pdfium
|
| 21 |
-
from marker.util import matrix_intersection_area
|
| 22 |
-
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 23 |
|
| 24 |
-
from marker.
|
| 25 |
-
from
|
| 26 |
-
from marker.models import create_model_dict
|
| 27 |
|
| 28 |
from scoring import wrap_table_html, similarity_eval_html
|
| 29 |
-
from gemini import gemini_table_rec
|
| 30 |
|
| 31 |
def update_teds_score(result, prefix: str = "marker"):
|
| 32 |
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
|
|
@@ -36,26 +26,16 @@ def update_teds_score(result, prefix: str = "marker"):
|
|
| 36 |
return result
|
| 37 |
|
| 38 |
|
| 39 |
-
def extract_tables(children: List[JSONBlockOutput]):
|
| 40 |
-
tables = []
|
| 41 |
-
for child in children:
|
| 42 |
-
if child.block_type == 'Table':
|
| 43 |
-
tables.append(child)
|
| 44 |
-
elif child.children:
|
| 45 |
-
tables.extend(extract_tables(child.children))
|
| 46 |
-
return tables
|
| 47 |
-
|
| 48 |
-
|
| 49 |
@click.command(help="Benchmark Table to HTML Conversion")
|
| 50 |
-
@click.
|
| 51 |
-
@click.option("--dataset", type=str, default="datalab-to/
|
| 52 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
| 53 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 54 |
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 55 |
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
| 56 |
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
|
| 57 |
def main(
|
| 58 |
-
|
| 59 |
dataset: str,
|
| 60 |
max_rows: int,
|
| 61 |
max_workers: int,
|
|
@@ -63,130 +43,13 @@ def main(
|
|
| 63 |
table_rec_batch_size: int | None,
|
| 64 |
use_gemini: bool = False
|
| 65 |
):
|
| 66 |
-
models = create_model_dict()
|
| 67 |
-
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
|
| 68 |
start = time.time()
|
| 69 |
|
| 70 |
|
| 71 |
dataset = datasets.load_dataset(dataset, split='train')
|
| 72 |
dataset = dataset.shuffle(seed=0)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
if max_rows is not None:
|
| 76 |
-
iterations = min(max_rows, len(dataset))
|
| 77 |
-
|
| 78 |
-
results = []
|
| 79 |
-
total_unaligned = 0
|
| 80 |
-
for i in tqdm(range(iterations), desc='Converting Tables'):
|
| 81 |
-
try:
|
| 82 |
-
row = dataset[i]
|
| 83 |
-
pdf_binary = base64.b64decode(row['pdf'])
|
| 84 |
-
gt_tables = row['tables'] #Already sorted by reading order, which is what marker returns
|
| 85 |
-
|
| 86 |
-
converter = TableConverter(
|
| 87 |
-
config=config_parser.generate_config_dict(),
|
| 88 |
-
artifact_dict=models,
|
| 89 |
-
processor_list=config_parser.get_processors(),
|
| 90 |
-
renderer=config_parser.get_renderer()
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
|
| 94 |
-
temp_pdf_file.write(pdf_binary)
|
| 95 |
-
temp_pdf_file.seek(0)
|
| 96 |
-
tqdm.disable = True
|
| 97 |
-
marker_json = converter(temp_pdf_file.name).children
|
| 98 |
-
tqdm.disable = False
|
| 99 |
-
|
| 100 |
-
doc = pdfium.PdfDocument(temp_pdf_file.name)
|
| 101 |
-
page_image = doc[0].render(scale=92/72).to_pil()
|
| 102 |
-
|
| 103 |
-
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 104 |
-
print(f'No tables detected, skipping...')
|
| 105 |
-
total_unaligned += len(gt_tables)
|
| 106 |
-
continue
|
| 107 |
-
|
| 108 |
-
marker_tables = extract_tables(marker_json)
|
| 109 |
-
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 110 |
-
page_bbox = marker_json[0].bbox
|
| 111 |
-
w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]
|
| 112 |
-
table_images = [page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox in marker_table_boxes]
|
| 113 |
-
|
| 114 |
-
# Normalize the bboxes
|
| 115 |
-
for bbox in marker_table_boxes:
|
| 116 |
-
bbox[0] = bbox[0] / page_bbox[2]
|
| 117 |
-
bbox[1] = bbox[1] / page_bbox[3]
|
| 118 |
-
bbox[2] = bbox[2] / page_bbox[2]
|
| 119 |
-
bbox[3] = bbox[3] / page_bbox[3]
|
| 120 |
-
|
| 121 |
-
gt_boxes = [table['normalized_bbox'] for table in gt_tables]
|
| 122 |
-
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
|
| 123 |
-
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
|
| 124 |
-
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
|
| 125 |
-
|
| 126 |
-
aligned_tables = []
|
| 127 |
-
used_tables = set()
|
| 128 |
-
unaligned_tables = set()
|
| 129 |
-
for table_idx, alignment in enumerate(table_alignments):
|
| 130 |
-
try:
|
| 131 |
-
max_area = np.max(alignment)
|
| 132 |
-
aligned_idx = np.argmax(alignment)
|
| 133 |
-
except ValueError:
|
| 134 |
-
# No alignment found
|
| 135 |
-
unaligned_tables.add(table_idx)
|
| 136 |
-
continue
|
| 137 |
-
|
| 138 |
-
if aligned_idx in used_tables:
|
| 139 |
-
# Marker table already aligned with another gt table
|
| 140 |
-
unaligned_tables.add(table_idx)
|
| 141 |
-
continue
|
| 142 |
-
|
| 143 |
-
# Gt table doesn't align well with any marker table
|
| 144 |
-
gt_table_pct = gt_areas[table_idx] / max_area
|
| 145 |
-
if not .75 < gt_table_pct < 1.25:
|
| 146 |
-
unaligned_tables.add(table_idx)
|
| 147 |
-
continue
|
| 148 |
-
|
| 149 |
-
# Marker table doesn't align with gt table
|
| 150 |
-
marker_table_pct = marker_areas[aligned_idx] / max_area
|
| 151 |
-
if not .75 < marker_table_pct < 1.25:
|
| 152 |
-
unaligned_tables.add(table_idx)
|
| 153 |
-
continue
|
| 154 |
-
|
| 155 |
-
gemini_html = ""
|
| 156 |
-
if use_gemini:
|
| 157 |
-
gemini_html = gemini_table_rec(table_images[aligned_idx])
|
| 158 |
-
|
| 159 |
-
aligned_tables.append(
|
| 160 |
-
(marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
|
| 161 |
-
)
|
| 162 |
-
used_tables.add(aligned_idx)
|
| 163 |
-
|
| 164 |
-
total_unaligned += len(unaligned_tables)
|
| 165 |
-
|
| 166 |
-
for marker_table, gt_table, gemini_table in aligned_tables:
|
| 167 |
-
gt_table_html = gt_table['html']
|
| 168 |
-
|
| 169 |
-
#marker wraps the table in <tbody> which fintabnet data doesn't
|
| 170 |
-
#Fintabnet doesn't use th tags, need to be replaced for fair comparison
|
| 171 |
-
marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser')
|
| 172 |
-
tbody = marker_table_soup.find('tbody')
|
| 173 |
-
if tbody:
|
| 174 |
-
tbody.unwrap()
|
| 175 |
-
for th_tag in marker_table_soup.find_all('th'):
|
| 176 |
-
th_tag.name = 'td'
|
| 177 |
-
marker_table_html = str(marker_table_soup)
|
| 178 |
-
marker_table_html = marker_table_html.replace("<br>", " ") # Fintabnet uses spaces instead of newlines
|
| 179 |
-
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 180 |
-
gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 181 |
-
|
| 182 |
-
results.append({
|
| 183 |
-
"marker_table": marker_table_html,
|
| 184 |
-
"gt_table": gt_table_html,
|
| 185 |
-
"gemini_table": gemini_table_html
|
| 186 |
-
})
|
| 187 |
-
except PdfiumError:
|
| 188 |
-
print('Broken PDF, Skipping...')
|
| 189 |
-
continue
|
| 190 |
|
| 191 |
print(f"Total time: {time.time() - start}.")
|
| 192 |
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
|
@@ -223,8 +86,12 @@ def main(
|
|
| 223 |
"gemini": gemini_results
|
| 224 |
}
|
| 225 |
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
json.dump(results, f, indent=2)
|
| 228 |
|
|
|
|
|
|
|
| 229 |
if __name__ == '__main__':
|
| 230 |
main()
|
|
|
|
| 1 |
import os
|
| 2 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from itertools import repeat
|
| 6 |
from typing import List
|
| 7 |
+
|
|
|
|
| 8 |
import time
|
| 9 |
import datasets
|
| 10 |
from tqdm import tqdm
|
|
|
|
| 11 |
import click
|
| 12 |
from tabulate import tabulate
|
| 13 |
import json
|
|
|
|
| 14 |
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from marker.settings import settings
|
| 17 |
+
from benchmarks.table.inference import inference_tables
|
|
|
|
| 18 |
|
| 19 |
from scoring import wrap_table_html, similarity_eval_html
|
|
|
|
| 20 |
|
| 21 |
def update_teds_score(result, prefix: str = "marker"):
|
| 22 |
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
|
|
|
|
| 26 |
return result
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
@click.command(help="Benchmark Table to HTML Conversion")
|
| 30 |
+
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
|
| 31 |
+
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
|
| 32 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
| 33 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 34 |
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 35 |
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
| 36 |
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
|
| 37 |
def main(
|
| 38 |
+
result_path: str,
|
| 39 |
dataset: str,
|
| 40 |
max_rows: int,
|
| 41 |
max_workers: int,
|
|
|
|
| 43 |
table_rec_batch_size: int | None,
|
| 44 |
use_gemini: bool = False
|
| 45 |
):
|
|
|
|
|
|
|
| 46 |
start = time.time()
|
| 47 |
|
| 48 |
|
| 49 |
dataset = datasets.load_dataset(dataset, split='train')
|
| 50 |
dataset = dataset.shuffle(seed=0)
|
| 51 |
|
| 52 |
+
results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
print(f"Total time: {time.time() - start}.")
|
| 55 |
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
|
|
|
| 86 |
"gemini": gemini_results
|
| 87 |
}
|
| 88 |
|
| 89 |
+
out_path = Path(result_path)
|
| 90 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 91 |
+
with open(out_path / "table.json", "w+") as f:
|
| 92 |
json.dump(results, f, indent=2)
|
| 93 |
|
| 94 |
+
print(f"Results saved to {out_path}.")
|
| 95 |
+
|
| 96 |
if __name__ == '__main__':
|
| 97 |
main()
|
benchmarks/throughput/__init__.py
ADDED
|
File without changes
|
benchmarks/throughput/main.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import click
|
| 5 |
+
import pypdfium2 as pdfium
|
| 6 |
+
|
| 7 |
+
from marker.converters.pdf import PdfConverter
|
| 8 |
+
from marker.models import create_model_dict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@click.command(help="Benchmark PDF to MD conversion throughput.")
|
| 12 |
+
@click.argument("pdf_path", type=str)
|
| 13 |
+
def main(pdf_path):
|
| 14 |
+
print(f"Converting {pdf_path} to markdown...")
|
| 15 |
+
pdf = pdfium.PdfDocument(pdf_path)
|
| 16 |
+
page_count = len(pdf)
|
| 17 |
+
pdf.close()
|
| 18 |
+
model_dict = create_model_dict()
|
| 19 |
+
torch.cuda.reset_peak_memory_stats()
|
| 20 |
+
|
| 21 |
+
times = []
|
| 22 |
+
for i in range(10):
|
| 23 |
+
block_converter = PdfConverter(
|
| 24 |
+
artifact_dict=model_dict,
|
| 25 |
+
config={"disable_tqdm": True}
|
| 26 |
+
)
|
| 27 |
+
start = time.time()
|
| 28 |
+
block_converter(pdf_path)
|
| 29 |
+
total = time.time() - start
|
| 30 |
+
times.append(total)
|
| 31 |
+
|
| 32 |
+
max_gpu_vram = torch.cuda.max_memory_allocated() / 1024 ** 3
|
| 33 |
+
|
| 34 |
+
print(f"Converted {page_count} pages in {sum(times)/len(times):.2f} seconds.")
|
| 35 |
+
print(f"Max GPU VRAM: {max_gpu_vram:.2f} GB")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
main()
|
benchmarks/verify_scores.py
CHANGED
|
@@ -6,18 +6,18 @@ def verify_scores(file_path):
|
|
| 6 |
with open(file_path, 'r') as file:
|
| 7 |
data = json.load(file)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
if
|
| 13 |
-
raise ValueError("
|
| 14 |
|
| 15 |
|
| 16 |
def verify_table_scores(file_path):
|
| 17 |
with open(file_path, 'r') as file:
|
| 18 |
data = json.load(file)
|
| 19 |
|
| 20 |
-
avg = sum([r["
|
| 21 |
if avg < 0.7:
|
| 22 |
raise ValueError("Average score is below the required threshold of 0.7")
|
| 23 |
|
|
|
|
| 6 |
with open(file_path, 'r') as file:
|
| 7 |
data = json.load(file)
|
| 8 |
|
| 9 |
+
raw_scores = [data["scores"][k] for k in data["scores"]]
|
| 10 |
+
marker_scores = [r["marker"]["heuristic"]["score"] for r in raw_scores]
|
| 11 |
+
marker_score = sum(marker_scores) / len(marker_scores)
|
| 12 |
+
if marker_score < 90:
|
| 13 |
+
raise ValueError("Marker score below 90")
|
| 14 |
|
| 15 |
|
| 16 |
def verify_table_scores(file_path):
|
| 17 |
with open(file_path, 'r') as file:
|
| 18 |
data = json.load(file)
|
| 19 |
|
| 20 |
+
avg = sum([r["marker_score"] for r in data["marker"]]) / len(data)
|
| 21 |
if avg < 0.7:
|
| 22 |
raise ValueError("Average score is below the required threshold of 0.7")
|
| 23 |
|
marker/builders/document.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Annotated
|
|
| 2 |
|
| 3 |
from marker.builders import BaseBuilder
|
| 4 |
from marker.builders.layout import LayoutBuilder
|
|
|
|
| 5 |
from marker.builders.ocr import OcrBuilder
|
| 6 |
from marker.providers.pdf import PdfProvider
|
| 7 |
from marker.schema import BlockTypes
|
|
@@ -27,9 +28,10 @@ class DocumentBuilder(BaseBuilder):
|
|
| 27 |
"Disable OCR processing.",
|
| 28 |
] = False
|
| 29 |
|
| 30 |
-
def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
|
| 31 |
document = self.build_document(provider)
|
| 32 |
layout_builder(document, provider)
|
|
|
|
| 33 |
if not self.disable_ocr:
|
| 34 |
ocr_builder(document, provider)
|
| 35 |
return document
|
|
|
|
| 2 |
|
| 3 |
from marker.builders import BaseBuilder
|
| 4 |
from marker.builders.layout import LayoutBuilder
|
| 5 |
+
from marker.builders.line import LineBuilder
|
| 6 |
from marker.builders.ocr import OcrBuilder
|
| 7 |
from marker.providers.pdf import PdfProvider
|
| 8 |
from marker.schema import BlockTypes
|
|
|
|
| 28 |
"Disable OCR processing.",
|
| 29 |
] = False
|
| 30 |
|
| 31 |
+
def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, line_builder: LineBuilder, ocr_builder: OcrBuilder):
|
| 32 |
document = self.build_document(provider)
|
| 33 |
layout_builder(document, provider)
|
| 34 |
+
line_builder(document, provider)
|
| 35 |
if not self.disable_ocr:
|
| 36 |
ocr_builder(document, provider)
|
| 37 |
return document
|
marker/builders/layout.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
-
from typing import Annotated, List, Optional
|
| 2 |
|
| 3 |
-
import numpy as np
|
| 4 |
from surya.layout import LayoutPredictor
|
| 5 |
from surya.layout.schema import LayoutResult, LayoutBox
|
| 6 |
-
from surya.ocr_error import OCRErrorPredictor
|
| 7 |
-
from surya.ocr_error.schema import OCRErrorDetectionResult
|
| 8 |
|
| 9 |
from marker.builders import BaseBuilder
|
| 10 |
-
from marker.providers import ProviderOutput, ProviderPageLines
|
| 11 |
from marker.providers.pdf import PdfProvider
|
| 12 |
from marker.schema import BlockTypes
|
| 13 |
from marker.schema.document import Document
|
|
@@ -15,45 +11,28 @@ from marker.schema.groups.page import PageGroup
|
|
| 15 |
from marker.schema.polygon import PolygonBox
|
| 16 |
from marker.schema.registry import get_block_class
|
| 17 |
from marker.settings import settings
|
| 18 |
-
from marker.util import matrix_intersection_area
|
| 19 |
|
| 20 |
|
| 21 |
class LayoutBuilder(BaseBuilder):
|
| 22 |
"""
|
| 23 |
A builder for performing layout detection on PDF pages and merging the results into the document.
|
| 24 |
"""
|
| 25 |
-
|
| 26 |
Optional[int],
|
| 27 |
"The batch size to use for the layout model.",
|
| 28 |
"Default is None, which will use the default batch size for the model."
|
| 29 |
] = None
|
| 30 |
-
layout_coverage_min_lines: Annotated[
|
| 31 |
-
int,
|
| 32 |
-
"The minimum number of PdfProvider lines that must be covered by the layout model",
|
| 33 |
-
"to consider the lines from the PdfProvider valid.",
|
| 34 |
-
] = 1
|
| 35 |
-
layout_coverage_threshold: Annotated[
|
| 36 |
-
float,
|
| 37 |
-
"The minimum coverage ratio required for the layout model to consider",
|
| 38 |
-
"the lines from the PdfProvider valid.",
|
| 39 |
-
] = .1
|
| 40 |
-
document_ocr_threshold: Annotated[
|
| 41 |
-
float,
|
| 42 |
-
"The minimum ratio of pages that must pass the layout coverage check",
|
| 43 |
-
"to avoid OCR.",
|
| 44 |
-
] = .8
|
| 45 |
-
excluded_for_coverage: Annotated[
|
| 46 |
-
Tuple[BlockTypes],
|
| 47 |
-
"A list of block types to exclude from the layout coverage check.",
|
| 48 |
-
] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
| 49 |
force_layout_block: Annotated[
|
| 50 |
str,
|
| 51 |
"Skip layout and force every page to be treated as a specific block type.",
|
| 52 |
] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
def __init__(self, layout_model: LayoutPredictor,
|
| 55 |
self.layout_model = layout_model
|
| 56 |
-
self.ocr_error_model = ocr_error_model
|
| 57 |
|
| 58 |
super().__init__(config)
|
| 59 |
|
|
@@ -64,11 +43,10 @@ class LayoutBuilder(BaseBuilder):
|
|
| 64 |
else:
|
| 65 |
layout_results = self.surya_layout(document.pages)
|
| 66 |
self.add_blocks_to_pages(document.pages, layout_results)
|
| 67 |
-
self.merge_blocks(document.pages, provider.page_lines)
|
| 68 |
|
| 69 |
def get_batch_size(self):
|
| 70 |
-
if self.
|
| 71 |
-
return self.
|
| 72 |
elif settings.TORCH_DEVICE_MODEL == "cuda":
|
| 73 |
return 6
|
| 74 |
return 6
|
|
@@ -94,26 +72,13 @@ class LayoutBuilder(BaseBuilder):
|
|
| 94 |
|
| 95 |
|
| 96 |
def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
|
|
|
|
| 97 |
layout_results = self.layout_model(
|
| 98 |
[p.get_image(highres=False) for p in pages],
|
| 99 |
batch_size=int(self.get_batch_size())
|
| 100 |
)
|
| 101 |
return layout_results
|
| 102 |
|
| 103 |
-
def surya_ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines) -> OCRErrorDetectionResult:
|
| 104 |
-
page_texts = []
|
| 105 |
-
for document_page in pages:
|
| 106 |
-
page_text = ''
|
| 107 |
-
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 108 |
-
page_text = '\n'.join(' '.join(s.text for s in line.spans) for line in provider_lines)
|
| 109 |
-
page_texts.append(page_text)
|
| 110 |
-
|
| 111 |
-
ocr_error_detection_results = self.ocr_error_model(
|
| 112 |
-
page_texts,
|
| 113 |
-
batch_size=int(self.get_batch_size()) #TODO Better Multiplier
|
| 114 |
-
)
|
| 115 |
-
return ocr_error_detection_results
|
| 116 |
-
|
| 117 |
def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
|
| 118 |
for page, layout_result in zip(pages, layout_results):
|
| 119 |
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
|
|
@@ -132,57 +97,4 @@ class LayoutBuilder(BaseBuilder):
|
|
| 132 |
|
| 133 |
# Ensure page has non-empty children
|
| 134 |
if page.children is None:
|
| 135 |
-
page.children = []
|
| 136 |
-
|
| 137 |
-
def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
|
| 138 |
-
ocr_error_detection_labels = self.surya_ocr_error_detection(document_pages, provider_page_lines).labels
|
| 139 |
-
|
| 140 |
-
good_pages = []
|
| 141 |
-
for (document_page, ocr_error_detection_label) in zip(document_pages, ocr_error_detection_labels):
|
| 142 |
-
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 143 |
-
good_pages.append(bool(provider_lines) and self.check_layout_coverage(document_page, provider_lines) and (ocr_error_detection_label != "bad"))
|
| 144 |
-
|
| 145 |
-
ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
|
| 146 |
-
for idx, document_page in enumerate(document_pages):
|
| 147 |
-
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 148 |
-
needs_ocr = not good_pages[idx]
|
| 149 |
-
if needs_ocr and ocr_document:
|
| 150 |
-
document_page.text_extraction_method = "surya"
|
| 151 |
-
continue
|
| 152 |
-
document_page.merge_blocks(provider_lines, text_extraction_method="pdftext")
|
| 153 |
-
document_page.text_extraction_method = "pdftext"
|
| 154 |
-
|
| 155 |
-
def check_layout_coverage(
|
| 156 |
-
self,
|
| 157 |
-
document_page: PageGroup,
|
| 158 |
-
provider_lines: List[ProviderOutput],
|
| 159 |
-
):
|
| 160 |
-
covered_blocks = 0
|
| 161 |
-
total_blocks = 0
|
| 162 |
-
large_text_blocks = 0
|
| 163 |
-
|
| 164 |
-
layout_blocks = [document_page.get_block(block) for block in document_page.structure]
|
| 165 |
-
layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage]
|
| 166 |
-
|
| 167 |
-
layout_bboxes = [block.polygon.bbox for block in layout_blocks]
|
| 168 |
-
provider_bboxes = [line.line.polygon.bbox for line in provider_lines]
|
| 169 |
-
|
| 170 |
-
intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes)
|
| 171 |
-
|
| 172 |
-
for idx, layout_block in enumerate(layout_blocks):
|
| 173 |
-
total_blocks += 1
|
| 174 |
-
intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0)
|
| 175 |
-
|
| 176 |
-
if intersecting_lines >= self.layout_coverage_min_lines:
|
| 177 |
-
covered_blocks += 1
|
| 178 |
-
|
| 179 |
-
if layout_block.polygon.intersection_pct(document_page.polygon) > 0.8 and layout_block.block_type == BlockTypes.Text:
|
| 180 |
-
large_text_blocks += 1
|
| 181 |
-
|
| 182 |
-
coverage_ratio = covered_blocks / total_blocks if total_blocks > 0 else 1
|
| 183 |
-
text_okay = coverage_ratio >= self.layout_coverage_threshold
|
| 184 |
-
|
| 185 |
-
# Model will sometimes say there is a single block of text on the page when it is blank
|
| 186 |
-
if not text_okay and (total_blocks == 1 and large_text_blocks == 1):
|
| 187 |
-
text_okay = True
|
| 188 |
-
return text_okay
|
|
|
|
| 1 |
+
from typing import Annotated, List, Optional
|
| 2 |
|
|
|
|
| 3 |
from surya.layout import LayoutPredictor
|
| 4 |
from surya.layout.schema import LayoutResult, LayoutBox
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from marker.builders import BaseBuilder
|
|
|
|
| 7 |
from marker.providers.pdf import PdfProvider
|
| 8 |
from marker.schema import BlockTypes
|
| 9 |
from marker.schema.document import Document
|
|
|
|
| 11 |
from marker.schema.polygon import PolygonBox
|
| 12 |
from marker.schema.registry import get_block_class
|
| 13 |
from marker.settings import settings
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class LayoutBuilder(BaseBuilder):
|
| 17 |
"""
|
| 18 |
A builder for performing layout detection on PDF pages and merging the results into the document.
|
| 19 |
"""
|
| 20 |
+
layout_batch_size: Annotated[
|
| 21 |
Optional[int],
|
| 22 |
"The batch size to use for the layout model.",
|
| 23 |
"Default is None, which will use the default batch size for the model."
|
| 24 |
] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
force_layout_block: Annotated[
|
| 26 |
str,
|
| 27 |
"Skip layout and force every page to be treated as a specific block type.",
|
| 28 |
] = None
|
| 29 |
+
disable_tqdm: Annotated[
|
| 30 |
+
bool,
|
| 31 |
+
"Disable tqdm progress bars.",
|
| 32 |
+
] = False
|
| 33 |
|
| 34 |
+
def __init__(self, layout_model: LayoutPredictor, config=None):
|
| 35 |
self.layout_model = layout_model
|
|
|
|
| 36 |
|
| 37 |
super().__init__(config)
|
| 38 |
|
|
|
|
| 43 |
else:
|
| 44 |
layout_results = self.surya_layout(document.pages)
|
| 45 |
self.add_blocks_to_pages(document.pages, layout_results)
|
|
|
|
| 46 |
|
| 47 |
def get_batch_size(self):
|
| 48 |
+
if self.layout_batch_size is not None:
|
| 49 |
+
return self.layout_batch_size
|
| 50 |
elif settings.TORCH_DEVICE_MODEL == "cuda":
|
| 51 |
return 6
|
| 52 |
return 6
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
|
| 75 |
+
self.layout_model.disable_tqdm = self.disable_tqdm
|
| 76 |
layout_results = self.layout_model(
|
| 77 |
[p.get_image(highres=False) for p in pages],
|
| 78 |
batch_size=int(self.get_batch_size())
|
| 79 |
)
|
| 80 |
return layout_results
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
|
| 83 |
for page, layout_result in zip(pages, layout_results):
|
| 84 |
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
|
|
|
|
| 97 |
|
| 98 |
# Ensure page has non-empty children
|
| 99 |
if page.children is None:
|
| 100 |
+
page.children = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker/builders/line.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from typing import Annotated, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from ftfy import fix_text
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from surya.detection import DetectionPredictor, InlineDetectionPredictor, TextDetectionResult
|
| 9 |
+
from surya.ocr_error import OCRErrorPredictor
|
| 10 |
+
|
| 11 |
+
from marker.builders import BaseBuilder
|
| 12 |
+
from marker.providers import ProviderOutput, ProviderPageLines
|
| 13 |
+
from marker.providers.pdf import PdfProvider
|
| 14 |
+
from marker.schema import BlockTypes
|
| 15 |
+
from marker.schema.document import Document
|
| 16 |
+
from marker.schema.groups.page import PageGroup
|
| 17 |
+
from marker.schema.polygon import PolygonBox
|
| 18 |
+
from marker.schema.registry import get_block_class
|
| 19 |
+
from marker.schema.text.line import Line
|
| 20 |
+
from marker.settings import settings
|
| 21 |
+
from marker.util import matrix_intersection_area, sort_text_lines
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TextBox(PolygonBox):
|
| 25 |
+
math: bool = False
|
| 26 |
+
|
| 27 |
+
def __hash__(self):
|
| 28 |
+
return hash(tuple(self.bbox))
|
| 29 |
+
|
| 30 |
+
class LineBuilder(BaseBuilder):
|
| 31 |
+
"""
|
| 32 |
+
A builder for detecting text lines, and inline math. Merges the detected lines with the lines from the provider
|
| 33 |
+
"""
|
| 34 |
+
detection_batch_size: Annotated[
|
| 35 |
+
Optional[int],
|
| 36 |
+
"The batch size to use for the detection model.",
|
| 37 |
+
"Default is None, which will use the default batch size for the model."
|
| 38 |
+
] = None
|
| 39 |
+
ocr_error_batch_size: Annotated[
|
| 40 |
+
Optional[int],
|
| 41 |
+
"The batch size to use for the ocr error detection model.",
|
| 42 |
+
"Default is None, which will use the default batch size for the model."
|
| 43 |
+
] = None
|
| 44 |
+
enable_table_ocr: Annotated[
|
| 45 |
+
bool,
|
| 46 |
+
"Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.",
|
| 47 |
+
] = False
|
| 48 |
+
layout_coverage_min_lines: Annotated[
|
| 49 |
+
int,
|
| 50 |
+
"The minimum number of PdfProvider lines that must be covered by the layout model",
|
| 51 |
+
"to consider the lines from the PdfProvider valid.",
|
| 52 |
+
] = 1
|
| 53 |
+
layout_coverage_threshold: Annotated[
|
| 54 |
+
float,
|
| 55 |
+
"The minimum coverage ratio required for the layout model to consider",
|
| 56 |
+
"the lines from the PdfProvider valid.",
|
| 57 |
+
] = .25
|
| 58 |
+
min_document_ocr_threshold: Annotated[
|
| 59 |
+
float,
|
| 60 |
+
"If less pages than this threshold are good, OCR will happen in the document. Otherwise it will not."
|
| 61 |
+
] = 0.85
|
| 62 |
+
span_inline_math_overlap_threshold: Annotated[
|
| 63 |
+
float,
|
| 64 |
+
"The minimum overlap of a span with an inline math box to consider for removal"
|
| 65 |
+
] = .5
|
| 66 |
+
char_inline_math_overlap_threshold: Annotated[
|
| 67 |
+
float,
|
| 68 |
+
"The minimum overlap of a character with an inline math box to consider for removal"
|
| 69 |
+
] = .5
|
| 70 |
+
line_inline_math_overlap_threshold: Annotated[
|
| 71 |
+
float,
|
| 72 |
+
"The minimum overlap of a line with an inline math box to consider as a match"
|
| 73 |
+
] = 0.
|
| 74 |
+
line_text_overlap_threshold: Annotated[
|
| 75 |
+
float,
|
| 76 |
+
"The minimum overlap of an equation with a text line to consider as a match"
|
| 77 |
+
] = .5
|
| 78 |
+
inline_math_minimum_area: Annotated[
|
| 79 |
+
float,
|
| 80 |
+
"The minimum area for an inline math block, in pixels."
|
| 81 |
+
] = 20
|
| 82 |
+
inline_math_line_vertical_merge_threshold: Annotated[
|
| 83 |
+
int,
|
| 84 |
+
"The maximum pixel distance between y1s for two lines to be merged"
|
| 85 |
+
] = 8
|
| 86 |
+
excluded_for_coverage: Annotated[
|
| 87 |
+
Tuple[BlockTypes],
|
| 88 |
+
"A list of block types to exclude from the layout coverage check.",
|
| 89 |
+
] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
|
| 90 |
+
use_llm: Annotated[
|
| 91 |
+
bool,
|
| 92 |
+
"Whether to use the LLM model for advanced processing."
|
| 93 |
+
] = False
|
| 94 |
+
texify_inline_spans: Annotated[
|
| 95 |
+
bool,
|
| 96 |
+
"Whether to run texify on inline math spans."
|
| 97 |
+
] = False
|
| 98 |
+
ocr_remove_blocks: Tuple[BlockTypes, ...] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents, BlockTypes.Equation)
|
| 99 |
+
disable_tqdm: Annotated[
|
| 100 |
+
bool,
|
| 101 |
+
"Disable tqdm progress bars.",
|
| 102 |
+
] = False
|
| 103 |
+
|
| 104 |
+
def __init__(self, detection_model: DetectionPredictor, inline_detection_model: InlineDetectionPredictor, ocr_error_model: OCRErrorPredictor, config=None):
|
| 105 |
+
super().__init__(config)
|
| 106 |
+
|
| 107 |
+
self.detection_model = detection_model
|
| 108 |
+
self.inline_detection_model = inline_detection_model
|
| 109 |
+
self.ocr_error_model = ocr_error_model
|
| 110 |
+
|
| 111 |
+
def __call__(self, document: Document, provider: PdfProvider):
|
| 112 |
+
# Disable Inline Detection for documents where layout model doesn't detect any equations
|
| 113 |
+
# Also disable if we won't use the inline detections (if we aren't using the LLM or texify)
|
| 114 |
+
do_inline_math_detection = document.contained_blocks([BlockTypes.Equation]) and (self.texify_inline_spans or self.use_llm)
|
| 115 |
+
provider_lines, ocr_lines = self.get_all_lines(document, provider, do_inline_math_detection)
|
| 116 |
+
self.merge_blocks(document, provider_lines, ocr_lines)
|
| 117 |
+
|
| 118 |
+
def get_detection_batch_size(self):
|
| 119 |
+
if self.detection_batch_size is not None:
|
| 120 |
+
return self.detection_batch_size
|
| 121 |
+
elif settings.TORCH_DEVICE_MODEL == "cuda":
|
| 122 |
+
return 4
|
| 123 |
+
return 4
|
| 124 |
+
|
| 125 |
+
def get_ocr_error_batch_size(self):
|
| 126 |
+
if self.ocr_error_batch_size is not None:
|
| 127 |
+
return self.ocr_error_batch_size
|
| 128 |
+
elif settings.TORCH_DEVICE_MODEL == "cuda":
|
| 129 |
+
return 4
|
| 130 |
+
return 4
|
| 131 |
+
|
| 132 |
+
def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool):
|
| 133 |
+
self.detection_model.disable_tqdm = self.disable_tqdm
|
| 134 |
+
page_detection_results = self.detection_model(
|
| 135 |
+
images=page_images,
|
| 136 |
+
batch_size=self.get_detection_batch_size()
|
| 137 |
+
)
|
| 138 |
+
inline_detection_results = [None] * len(page_detection_results)
|
| 139 |
+
if do_inline_math_detection:
|
| 140 |
+
self.inline_detection_model.disable_tqdm = self.disable_tqdm
|
| 141 |
+
inline_detection_results = self.inline_detection_model(
|
| 142 |
+
images=page_images,
|
| 143 |
+
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in page_detection_results],
|
| 144 |
+
batch_size=self.get_detection_batch_size()
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
detection_results = []
|
| 148 |
+
inline_results = []
|
| 149 |
+
idx = 0
|
| 150 |
+
for good in run_detection:
|
| 151 |
+
if good:
|
| 152 |
+
detection_results.append(page_detection_results[idx])
|
| 153 |
+
inline_results.append(inline_detection_results[idx])
|
| 154 |
+
idx += 1
|
| 155 |
+
else:
|
| 156 |
+
detection_results.append(None)
|
| 157 |
+
inline_results.append(None)
|
| 158 |
+
assert idx == len(page_images)
|
| 159 |
+
|
| 160 |
+
assert len(run_detection) == len(detection_results) == len(inline_results)
|
| 161 |
+
return detection_results, inline_results
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_all_lines(self, document: Document, provider: PdfProvider, do_inline_math_detection: bool):
|
| 165 |
+
ocr_error_detection_results = self.ocr_error_detection(document.pages, provider.page_lines)
|
| 166 |
+
|
| 167 |
+
boxes_to_ocr = {page.page_id: [] for page in document.pages}
|
| 168 |
+
page_lines = {page.page_id: [] for page in document.pages}
|
| 169 |
+
|
| 170 |
+
LineClass: Line = get_block_class(BlockTypes.Line)
|
| 171 |
+
|
| 172 |
+
layout_good = []
|
| 173 |
+
for document_page, ocr_error_detection_label in zip(document.pages, ocr_error_detection_results.labels):
|
| 174 |
+
provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, [])
|
| 175 |
+
provider_lines_good = all([
|
| 176 |
+
bool(provider_lines),
|
| 177 |
+
ocr_error_detection_label != 'bad',
|
| 178 |
+
self.check_layout_coverage(document_page, provider_lines)
|
| 179 |
+
])
|
| 180 |
+
layout_good.append(provider_lines_good)
|
| 181 |
+
|
| 182 |
+
# Don't OCR if only a few pages are bad
|
| 183 |
+
if sum(layout_good) > len(document.pages) * self.min_document_ocr_threshold:
|
| 184 |
+
layout_good = [True] * len(document.pages)
|
| 185 |
+
|
| 186 |
+
run_detection = [not good or do_inline_math_detection for good in layout_good]
|
| 187 |
+
page_images = [page.get_image(highres=False, remove_blocks=self.ocr_remove_blocks) for page, good in zip(document.pages, run_detection) if good]
|
| 188 |
+
|
| 189 |
+
# Note: run_detection is longer than page_images, since it has a value for each page, not just good ones
|
| 190 |
+
# Detection results and inline detection results are for every page (we use run_detection to make the list full length)
|
| 191 |
+
detection_results, inline_detection_results = self.get_detection_results(page_images, run_detection, do_inline_math_detection)
|
| 192 |
+
|
| 193 |
+
assert len(detection_results) == len(inline_detection_results) == len(layout_good) == len(document.pages)
|
| 194 |
+
for document_page, detection_result, inline_detection_result, provider_lines_good in zip(
|
| 195 |
+
document.pages,
|
| 196 |
+
detection_results,
|
| 197 |
+
inline_detection_results,
|
| 198 |
+
layout_good
|
| 199 |
+
):
|
| 200 |
+
provider_lines: List[ProviderOutput] = provider.page_lines.get(document_page.page_id, [])
|
| 201 |
+
page_size = provider.get_page_bbox(document_page.page_id).size
|
| 202 |
+
image_size = PolygonBox.from_bbox(detection_result.image_bbox).size if detection_result else page_size
|
| 203 |
+
|
| 204 |
+
# Merge text and inline math detection results
|
| 205 |
+
merged_detection_boxes = self.determine_math_lines(text_result=detection_result, inline_result=inline_detection_result)
|
| 206 |
+
# Sort the lines to ensure that the order is preserved
|
| 207 |
+
merged_detection_boxes = sort_text_lines(merged_detection_boxes)
|
| 208 |
+
|
| 209 |
+
math_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if box.math]
|
| 210 |
+
nonmath_detection_boxes = [(i, box) for i, box in enumerate(merged_detection_boxes) if not box.math]
|
| 211 |
+
|
| 212 |
+
if provider_lines_good:
|
| 213 |
+
# Merge inline math blocks into the provider lines, only persist new detected text lines which do not overlap with existing provider lines
|
| 214 |
+
# The missing lines are not from a table, so we can safely set this - The attribute for individual blocks is overridden by OCRBuilder
|
| 215 |
+
document_page.text_extraction_method = 'pdftext'
|
| 216 |
+
|
| 217 |
+
# Add in the provider lines - merge ones that get broken by inline math
|
| 218 |
+
page_lines[document_page.page_id].extend(
|
| 219 |
+
self.merge_provider_lines_inline_math(
|
| 220 |
+
provider_lines,
|
| 221 |
+
[b for _,b in math_detection_boxes],
|
| 222 |
+
image_size,
|
| 223 |
+
page_size
|
| 224 |
+
)
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
document_page.text_extraction_method = 'surya'
|
| 228 |
+
|
| 229 |
+
# Sort lines properly
|
| 230 |
+
full_lines = nonmath_detection_boxes + math_detection_boxes
|
| 231 |
+
full_lines = sorted(full_lines, key=lambda x: x[0])
|
| 232 |
+
full_lines = [b for _, b in full_lines]
|
| 233 |
+
|
| 234 |
+
# Skip inline math merging if no provider lines are good; OCR all text lines and all inline math lines
|
| 235 |
+
boxes_to_ocr[document_page.page_id].extend(full_lines)
|
| 236 |
+
|
| 237 |
+
# Dummy lines to merge into the document - Contains no spans, will be filled in later by OCRBuilder
|
| 238 |
+
ocr_lines = {document_page.page_id: [] for document_page in document.pages}
|
| 239 |
+
for page_id, page_ocr_boxes in boxes_to_ocr.items():
|
| 240 |
+
page_size = provider.get_page_bbox(page_id).size
|
| 241 |
+
image_size = document.get_page(page_id).get_image(highres=False).size
|
| 242 |
+
for box_to_ocr in page_ocr_boxes:
|
| 243 |
+
line_polygon = PolygonBox(polygon=box_to_ocr.polygon).rescale(image_size, page_size)
|
| 244 |
+
format = ["math"] if box_to_ocr.math else None
|
| 245 |
+
ocr_lines[page_id].append(
|
| 246 |
+
ProviderOutput(
|
| 247 |
+
line=LineClass(
|
| 248 |
+
polygon=line_polygon,
|
| 249 |
+
page_id=page_id,
|
| 250 |
+
text_extraction_method='surya',
|
| 251 |
+
formats=format
|
| 252 |
+
),
|
| 253 |
+
spans=[]
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
return page_lines, ocr_lines
|
| 258 |
+
|
| 259 |
+
def ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines):
|
| 260 |
+
page_texts = []
|
| 261 |
+
for document_page in pages:
|
| 262 |
+
provider_lines = provider_page_lines.get(document_page.page_id, [])
|
| 263 |
+
page_text = '\n'.join(' '.join(s.text for s in line.spans) for line in provider_lines)
|
| 264 |
+
page_texts.append(page_text)
|
| 265 |
+
|
| 266 |
+
self.ocr_error_model.disable_tqdm = self.disable_tqdm
|
| 267 |
+
ocr_error_detection_results = self.ocr_error_model(
|
| 268 |
+
page_texts,
|
| 269 |
+
batch_size=int(self.get_ocr_error_batch_size())
|
| 270 |
+
)
|
| 271 |
+
return ocr_error_detection_results
|
| 272 |
+
|
| 273 |
+
def check_layout_coverage(
|
| 274 |
+
self,
|
| 275 |
+
document_page: PageGroup,
|
| 276 |
+
provider_lines: List[ProviderOutput],
|
| 277 |
+
):
|
| 278 |
+
covered_blocks = 0
|
| 279 |
+
total_blocks = 0
|
| 280 |
+
large_text_blocks = 0
|
| 281 |
+
|
| 282 |
+
layout_blocks = [document_page.get_block(block) for block in document_page.structure]
|
| 283 |
+
layout_blocks = [b for b in layout_blocks if b.block_type not in self.excluded_for_coverage]
|
| 284 |
+
|
| 285 |
+
layout_bboxes = [block.polygon.bbox for block in layout_blocks]
|
| 286 |
+
provider_bboxes = [line.line.polygon.bbox for line in provider_lines]
|
| 287 |
+
|
| 288 |
+
if len(layout_bboxes) == 0:
|
| 289 |
+
return True
|
| 290 |
+
|
| 291 |
+
if len(provider_bboxes) == 0:
|
| 292 |
+
return False
|
| 293 |
+
|
| 294 |
+
intersection_matrix = matrix_intersection_area(layout_bboxes, provider_bboxes)
|
| 295 |
+
|
| 296 |
+
for idx, layout_block in enumerate(layout_blocks):
|
| 297 |
+
total_blocks += 1
|
| 298 |
+
intersecting_lines = np.count_nonzero(intersection_matrix[idx] > 0)
|
| 299 |
+
|
| 300 |
+
if intersecting_lines >= self.layout_coverage_min_lines:
|
| 301 |
+
covered_blocks += 1
|
| 302 |
+
|
| 303 |
+
if layout_block.polygon.intersection_pct(document_page.polygon) > 0.8 and layout_block.block_type == BlockTypes.Text:
|
| 304 |
+
large_text_blocks += 1
|
| 305 |
+
|
| 306 |
+
coverage_ratio = covered_blocks / total_blocks if total_blocks > 0 else 1
|
| 307 |
+
text_okay = coverage_ratio >= self.layout_coverage_threshold
|
| 308 |
+
|
| 309 |
+
# Model will sometimes say there is a single block of text on the page when it is blank
|
| 310 |
+
if not text_okay and (total_blocks == 1 and large_text_blocks == 1):
|
| 311 |
+
text_okay = True
|
| 312 |
+
return text_okay
|
| 313 |
+
|
| 314 |
+
def merge_blocks(self, document: Document, page_provider_lines: ProviderPageLines, page_ocr_lines: ProviderPageLines):
|
| 315 |
+
for document_page in document.pages:
|
| 316 |
+
provider_lines = page_provider_lines[document_page.page_id]
|
| 317 |
+
ocr_lines = page_ocr_lines[document_page.page_id]
|
| 318 |
+
|
| 319 |
+
# Only one or the other will have lines
|
| 320 |
+
merged_lines = provider_lines + ocr_lines
|
| 321 |
+
|
| 322 |
+
# Text extraction method is overridden later for OCRed documents
|
| 323 |
+
document_page.merge_blocks(merged_lines, text_extraction_method='pdftext')
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def determine_math_lines(
|
| 327 |
+
self,
|
| 328 |
+
text_result: TextDetectionResult,
|
| 329 |
+
inline_result: TextDetectionResult,
|
| 330 |
+
) -> List[TextBox]:
|
| 331 |
+
"""
|
| 332 |
+
Marks lines as math if they contain inline math boxes.
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
if not text_result:
|
| 336 |
+
return []
|
| 337 |
+
|
| 338 |
+
text_boxes = [
|
| 339 |
+
TextBox(
|
| 340 |
+
polygon=box.polygon
|
| 341 |
+
) for box in text_result.bboxes
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
# Skip if no inline math was detected
|
| 345 |
+
if not inline_result:
|
| 346 |
+
return text_boxes
|
| 347 |
+
|
| 348 |
+
inline_bboxes = [m.bbox for m in inline_result.bboxes]
|
| 349 |
+
text_bboxes = [t.bbox for t in text_boxes]
|
| 350 |
+
|
| 351 |
+
if len(inline_bboxes) == 0:
|
| 352 |
+
return text_boxes
|
| 353 |
+
|
| 354 |
+
if len(text_boxes) == 0:
|
| 355 |
+
return []
|
| 356 |
+
|
| 357 |
+
overlaps = matrix_intersection_area(inline_bboxes, text_bboxes)
|
| 358 |
+
|
| 359 |
+
# Mark text boxes as math if they overlap with an inline math box
|
| 360 |
+
for i, inline_box in enumerate(inline_result.bboxes):
|
| 361 |
+
overlap_row = overlaps[i]
|
| 362 |
+
max_overlap_idx = np.argmax(overlap_row)
|
| 363 |
+
max_overlap_box = text_boxes[max_overlap_idx]
|
| 364 |
+
|
| 365 |
+
max_overlap = np.max(overlap_row) / inline_box.area
|
| 366 |
+
|
| 367 |
+
# Avoid small or nonoverlapping inline math regions
|
| 368 |
+
if max_overlap <= self.line_inline_math_overlap_threshold or inline_box.area < self.inline_math_minimum_area:
|
| 369 |
+
continue
|
| 370 |
+
|
| 371 |
+
# Ignore vertical lines
|
| 372 |
+
if max_overlap_box.height > max_overlap_box.width * 2:
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
max_overlap_box.math = True
|
| 376 |
+
|
| 377 |
+
return text_boxes
|
| 378 |
+
|
| 379 |
+
# Add appropriate formats to math spans added by inline math detection
|
| 380 |
+
def add_math_span_format(self, provider_line):
|
| 381 |
+
if not provider_line.line.formats:
|
| 382 |
+
provider_line.line.formats = ["math"]
|
| 383 |
+
elif "math" not in provider_line.line.formats:
|
| 384 |
+
provider_line.line.formats.append("math")
|
| 385 |
+
|
| 386 |
+
def merge_provider_lines_inline_math(
|
| 387 |
+
self,
|
| 388 |
+
provider_lines: List[ProviderOutput],
|
| 389 |
+
inline_math_lines: List[TextBox],
|
| 390 |
+
image_size,
|
| 391 |
+
page_size
|
| 392 |
+
):
|
| 393 |
+
# When provider lines is empty or no inline math detected, return provider lines
|
| 394 |
+
if not provider_lines or not inline_math_lines:
|
| 395 |
+
return provider_lines
|
| 396 |
+
|
| 397 |
+
horizontal_provider_lines = [
|
| 398 |
+
(j, provider_line) for j, provider_line in enumerate(provider_lines)
|
| 399 |
+
if provider_line.line.polygon.height < provider_line.line.polygon.width * 3 # Multiply to account for small blocks inside equations, but filter out big vertical lines
|
| 400 |
+
]
|
| 401 |
+
provider_line_boxes = [p.line.polygon.bbox for _, p in horizontal_provider_lines]
|
| 402 |
+
math_line_boxes = [PolygonBox(polygon=m.polygon).rescale(image_size, page_size).bbox for m in inline_math_lines]
|
| 403 |
+
|
| 404 |
+
overlaps = matrix_intersection_area(math_line_boxes, provider_line_boxes)
|
| 405 |
+
|
| 406 |
+
# Find potential merges
|
| 407 |
+
merge_lines = []
|
| 408 |
+
for i in range(len(math_line_boxes)):
|
| 409 |
+
merge_line = []
|
| 410 |
+
math_line_polygon = PolygonBox(polygon=inline_math_lines[i].polygon).rescale(image_size, page_size)
|
| 411 |
+
max_overlap = np.max(overlaps[i])
|
| 412 |
+
if max_overlap <= self.line_inline_math_overlap_threshold:
|
| 413 |
+
continue
|
| 414 |
+
|
| 415 |
+
best_overlap = np.argmax(overlaps[i])
|
| 416 |
+
best_overlap_line = horizontal_provider_lines[best_overlap]
|
| 417 |
+
best_overlap_y1 = best_overlap_line[1].line.polygon.y_start
|
| 418 |
+
|
| 419 |
+
nonzero_idxs = np.nonzero(overlaps[i] > self.line_inline_math_overlap_threshold)[0]
|
| 420 |
+
for idx in nonzero_idxs:
|
| 421 |
+
provider_idx, provider_line = horizontal_provider_lines[idx]
|
| 422 |
+
provider_line_y1 = provider_line.line.polygon.y_start
|
| 423 |
+
|
| 424 |
+
should_merge_line = False
|
| 425 |
+
if abs(provider_line_y1 - best_overlap_y1) <= self.inline_math_line_vertical_merge_threshold:
|
| 426 |
+
should_merge_line = True
|
| 427 |
+
|
| 428 |
+
line_overlaps = self.find_overlapping_math_chars(provider_line, math_line_polygon, remove_chars=not should_merge_line)
|
| 429 |
+
|
| 430 |
+
# Do not merge if too far above/below (but remove characters)
|
| 431 |
+
if line_overlaps and should_merge_line:
|
| 432 |
+
# Add the index of the provider line to the merge line
|
| 433 |
+
merge_line.append(provider_idx)
|
| 434 |
+
|
| 435 |
+
if len(merge_line) > 0:
|
| 436 |
+
merge_lines.append(merge_line)
|
| 437 |
+
|
| 438 |
+
# Handle the merging
|
| 439 |
+
already_merged = set()
|
| 440 |
+
potential_merges = set([m for merge_line in merge_lines for m in merge_line])
|
| 441 |
+
out_provider_lines = [(i, p) for i, p in enumerate(provider_lines) if i not in potential_merges]
|
| 442 |
+
for merge_section in merge_lines:
|
| 443 |
+
merge_section = [m for m in merge_section if m not in already_merged]
|
| 444 |
+
if len(merge_section) == 0:
|
| 445 |
+
continue
|
| 446 |
+
elif len(merge_section) == 1:
|
| 447 |
+
line_idx = merge_section[0]
|
| 448 |
+
merged_line = provider_lines[line_idx]
|
| 449 |
+
self.add_math_span_format(merged_line)
|
| 450 |
+
out_provider_lines.append((line_idx, merged_line))
|
| 451 |
+
already_merged.add(merge_section[0])
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
merge_section = sorted(merge_section)
|
| 455 |
+
merged_line = None
|
| 456 |
+
min_idx = min(merge_section)
|
| 457 |
+
for idx in merge_section:
|
| 458 |
+
provider_line = deepcopy(provider_lines[idx])
|
| 459 |
+
if merged_line is None:
|
| 460 |
+
merged_line = provider_line
|
| 461 |
+
else:
|
| 462 |
+
# Combine the spans of the provider line with the merged line
|
| 463 |
+
merged_line = merged_line.merge(provider_line)
|
| 464 |
+
self.add_math_span_format(merged_line)
|
| 465 |
+
already_merged.add(idx) # Prevent double merging
|
| 466 |
+
out_provider_lines.append((min_idx, merged_line))
|
| 467 |
+
|
| 468 |
+
# Sort to preserve original order
|
| 469 |
+
out_provider_lines = sorted(out_provider_lines, key=lambda x: x[0])
|
| 470 |
+
out_provider_lines = [p for _, p in out_provider_lines]
|
| 471 |
+
return out_provider_lines
|
| 472 |
+
|
| 473 |
+
def clear_line_text(self, provider_line):
|
| 474 |
+
for span in provider_line.spans:
|
| 475 |
+
span.text = ""
|
| 476 |
+
|
| 477 |
+
def find_overlapping_math_chars(self, provider_line, math_line_polygon, remove_chars=False):
|
| 478 |
+
# Identify if a character in the provider line overlaps with the inline math line - meaning that the line can be treated as math
|
| 479 |
+
spans = provider_line.spans
|
| 480 |
+
math_overlaps = False
|
| 481 |
+
|
| 482 |
+
# For providers which do not surface characters
|
| 483 |
+
if provider_line.chars is None:
|
| 484 |
+
for span in spans:
|
| 485 |
+
if span.polygon.intersection_pct(math_line_polygon) > self.span_inline_math_overlap_threshold:
|
| 486 |
+
math_overlaps = True
|
| 487 |
+
return math_overlaps
|
| 488 |
+
|
| 489 |
+
# For providers which surface characters - find line overlap based on characters
|
| 490 |
+
assert len(spans) == len(provider_line.chars), "Number of spans and characters in provider line do not match"
|
| 491 |
+
for span, span_chars in zip(spans, provider_line.chars):
|
| 492 |
+
if len(span_chars) == 0:
|
| 493 |
+
continue
|
| 494 |
+
|
| 495 |
+
char_intersections_areas = matrix_intersection_area([char.polygon.bbox for char in span_chars], [math_line_polygon.bbox]).max(axis=-1)
|
| 496 |
+
char_intersections = char_intersections_areas / np.array([char.polygon.area for char in span_chars])
|
| 497 |
+
|
| 498 |
+
new_span_chars = []
|
| 499 |
+
span_overlaps = False
|
| 500 |
+
for char, intersection_pct in zip(span_chars, char_intersections):
|
| 501 |
+
if intersection_pct >= self.char_inline_math_overlap_threshold:
|
| 502 |
+
span_overlaps = True
|
| 503 |
+
else:
|
| 504 |
+
new_span_chars.append(char)
|
| 505 |
+
|
| 506 |
+
# Remove stray characters that overlap with math lines
|
| 507 |
+
if span_overlaps and remove_chars:
|
| 508 |
+
span.text = fix_text(''.join(c.char for c in new_span_chars))
|
| 509 |
+
|
| 510 |
+
math_overlaps = math_overlaps or span_overlaps
|
| 511 |
+
|
| 512 |
+
return math_overlaps
|
marker/builders/llm_layout.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
from typing import Annotated
|
| 3 |
|
| 4 |
-
from google.ai.generativelanguage_v1beta.types import content
|
| 5 |
from surya.layout import LayoutPredictor
|
| 6 |
-
from surya.ocr_error import OCRErrorPredictor
|
| 7 |
from tqdm import tqdm
|
|
|
|
| 8 |
|
| 9 |
from marker.builders.layout import LayoutBuilder
|
| 10 |
-
from marker.
|
| 11 |
from marker.providers.pdf import PdfProvider
|
| 12 |
from marker.schema import BlockTypes
|
| 13 |
from marker.schema.blocks import Block
|
|
@@ -37,19 +36,15 @@ class LLMLayoutBuilder(LayoutBuilder):
|
|
| 37 |
model_name: Annotated[
|
| 38 |
str,
|
| 39 |
"The name of the Gemini model to use.",
|
| 40 |
-
] = "gemini-
|
| 41 |
-
max_retries: Annotated[
|
| 42 |
-
int,
|
| 43 |
-
"The maximum number of retries to use for the Gemini model.",
|
| 44 |
-
] = 3
|
| 45 |
max_concurrency: Annotated[
|
| 46 |
int,
|
| 47 |
"The maximum number of concurrent requests to make to the Gemini model.",
|
| 48 |
] = 3
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
"
|
| 52 |
-
] =
|
| 53 |
topk_relabelling_prompt: Annotated[
|
| 54 |
str,
|
| 55 |
"The prompt to use for relabelling blocks.",
|
|
@@ -94,10 +89,10 @@ Potential labels:
|
|
| 94 |
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
|
| 95 |
"""
|
| 96 |
|
| 97 |
-
def __init__(self, layout_model: LayoutPredictor,
|
| 98 |
-
super().__init__(layout_model,
|
| 99 |
|
| 100 |
-
self.
|
| 101 |
|
| 102 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 103 |
super().__call__(document, provider)
|
|
@@ -107,7 +102,7 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
|
|
| 107 |
print(f"Error relabelling blocks: {e}")
|
| 108 |
|
| 109 |
def relabel_blocks(self, document: Document):
|
| 110 |
-
pbar = tqdm(desc="LLM layout relabelling")
|
| 111 |
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
|
| 112 |
futures = []
|
| 113 |
for page in document.pages:
|
|
@@ -154,21 +149,13 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
|
|
| 154 |
|
| 155 |
def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
|
| 156 |
image = self.extract_image(document, block)
|
| 157 |
-
response_schema = content.Schema(
|
| 158 |
-
type=content.Type.OBJECT,
|
| 159 |
-
enum=[],
|
| 160 |
-
required=["image_description", "label"],
|
| 161 |
-
properties={
|
| 162 |
-
"image_description": content.Schema(
|
| 163 |
-
type=content.Type.STRING,
|
| 164 |
-
),
|
| 165 |
-
"label": content.Schema(
|
| 166 |
-
type=content.Type.STRING,
|
| 167 |
-
),
|
| 168 |
-
},
|
| 169 |
-
)
|
| 170 |
|
| 171 |
-
response = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
generated_label = None
|
| 173 |
if response and "label" in response:
|
| 174 |
generated_label = response["label"]
|
|
@@ -184,3 +171,8 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
|
|
| 184 |
|
| 185 |
def extract_image(self, document: Document, image_block: Block, expand: float = 0.01):
|
| 186 |
return image_block.get_image(document, highres=False, expansion=(expand, expand))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 2 |
from typing import Annotated
|
| 3 |
|
|
|
|
| 4 |
from surya.layout import LayoutPredictor
|
|
|
|
| 5 |
from tqdm import tqdm
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
|
| 8 |
from marker.builders.layout import LayoutBuilder
|
| 9 |
+
from marker.services import BaseService
|
| 10 |
from marker.providers.pdf import PdfProvider
|
| 11 |
from marker.schema import BlockTypes
|
| 12 |
from marker.schema.blocks import Block
|
|
|
|
| 36 |
model_name: Annotated[
|
| 37 |
str,
|
| 38 |
"The name of the Gemini model to use.",
|
| 39 |
+
] = "gemini-2.0-flash"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
max_concurrency: Annotated[
|
| 41 |
int,
|
| 42 |
"The maximum number of concurrent requests to make to the Gemini model.",
|
| 43 |
] = 3
|
| 44 |
+
disable_tqdm: Annotated[
|
| 45 |
+
bool,
|
| 46 |
+
"Whether to disable the tqdm progress bar.",
|
| 47 |
+
] = False
|
| 48 |
topk_relabelling_prompt: Annotated[
|
| 49 |
str,
|
| 50 |
"The prompt to use for relabelling blocks.",
|
|
|
|
| 89 |
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
|
| 90 |
"""
|
| 91 |
|
| 92 |
+
def __init__(self, layout_model: LayoutPredictor, llm_service: BaseService, config=None):
|
| 93 |
+
super().__init__(layout_model, config)
|
| 94 |
|
| 95 |
+
self.llm_service = llm_service
|
| 96 |
|
| 97 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 98 |
super().__call__(document, provider)
|
|
|
|
| 102 |
print(f"Error relabelling blocks: {e}")
|
| 103 |
|
| 104 |
def relabel_blocks(self, document: Document):
|
| 105 |
+
pbar = tqdm(desc="LLM layout relabelling", disable=self.disable_tqdm)
|
| 106 |
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
|
| 107 |
futures = []
|
| 108 |
for page in document.pages:
|
|
|
|
| 149 |
|
| 150 |
def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
|
| 151 |
image = self.extract_image(document, block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
response = self.llm_service(
|
| 154 |
+
prompt,
|
| 155 |
+
image,
|
| 156 |
+
block,
|
| 157 |
+
LayoutSchema
|
| 158 |
+
)
|
| 159 |
generated_label = None
|
| 160 |
if response and "label" in response:
|
| 161 |
generated_label = response["label"]
|
|
|
|
| 171 |
|
| 172 |
def extract_image(self, document: Document, image_block: Block, expand: float = 0.01):
|
| 173 |
return image_block.get_image(document, highres=False, expansion=(expand, expand))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class LayoutSchema(BaseModel):
|
| 177 |
+
image_description: str
|
| 178 |
+
label: str
|
marker/builders/ocr.py
CHANGED
|
@@ -1,21 +1,19 @@
|
|
|
|
|
| 1 |
from typing import Annotated, List, Optional
|
| 2 |
|
| 3 |
from ftfy import fix_text
|
| 4 |
-
from surya.detection import DetectionPredictor
|
| 5 |
from surya.recognition import RecognitionPredictor
|
| 6 |
|
| 7 |
from marker.builders import BaseBuilder
|
| 8 |
-
from marker.providers import ProviderOutput, ProviderPageLines
|
| 9 |
from marker.providers.pdf import PdfProvider
|
| 10 |
from marker.schema import BlockTypes
|
|
|
|
| 11 |
from marker.schema.document import Document
|
| 12 |
-
from marker.schema.
|
| 13 |
from marker.schema.registry import get_block_class
|
| 14 |
-
from marker.schema.text.line import Line
|
| 15 |
from marker.schema.text.span import Span
|
| 16 |
from marker.settings import settings
|
| 17 |
|
| 18 |
-
|
| 19 |
class OcrBuilder(BaseBuilder):
|
| 20 |
"""
|
| 21 |
A builder for performing OCR on PDF pages and merging the results into the document.
|
|
@@ -25,30 +23,25 @@ class OcrBuilder(BaseBuilder):
|
|
| 25 |
"The batch size to use for the recognition model.",
|
| 26 |
"Default is None, which will use the default batch size for the model."
|
| 27 |
] = None
|
| 28 |
-
detection_batch_size: Annotated[
|
| 29 |
-
Optional[int],
|
| 30 |
-
"The batch size to use for the detection model.",
|
| 31 |
-
"Default is None, which will use the default batch size for the model."
|
| 32 |
-
] = None
|
| 33 |
languages: Annotated[
|
| 34 |
Optional[List[str]],
|
| 35 |
"A list of languages to use for OCR.",
|
| 36 |
"Default is None."
|
| 37 |
] = None
|
| 38 |
-
|
| 39 |
bool,
|
| 40 |
-
"
|
| 41 |
] = False
|
| 42 |
|
| 43 |
-
def __init__(self,
|
| 44 |
super().__init__(config)
|
| 45 |
|
| 46 |
-
self.detection_model = detection_model
|
| 47 |
self.recognition_model = recognition_model
|
| 48 |
|
| 49 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 50 |
-
|
| 51 |
-
self.
|
|
|
|
| 52 |
|
| 53 |
def get_recognition_batch_size(self):
|
| 54 |
if self.recognition_batch_size is not None:
|
|
@@ -59,64 +52,62 @@ class OcrBuilder(BaseBuilder):
|
|
| 59 |
return 32
|
| 60 |
return 32
|
| 61 |
|
| 62 |
-
def
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
recognition_results = self.recognition_model(
|
| 74 |
-
images=
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
detection_batch_size=int(self.get_detection_batch_size()),
|
| 78 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 79 |
-
|
| 80 |
)
|
| 81 |
|
| 82 |
-
page_lines = {}
|
| 83 |
-
|
| 84 |
SpanClass: Span = get_block_class(BlockTypes.Span)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
text=fix_text(ocr_line.text) + "\n",
|
| 103 |
-
formats=['plain'],
|
| 104 |
-
page_id=page_id,
|
| 105 |
-
polygon=polygon,
|
| 106 |
-
minimum_position=0,
|
| 107 |
-
maximum_position=0,
|
| 108 |
-
font='Unknown',
|
| 109 |
-
font_weight=0,
|
| 110 |
-
font_size=0,
|
| 111 |
-
)
|
| 112 |
-
]
|
| 113 |
-
|
| 114 |
-
page_lines[page_id].append(ProviderOutput(line=line, spans=spans))
|
| 115 |
-
|
| 116 |
-
return page_lines
|
| 117 |
-
|
| 118 |
-
def merge_blocks(self, document: Document, page_lines: ProviderPageLines):
|
| 119 |
-
ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"]
|
| 120 |
-
for document_page in ocred_pages:
|
| 121 |
-
lines = page_lines[document_page.page_id]
|
| 122 |
-
document_page.merge_blocks(lines, text_extraction_method="surya")
|
|
|
|
| 1 |
+
import copy
|
| 2 |
from typing import Annotated, List, Optional
|
| 3 |
|
| 4 |
from ftfy import fix_text
|
|
|
|
| 5 |
from surya.recognition import RecognitionPredictor
|
| 6 |
|
| 7 |
from marker.builders import BaseBuilder
|
|
|
|
| 8 |
from marker.providers.pdf import PdfProvider
|
| 9 |
from marker.schema import BlockTypes
|
| 10 |
+
from marker.schema.blocks import BlockId
|
| 11 |
from marker.schema.document import Document
|
| 12 |
+
from marker.schema.groups import PageGroup
|
| 13 |
from marker.schema.registry import get_block_class
|
|
|
|
| 14 |
from marker.schema.text.span import Span
|
| 15 |
from marker.settings import settings
|
| 16 |
|
|
|
|
| 17 |
class OcrBuilder(BaseBuilder):
|
| 18 |
"""
|
| 19 |
A builder for performing OCR on PDF pages and merging the results into the document.
|
|
|
|
| 23 |
"The batch size to use for the recognition model.",
|
| 24 |
"Default is None, which will use the default batch size for the model."
|
| 25 |
] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
languages: Annotated[
|
| 27 |
Optional[List[str]],
|
| 28 |
"A list of languages to use for OCR.",
|
| 29 |
"Default is None."
|
| 30 |
] = None
|
| 31 |
+
disable_tqdm: Annotated[
|
| 32 |
bool,
|
| 33 |
+
"Disable tqdm progress bars.",
|
| 34 |
] = False
|
| 35 |
|
| 36 |
+
def __init__(self, recognition_model: RecognitionPredictor, config=None):
|
| 37 |
super().__init__(config)
|
| 38 |
|
|
|
|
| 39 |
self.recognition_model = recognition_model
|
| 40 |
|
| 41 |
def __call__(self, document: Document, provider: PdfProvider):
|
| 42 |
+
pages_to_ocr = [page for page in document.pages if page.text_extraction_method == 'surya']
|
| 43 |
+
images, line_boxes, line_ids = self.get_ocr_images_boxes_ids(document, pages_to_ocr, provider)
|
| 44 |
+
self.ocr_extraction(document, pages_to_ocr, provider, images, line_boxes, line_ids)
|
| 45 |
|
| 46 |
def get_recognition_batch_size(self):
|
| 47 |
if self.recognition_batch_size is not None:
|
|
|
|
| 52 |
return 32
|
| 53 |
return 32
|
| 54 |
|
| 55 |
+
def get_ocr_images_boxes_ids(self, document: Document, pages: List[PageGroup], provider: PdfProvider):
|
| 56 |
+
highres_images, highres_boxes, line_ids = [], [], []
|
| 57 |
+
for document_page in pages:
|
| 58 |
+
page_highres_image = document_page.get_image(highres=True)
|
| 59 |
+
page_highres_boxes = []
|
| 60 |
+
page_line_ids = []
|
| 61 |
+
|
| 62 |
+
page_size = provider.get_page_bbox(document_page.page_id).size
|
| 63 |
+
image_size = page_highres_image.size
|
| 64 |
+
for block in document_page.contained_blocks(document):
|
| 65 |
+
block_lines = block.contained_blocks(document, [BlockTypes.Line])
|
| 66 |
+
block_detected_lines = [block_line for block_line in block_lines if block_line.text_extraction_method == 'surya']
|
| 67 |
+
|
| 68 |
+
block.text_extraction_method = 'surya'
|
| 69 |
+
for line in block_detected_lines:
|
| 70 |
+
line_polygon = copy.deepcopy(line.polygon)
|
| 71 |
+
page_highres_boxes.append(line_polygon.rescale(page_size, image_size).bbox)
|
| 72 |
+
page_line_ids.append(line.id)
|
| 73 |
+
|
| 74 |
+
highres_images.append(page_highres_image)
|
| 75 |
+
highres_boxes.append(page_highres_boxes)
|
| 76 |
+
line_ids.append(page_line_ids)
|
| 77 |
+
|
| 78 |
+
return highres_images, highres_boxes, line_ids
|
| 79 |
+
|
| 80 |
+
def ocr_extraction(self, document: Document, pages: List[PageGroup], provider: PdfProvider, images: List[any], line_boxes: List[List[float]], line_ids: List[List[BlockId]]):
|
| 81 |
+
if sum(len(b) for b in line_boxes)==0:
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
self.recognition_model.disable_tqdm = self.disable_tqdm
|
| 85 |
recognition_results = self.recognition_model(
|
| 86 |
+
images=images,
|
| 87 |
+
bboxes=line_boxes,
|
| 88 |
+
langs=[self.languages] * len(pages),
|
|
|
|
| 89 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 90 |
+
sort_lines=False
|
| 91 |
)
|
| 92 |
|
|
|
|
|
|
|
| 93 |
SpanClass: Span = get_block_class(BlockTypes.Span)
|
| 94 |
+
for document_page, page_recognition_result, page_line_ids in zip(pages, recognition_results, line_ids):
|
| 95 |
+
for line_id, ocr_line in zip(page_line_ids, page_recognition_result.text_lines):
|
| 96 |
+
if not fix_text(ocr_line.text):
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
line = document_page.get_block(line_id)
|
| 100 |
+
assert line.structure is None
|
| 101 |
+
new_span = SpanClass(
|
| 102 |
+
text=fix_text(ocr_line.text) + '\n',
|
| 103 |
+
formats=['plain'],
|
| 104 |
+
page_id=document_page.page_id,
|
| 105 |
+
polygon=copy.deepcopy(line.polygon),
|
| 106 |
+
minimum_position=0,
|
| 107 |
+
maximum_position=0,
|
| 108 |
+
font='Unknown',
|
| 109 |
+
font_weight=0,
|
| 110 |
+
font_size=0,
|
| 111 |
)
|
| 112 |
+
document_page.add_full_block(new_span)
|
| 113 |
+
line.add_structure(new_span)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker/config/crawler.py
CHANGED
|
@@ -9,10 +9,11 @@ from marker.converters import BaseConverter
|
|
| 9 |
from marker.processors import BaseProcessor
|
| 10 |
from marker.providers import BaseProvider
|
| 11 |
from marker.renderers import BaseRenderer
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class ConfigCrawler:
|
| 15 |
-
def __init__(self, base_classes=(BaseBuilder, BaseProcessor, BaseConverter, BaseProvider, BaseRenderer)):
|
| 16 |
self.base_classes = base_classes
|
| 17 |
self.class_config_map = {}
|
| 18 |
|
|
|
|
| 9 |
from marker.processors import BaseProcessor
|
| 10 |
from marker.providers import BaseProvider
|
| 11 |
from marker.renderers import BaseRenderer
|
| 12 |
+
from marker.services import BaseService
|
| 13 |
|
| 14 |
|
| 15 |
class ConfigCrawler:
|
| 16 |
+
def __init__(self, base_classes=(BaseBuilder, BaseProcessor, BaseConverter, BaseProvider, BaseRenderer, BaseService)):
|
| 17 |
self.base_classes = base_classes
|
| 18 |
self.class_config_map = {}
|
| 19 |
|
marker/config/parser.py
CHANGED
|
@@ -39,9 +39,9 @@ class ConfigParser:
|
|
| 39 |
fn = click.option("--languages", type=str, default=None, help="Comma separated list of languages to use for OCR.")(fn)
|
| 40 |
|
| 41 |
# we put common options here
|
| 42 |
-
fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
|
| 43 |
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
|
| 44 |
fn = click.option("--converter_cls", type=str, default=None, help="Converter class to use. Defaults to PDF converter.")(fn)
|
|
|
|
| 45 |
|
| 46 |
# enum options
|
| 47 |
fn = click.option("--force_layout_block", type=click.Choice(choices=[t.name for t in BlockTypes]), default=None,)(fn)
|
|
@@ -74,8 +74,23 @@ class ConfigParser:
|
|
| 74 |
case _:
|
| 75 |
if k in crawler.attr_set:
|
| 76 |
config[k] = v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return config
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def get_renderer(self):
|
| 80 |
match self.cli_options["output_format"]:
|
| 81 |
case "json":
|
|
@@ -122,3 +137,4 @@ class ConfigParser:
|
|
| 122 |
def get_base_filename(self, filepath: str):
|
| 123 |
basename = os.path.basename(filepath)
|
| 124 |
return os.path.splitext(basename)[0]
|
|
|
|
|
|
| 39 |
fn = click.option("--languages", type=str, default=None, help="Comma separated list of languages to use for OCR.")(fn)
|
| 40 |
|
| 41 |
# we put common options here
|
|
|
|
| 42 |
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
|
| 43 |
fn = click.option("--converter_cls", type=str, default=None, help="Converter class to use. Defaults to PDF converter.")(fn)
|
| 44 |
+
fn = click.option("--llm_service", type=str, default=None, help="LLM service to use - should be full import path, like marker.services.gemini.GoogleGeminiService")(fn)
|
| 45 |
|
| 46 |
# enum options
|
| 47 |
fn = click.option("--force_layout_block", type=click.Choice(choices=[t.name for t in BlockTypes]), default=None,)(fn)
|
|
|
|
| 74 |
case _:
|
| 75 |
if k in crawler.attr_set:
|
| 76 |
config[k] = v
|
| 77 |
+
|
| 78 |
+
# Backward compatibility for google_api_key
|
| 79 |
+
if settings.GOOGLE_API_KEY:
|
| 80 |
+
config["gemini_api_key"] = settings.GOOGLE_API_KEY
|
| 81 |
+
|
| 82 |
return config
|
| 83 |
|
| 84 |
+
def get_llm_service(self):
|
| 85 |
+
# Only return an LLM service when use_llm is enabled
|
| 86 |
+
if not self.cli_options.get("use_llm", False):
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
service_cls = self.cli_options.get("llm_service", None)
|
| 90 |
+
if service_cls is None:
|
| 91 |
+
service_cls = "marker.services.gemini.GoogleGeminiService"
|
| 92 |
+
return service_cls
|
| 93 |
+
|
| 94 |
def get_renderer(self):
|
| 95 |
match self.cli_options["output_format"]:
|
| 96 |
case "json":
|
|
|
|
| 137 |
def get_base_filename(self, filepath: str):
|
| 138 |
basename = os.path.basename(filepath)
|
| 139 |
return os.path.splitext(basename)[0]
|
| 140 |
+
|
marker/converters/__init__.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
from marker.util import assign_config
|
| 6 |
|
| 7 |
|
|
@@ -9,6 +13,48 @@ class BaseConverter:
|
|
| 9 |
def __init__(self, config: Optional[BaseModel | dict] = None):
|
| 10 |
assign_config(self, config)
|
| 11 |
self.config = config
|
|
|
|
| 12 |
|
| 13 |
def __call__(self, *args, **kwargs):
|
| 14 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Optional, List, Type
|
| 3 |
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
+
from marker.processors import BaseProcessor
|
| 7 |
+
from marker.processors.llm import BaseLLMSimpleBlockProcessor
|
| 8 |
+
from marker.processors.llm.llm_meta import LLMSimpleBlockMetaProcessor
|
| 9 |
from marker.util import assign_config
|
| 10 |
|
| 11 |
|
|
|
|
| 13 |
def __init__(self, config: Optional[BaseModel | dict] = None):
|
| 14 |
assign_config(self, config)
|
| 15 |
self.config = config
|
| 16 |
+
self.llm_service = None
|
| 17 |
|
| 18 |
def __call__(self, *args, **kwargs):
|
| 19 |
+
raise NotImplementedError
|
| 20 |
+
|
| 21 |
+
def resolve_dependencies(self, cls):
|
| 22 |
+
init_signature = inspect.signature(cls.__init__)
|
| 23 |
+
parameters = init_signature.parameters
|
| 24 |
+
|
| 25 |
+
resolved_kwargs = {}
|
| 26 |
+
for param_name, param in parameters.items():
|
| 27 |
+
if param_name == 'self':
|
| 28 |
+
continue
|
| 29 |
+
elif param_name == 'config':
|
| 30 |
+
resolved_kwargs[param_name] = self.config
|
| 31 |
+
elif param.name in self.artifact_dict:
|
| 32 |
+
resolved_kwargs[param_name] = self.artifact_dict[param_name]
|
| 33 |
+
elif param.default != inspect.Parameter.empty:
|
| 34 |
+
resolved_kwargs[param_name] = param.default
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f"Cannot resolve dependency for parameter: {param_name}")
|
| 37 |
+
|
| 38 |
+
return cls(**resolved_kwargs)
|
| 39 |
+
|
| 40 |
+
def initialize_processors(self, processor_cls_lst: List[Type[BaseProcessor]]) -> List[BaseProcessor]:
|
| 41 |
+
processors = []
|
| 42 |
+
for processor_cls in processor_cls_lst:
|
| 43 |
+
processors.append(self.resolve_dependencies(processor_cls))
|
| 44 |
+
|
| 45 |
+
simple_llm_processors = [p for p in processors if issubclass(type(p), BaseLLMSimpleBlockProcessor)]
|
| 46 |
+
other_processors = [p for p in processors if not issubclass(type(p), BaseLLMSimpleBlockProcessor)]
|
| 47 |
+
|
| 48 |
+
if not simple_llm_processors:
|
| 49 |
+
return processors
|
| 50 |
+
|
| 51 |
+
llm_positions = [i for i, p in enumerate(processors) if issubclass(type(p), BaseLLMSimpleBlockProcessor)]
|
| 52 |
+
insert_position = max(0, llm_positions[-1] - len(simple_llm_processors) + 1)
|
| 53 |
+
|
| 54 |
+
meta_processor = LLMSimpleBlockMetaProcessor(
|
| 55 |
+
processor_lst=simple_llm_processors,
|
| 56 |
+
llm_service=self.llm_service,
|
| 57 |
+
config=self.config,
|
| 58 |
+
)
|
| 59 |
+
other_processors.insert(insert_position, meta_processor)
|
| 60 |
+
return other_processors
|
marker/converters/pdf.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
|
| 3 |
|
| 4 |
-
import inspect
|
| 5 |
from collections import defaultdict
|
| 6 |
-
from functools import cache
|
| 7 |
from typing import Annotated, Any, Dict, List, Optional, Type, Tuple
|
| 8 |
|
| 9 |
from marker.processors import BaseProcessor
|
|
@@ -12,6 +11,7 @@ from marker.providers.registry import provider_from_filepath
|
|
| 12 |
from marker.builders.document import DocumentBuilder
|
| 13 |
from marker.builders.layout import LayoutBuilder
|
| 14 |
from marker.builders.llm_layout import LLMLayoutBuilder
|
|
|
|
| 15 |
from marker.builders.ocr import OcrBuilder
|
| 16 |
from marker.builders.structure import StructureBuilder
|
| 17 |
from marker.converters import BaseConverter
|
|
@@ -41,6 +41,8 @@ from marker.schema.blocks import Block
|
|
| 41 |
from marker.schema.registry import register_block_class
|
| 42 |
from marker.util import strings_to_classes
|
| 43 |
from marker.processors.llm.llm_handwriting import LLMHandwritingProcessor
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class PdfConverter(BaseConverter):
|
|
@@ -59,6 +61,7 @@ class PdfConverter(BaseConverter):
|
|
| 59 |
"Enable higher quality processing with LLMs.",
|
| 60 |
] = False
|
| 61 |
default_processors: Tuple[BaseProcessor, ...] = (
|
|
|
|
| 62 |
BlockquoteProcessor,
|
| 63 |
CodeProcessor,
|
| 64 |
DocumentTOCProcessor,
|
|
@@ -83,9 +86,19 @@ class PdfConverter(BaseConverter):
|
|
| 83 |
DebugProcessor,
|
| 84 |
)
|
| 85 |
|
| 86 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
super().__init__(config)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
for block_type, override_block_type in self.override_map.items():
|
| 90 |
register_block_class(block_type, override_block_type)
|
| 91 |
|
|
@@ -99,44 +112,37 @@ class PdfConverter(BaseConverter):
|
|
| 99 |
else:
|
| 100 |
renderer = MarkdownRenderer
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
self.artifact_dict = artifact_dict
|
| 103 |
-
self.processor_list = processor_list
|
| 104 |
self.renderer = renderer
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
self.layout_builder_class = LayoutBuilder
|
| 107 |
if self.use_llm:
|
| 108 |
self.layout_builder_class = LLMLayoutBuilder
|
| 109 |
|
| 110 |
-
def resolve_dependencies(self, cls):
|
| 111 |
-
init_signature = inspect.signature(cls.__init__)
|
| 112 |
-
parameters = init_signature.parameters
|
| 113 |
-
|
| 114 |
-
resolved_kwargs = {}
|
| 115 |
-
for param_name, param in parameters.items():
|
| 116 |
-
if param_name == 'self':
|
| 117 |
-
continue
|
| 118 |
-
elif param_name == 'config':
|
| 119 |
-
resolved_kwargs[param_name] = self.config
|
| 120 |
-
elif param.name in self.artifact_dict:
|
| 121 |
-
resolved_kwargs[param_name] = self.artifact_dict[param_name]
|
| 122 |
-
elif param.default != inspect.Parameter.empty:
|
| 123 |
-
resolved_kwargs[param_name] = param.default
|
| 124 |
-
else:
|
| 125 |
-
raise ValueError(f"Cannot resolve dependency for parameter: {param_name}")
|
| 126 |
-
|
| 127 |
-
return cls(**resolved_kwargs)
|
| 128 |
-
|
| 129 |
-
@cache
|
| 130 |
def build_document(self, filepath: str):
|
| 131 |
provider_cls = provider_from_filepath(filepath)
|
| 132 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
|
|
|
| 133 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
| 137 |
|
| 138 |
-
for
|
| 139 |
-
processor = self.resolve_dependencies(processor_cls)
|
| 140 |
processor(document)
|
| 141 |
|
| 142 |
return document
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
|
| 4 |
|
|
|
|
| 5 |
from collections import defaultdict
|
|
|
|
| 6 |
from typing import Annotated, Any, Dict, List, Optional, Type, Tuple
|
| 7 |
|
| 8 |
from marker.processors import BaseProcessor
|
|
|
|
| 11 |
from marker.builders.document import DocumentBuilder
|
| 12 |
from marker.builders.layout import LayoutBuilder
|
| 13 |
from marker.builders.llm_layout import LLMLayoutBuilder
|
| 14 |
+
from marker.builders.line import LineBuilder
|
| 15 |
from marker.builders.ocr import OcrBuilder
|
| 16 |
from marker.builders.structure import StructureBuilder
|
| 17 |
from marker.converters import BaseConverter
|
|
|
|
| 41 |
from marker.schema.registry import register_block_class
|
| 42 |
from marker.util import strings_to_classes
|
| 43 |
from marker.processors.llm.llm_handwriting import LLMHandwritingProcessor
|
| 44 |
+
from marker.processors.order import OrderProcessor
|
| 45 |
+
from marker.services.gemini import GoogleGeminiService
|
| 46 |
|
| 47 |
|
| 48 |
class PdfConverter(BaseConverter):
|
|
|
|
| 61 |
"Enable higher quality processing with LLMs.",
|
| 62 |
] = False
|
| 63 |
default_processors: Tuple[BaseProcessor, ...] = (
|
| 64 |
+
OrderProcessor,
|
| 65 |
BlockquoteProcessor,
|
| 66 |
CodeProcessor,
|
| 67 |
DocumentTOCProcessor,
|
|
|
|
| 86 |
DebugProcessor,
|
| 87 |
)
|
| 88 |
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
artifact_dict: Dict[str, Any],
|
| 92 |
+
processor_list: Optional[List[str]] = None,
|
| 93 |
+
renderer: str | None = None,
|
| 94 |
+
llm_service: str | None = None,
|
| 95 |
+
config=None
|
| 96 |
+
):
|
| 97 |
super().__init__(config)
|
| 98 |
|
| 99 |
+
if config is None:
|
| 100 |
+
config = {}
|
| 101 |
+
|
| 102 |
for block_type, override_block_type in self.override_map.items():
|
| 103 |
register_block_class(block_type, override_block_type)
|
| 104 |
|
|
|
|
| 112 |
else:
|
| 113 |
renderer = MarkdownRenderer
|
| 114 |
|
| 115 |
+
if llm_service:
|
| 116 |
+
llm_service_cls = strings_to_classes([llm_service])[0]
|
| 117 |
+
llm_service = self.resolve_dependencies(llm_service_cls)
|
| 118 |
+
elif config.get("use_llm", False):
|
| 119 |
+
llm_service = self.resolve_dependencies(GoogleGeminiService)
|
| 120 |
+
|
| 121 |
+
# Inject llm service into artifact_dict so it can be picked up by processors, etc.
|
| 122 |
+
artifact_dict["llm_service"] = llm_service
|
| 123 |
+
self.llm_service = llm_service
|
| 124 |
+
|
| 125 |
self.artifact_dict = artifact_dict
|
|
|
|
| 126 |
self.renderer = renderer
|
| 127 |
|
| 128 |
+
processor_list = self.initialize_processors(processor_list)
|
| 129 |
+
self.processor_list = processor_list
|
| 130 |
+
|
| 131 |
self.layout_builder_class = LayoutBuilder
|
| 132 |
if self.use_llm:
|
| 133 |
self.layout_builder_class = LLMLayoutBuilder
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def build_document(self, filepath: str):
|
| 136 |
provider_cls = provider_from_filepath(filepath)
|
| 137 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 138 |
+
line_builder = self.resolve_dependencies(LineBuilder)
|
| 139 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 140 |
+
provider = provider_cls(filepath, self.config)
|
| 141 |
+
document = DocumentBuilder(self.config)(provider, layout_builder, line_builder, ocr_builder)
|
| 142 |
+
structure_builder_cls = self.resolve_dependencies(StructureBuilder)
|
| 143 |
+
structure_builder_cls(document)
|
| 144 |
|
| 145 |
+
for processor in self.processor_list:
|
|
|
|
| 146 |
processor(document)
|
| 147 |
|
| 148 |
return document
|
marker/converters/table.py
CHANGED
|
@@ -2,6 +2,7 @@ from functools import cache
|
|
| 2 |
from typing import Tuple, List
|
| 3 |
|
| 4 |
from marker.builders.document import DocumentBuilder
|
|
|
|
| 5 |
from marker.builders.ocr import OcrBuilder
|
| 6 |
from marker.converters.pdf import PdfConverter
|
| 7 |
from marker.processors import BaseProcessor
|
|
@@ -24,21 +25,21 @@ class TableConverter(PdfConverter):
|
|
| 24 |
)
|
| 25 |
converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)
|
| 26 |
|
| 27 |
-
@cache
|
| 28 |
def build_document(self, filepath: str):
|
| 29 |
provider_cls = provider_from_filepath(filepath)
|
| 30 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
|
|
|
| 31 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 32 |
document_builder = DocumentBuilder(self.config)
|
| 33 |
document_builder.disable_ocr = True
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
for page in document.pages:
|
| 38 |
page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
|
| 39 |
|
| 40 |
-
for
|
| 41 |
-
processor = self.resolve_dependencies(processor_cls)
|
| 42 |
processor(document)
|
| 43 |
|
| 44 |
return document
|
|
|
|
| 2 |
from typing import Tuple, List
|
| 3 |
|
| 4 |
from marker.builders.document import DocumentBuilder
|
| 5 |
+
from marker.builders.line import LineBuilder
|
| 6 |
from marker.builders.ocr import OcrBuilder
|
| 7 |
from marker.converters.pdf import PdfConverter
|
| 8 |
from marker.processors import BaseProcessor
|
|
|
|
| 25 |
)
|
| 26 |
converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)
|
| 27 |
|
|
|
|
| 28 |
def build_document(self, filepath: str):
|
| 29 |
provider_cls = provider_from_filepath(filepath)
|
| 30 |
layout_builder = self.resolve_dependencies(self.layout_builder_class)
|
| 31 |
+
line_builder = self.resolve_dependencies(LineBuilder)
|
| 32 |
ocr_builder = self.resolve_dependencies(OcrBuilder)
|
| 33 |
document_builder = DocumentBuilder(self.config)
|
| 34 |
document_builder.disable_ocr = True
|
| 35 |
+
|
| 36 |
+
provider = provider_cls(filepath, self.config)
|
| 37 |
+
document = document_builder(provider, layout_builder, line_builder, ocr_builder)
|
| 38 |
|
| 39 |
for page in document.pages:
|
| 40 |
page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
|
| 41 |
|
| 42 |
+
for processor in self.processor_list:
|
|
|
|
| 43 |
processor(document)
|
| 44 |
|
| 45 |
return document
|
marker/logger.py
CHANGED
|
@@ -7,3 +7,7 @@ def configure_logging():
|
|
| 7 |
|
| 8 |
logging.getLogger('PIL').setLevel(logging.ERROR)
|
| 9 |
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
logging.getLogger('PIL').setLevel(logging.ERROR)
|
| 9 |
warnings.simplefilter(action='ignore', category=FutureWarning)
|
| 10 |
+
|
| 11 |
+
logging.getLogger('fontTools.subset').setLevel(logging.ERROR)
|
| 12 |
+
logging.getLogger('fontTools.ttLib.ttFont').setLevel(logging.ERROR)
|
| 13 |
+
logging.getLogger('weasyprint').setLevel(logging.CRITICAL)
|