Vik Paruchuri commited on
Commit
93c1274
·
2 Parent(s): 333b95b 85e05d9

Merge pull request #472 from VikParuchuri/vik_dev

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/scripts.yml +29 -0
  2. README.md +55 -10
  3. benchmarks/table/scoring.py +109 -0
  4. benchmarks/table/table.py +187 -0
  5. chunk_convert.py +2 -20
  6. convert.py +2 -115
  7. convert_single.py +2 -41
  8. marker/builders/document.py +6 -1
  9. marker/builders/layout.py +37 -15
  10. marker/builders/llm_layout.py +59 -44
  11. marker/builders/ocr.py +7 -11
  12. marker/config/parser.py +17 -0
  13. marker/config/printer.py +42 -16
  14. marker/converters/pdf.py +34 -27
  15. marker/converters/table.py +49 -0
  16. marker/models.py +34 -71
  17. marker/processors/debug.py +2 -2
  18. marker/processors/equation.py +4 -6
  19. marker/processors/ignoretext.py +1 -3
  20. marker/processors/llm/__init__.py +4 -13
  21. marker/processors/llm/llm_complex.py +17 -11
  22. marker/processors/llm/llm_equation.py +82 -0
  23. marker/processors/llm/llm_form.py +57 -35
  24. marker/processors/llm/llm_handwriting.py +86 -0
  25. marker/processors/llm/llm_image_description.py +5 -2
  26. marker/processors/llm/llm_table.py +55 -23
  27. marker/processors/llm/llm_table_merge.py +318 -0
  28. marker/processors/llm/llm_text.py +8 -5
  29. marker/processors/llm/utils.py +5 -2
  30. marker/processors/table.py +204 -44
  31. marker/providers/__init__.py +13 -1
  32. marker/providers/image.py +52 -0
  33. marker/providers/pdf.py +21 -6
  34. marker/providers/registry.py +12 -0
  35. marker/renderers/__init__.py +14 -8
  36. marker/renderers/html.py +16 -8
  37. marker/renderers/json.py +3 -0
  38. marker/renderers/markdown.py +84 -8
  39. marker/schema/__init__.py +1 -0
  40. marker/schema/blocks/__init__.py +1 -0
  41. marker/schema/blocks/base.py +25 -5
  42. marker/schema/blocks/basetable.py +39 -0
  43. marker/schema/blocks/caption.py +2 -0
  44. marker/schema/blocks/code.py +2 -1
  45. marker/schema/blocks/complexregion.py +3 -2
  46. marker/schema/blocks/equation.py +4 -3
  47. marker/schema/blocks/figure.py +3 -2
  48. marker/schema/blocks/footnote.py +1 -0
  49. marker/schema/blocks/form.py +4 -15
  50. marker/schema/blocks/handwriting.py +8 -0
.github/workflows/scripts.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Test CLI scripts
2
+
3
+ on: [push]
4
+
5
+ env:
6
+ TORCH_DEVICE: "cpu"
7
+ OCR_ENGINE: "surya"
8
+
9
+ jobs:
10
+ tests:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ - name: Set up Python 3.11
15
+ uses: actions/setup-python@v4
16
+ with:
17
+ python-version: 3.11
18
+ - name: Install python dependencies
19
+ run: |
20
+ pip install poetry
21
+ poetry install
22
+ - name: Download benchmark data
23
+ run: |
24
+ wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi"
25
+ unzip -o benchmark_data.zip
26
+ - name: Test single script
27
+ run: poetry run marker_single benchmark_data/pdfs/switch_trans.pdf --page_range 0
28
+ - name: Test convert script
29
+ run: poetry run marker benchmark_data/pdfs --max_files 1 --workers 1 --page_range 0
README.md CHANGED
@@ -1,13 +1,11 @@
1
  # Marker
2
 
3
- Marker converts PDFs to markdown, JSON, and HTML quickly and accurately.
4
 
5
- - Supports a wide range of documents
6
- - Supports all languages
7
  - Removes headers/footers/other artifacts
8
- - Formats tables, forms, and code blocks
9
  - Extracts and saves images along with the markdown
10
- - Converts equations to latex
11
  - Easily extensible with your own formatting and logic
12
  - Optionally boost accuracy with an LLM
13
  - Works on GPU, CPU, or MPS
@@ -63,11 +61,11 @@ There's a hosted API for marker available [here](https://www.datalab.to/):
63
  PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
64
 
65
  - Marker will only convert block equations
66
- - Tables are not always formatted 100% correctly - multiline cells are sometimes split into multiple rows.
67
  - Forms are not converted optimally
68
  - Very complex layouts, with nested tables and forms, may not work
69
 
70
- Note: Passing the `--use_llm` flag will mostly solve all of these issues.
71
 
72
  # Installation
73
 
@@ -84,7 +82,7 @@ pip install marker-pdf
84
  First, some configuration:
85
 
86
  - Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
87
- - Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag on the CLI or via configuration to ensure your PDF runs through OCR.
88
 
89
  ## Interactive App
90
 
@@ -101,9 +99,12 @@ marker_gui
101
  marker_single /path/to/file.pdf
102
  ```
103
 
 
 
104
  Options:
105
  - `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
106
  - `--output_format [markdown|json|html]`: Specify the format for the output results.
 
107
  - `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
108
  - `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
109
  - `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
@@ -114,6 +115,7 @@ Options:
114
  - `--config_json PATH`: Path to a JSON configuration file containing additional settings.
115
  - `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
116
  - `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
 
117
 
118
  The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you don't need OCR, marker can work with any language.
119
 
@@ -179,7 +181,7 @@ rendered = converter("FILEPATH")
179
 
180
  ### Extract blocks
181
 
182
- Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programatically manipulate these blocks.
183
 
184
  Here's an example of extracting all forms from a document:
185
 
@@ -197,6 +199,28 @@ forms = document.contained_blocks((BlockTypes.Form,))
197
 
198
  Look at the processors for more examples of extracting and manipulating blocks.
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # Output Formats
201
 
202
  ## Markdown
@@ -348,7 +372,7 @@ There are some settings that you may find useful if things aren't working the wa
348
  Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
349
 
350
  # Benchmarks
351
-
352
  Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods. It's noisy, but at least directionally correct.
353
 
354
  **Speed**
@@ -371,6 +395,18 @@ Marker takes about 6GB of VRAM on average per task, so you can convert 8 documen
371
 
372
  ![Benchmark results](data/images/per_doc.png)
373
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  ## Running your own benchmarks
375
 
376
  You can benchmark the performance of marker on your machine. Install marker manually with:
@@ -380,12 +416,21 @@ git clone https://github.com/VikParuchuri/marker.git
380
  poetry install
381
  ```
382
 
 
 
383
  Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
384
 
385
  ```shell
386
  python benchmarks/overall.py data/pdfs data/references report.json
387
  ```
388
 
 
 
 
 
 
 
 
389
  # Thanks
390
 
391
  This work would not have been possible without amazing open source models and datasets, including (but not limited to):
 
1
  # Marker
2
 
3
+ Marker converts PDFs and images to markdown, JSON, and HTML quickly and accurately.
4
 
5
+ - Supports a range of documents in all languages
 
6
  - Removes headers/footers/other artifacts
7
+ - Formats tables, forms, equations, links, and code blocks
8
  - Extracts and saves images along with the markdown
 
9
  - Easily extensible with your own formatting and logic
10
  - Optionally boost accuracy with an LLM
11
  - Works on GPU, CPU, or MPS
 
61
  PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:
62
 
63
  - Marker will only convert block equations
64
+ - Tables are not always formatted 100% correctly
65
  - Forms are not converted optimally
66
  - Very complex layouts, with nested tables and forms, may not work
67
 
68
+ Note: Passing the `--use_llm` flag will mostly solve these issues.
69
 
70
  # Installation
71
 
 
82
  First, some configuration:
83
 
84
  - Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
85
+ - Some PDFs, even digital ones, have bad text in them. Set the `force_ocr` flag on the CLI or via configuration to ensure your PDF runs through OCR, or the `strip_existing_ocr` to keep all digital text, and only strip out any existing OCR text.
86
 
87
  ## Interactive App
88
 
 
99
  marker_single /path/to/file.pdf
100
  ```
101
 
102
+ You can pass in PDFs or images.
103
+
104
  Options:
105
  - `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
106
  - `--output_format [markdown|json|html]`: Specify the format for the output results.
107
+ - `--paginate_output`: Paginates the output, using `\n\n{PAGE_NUMBER}` followed by `-` * 48, then `\n\n`
108
  - `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
109
  - `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
110
  - `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
 
115
  - `--config_json PATH`: Path to a JSON configuration file containing additional settings.
116
  - `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
117
  - `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.
118
+ - `--converter_cls`: One of `marker.converters.pdf.PdfConverter` (default) or `marker.converters.table.TableConverter`. The `PdfConverter` will convert the whole PDF, the `TableConverter` will only extract and convert tables.
119
 
120
  The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you don't need OCR, marker can work with any language.
121
 
 
181
 
182
  ### Extract blocks
183
 
184
+ Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programmatically manipulate these blocks.
185
 
186
  Here's an example of extracting all forms from a document:
187
 
 
199
 
200
  Look at the processors for more examples of extracting and manipulating blocks.
201
 
202
+ ## Other converters
203
+
204
+ You can also use other converters that define different conversion pipelines:
205
+
206
+ ### Extract tables
207
+
208
+ The `TableConverter` will only convert and extract tables:
209
+
210
+ ```python
211
+ from marker.converters.table import TableConverter
212
+ from marker.models import create_model_dict
213
+ from marker.output import text_from_rendered
214
+
215
+ converter = TableConverter(
216
+ artifact_dict=create_model_dict(),
217
+ )
218
+ rendered = converter("FILEPATH")
219
+ text, _, images = text_from_rendered(rendered)
220
+ ```
221
+
222
+ This takes all the same configuration as the PdfConverter. You can specify the configuration `force_layout_block=Table` to avoid layout detection and instead assume every page is a table.
223
+
224
  # Output Formats
225
 
226
  ## Markdown
 
372
  Pass the `debug` option to activate debug mode. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.
373
 
374
  # Benchmarks
375
+ ## Overall PDF Conversion
376
  Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods. It's noisy, but at least directionally correct.
377
 
378
  **Speed**
 
395
 
396
  ![Benchmark results](data/images/per_doc.png)
397
 
398
+ ## Table Conversion
399
+ Marker can extract tables from PDFs using `marker.converters.table.TableConverter`. The table extraction performance is measured by comparing the extracted HTML representation of tables against the original HTML representations using the test split of [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/). The HTML representations are compared using a tree edit distance based metric to judge both structure and content. Marker detects and identifies the structure of all tables in a PDF page and achieves these scores:
400
+
401
+ | Avg score | Total tables | use_llm |
402
+ |-----------|--------------|---------|
403
+ | 0.824 | 54 | False |
404
+ | 0.873 | 54 | True |
405
+
406
+ The `--use_llm` flag can significantly improve table recognition performance, as you can see.
407
+
408
+ We filter out tables that we cannot align with the ground truth, since fintabnet and our layout model have slightly different detection methods (this results in some tables being split/merged).
409
+
410
  ## Running your own benchmarks
411
 
412
  You can benchmark the performance of marker on your machine. Install marker manually with:
 
416
  poetry install
417
  ```
418
 
419
+ ### Overall PDF Conversion
420
+
421
  Download the benchmark data [here](https://drive.google.com/file/d/1ZSeWDo2g1y0BRLT7KnbmytV2bjWARWba/view?usp=sharing) and unzip. Then run the overall benchmark like this:
422
 
423
  ```shell
424
  python benchmarks/overall.py data/pdfs data/references report.json
425
  ```
426
 
427
+ ### Table Conversion
428
+ The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:
429
+
430
+ ```shell
431
+ python benchmarks/table/table.py table_report.json --max_rows 1000
432
+ ```
433
+
434
  # Thanks
435
 
436
  This work would not have been possible without amazing open source models and datasets, including (but not limited to):
benchmarks/table/scoring.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"
2
+ TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
3
+ """
4
+
5
+ import distance
6
+ from apted import APTED, Config
7
+ from apted.helpers import Tree
8
+ from lxml import html
9
+ from collections import deque
10
+
11
+ def wrap_table_html(table_html:str)->str:
12
+ return f'<html><body>{table_html}</body></html>'
13
+
14
+ class TableTree(Tree):
15
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
16
+ self.tag = tag
17
+ self.colspan = colspan
18
+ self.rowspan = rowspan
19
+ self.content = content
20
+
21
+ # Sets self.name and self.children
22
+ super().__init__(tag, *children)
23
+
24
+ def bracket(self):
25
+ """Show tree using brackets notation"""
26
+ if self.tag == 'td':
27
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
28
+ (self.tag, self.colspan, self.rowspan, self.content)
29
+ else:
30
+ result = '"tag": %s' % self.tag
31
+ for child in self.children:
32
+ result += child.bracket()
33
+ return "{{{}}}".format(result)
34
+
35
+ class CustomConfig(Config):
36
+ @staticmethod
37
+ def maximum(*sequences):
38
+ return max(map(len, sequences))
39
+
40
+ def normalized_distance(self, *sequences):
41
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
42
+
43
+ def rename(self, node1, node2):
44
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
45
+ return 1.
46
+ if node1.tag == 'td':
47
+ if node1.content or node2.content:
48
+ return self.normalized_distance(node1.content, node2.content)
49
+ return 0.
50
+
51
+ def tokenize(node):
52
+ """
53
+ Tokenizes table cells
54
+ """
55
+ global __tokens__
56
+ __tokens__.append('<%s>' % node.tag)
57
+ if node.text is not None:
58
+ __tokens__ += list(node.text)
59
+ for n in node.getchildren():
60
+ tokenize(n)
61
+ if node.tag != 'unk':
62
+ __tokens__.append('</%s>' % node.tag)
63
+ if node.tag != 'td' and node.tail is not None:
64
+ __tokens__ += list(node.tail)
65
+
66
+ def tree_convert_html(node, convert_cell=False, parent=None):
67
+ """
68
+ Converts HTML tree to the format required by apted
69
+ """
70
+ global __tokens__
71
+ if node.tag == 'td':
72
+ if convert_cell:
73
+ __tokens__ = []
74
+ tokenize(node)
75
+ cell = __tokens__[1:-1].copy()
76
+ else:
77
+ cell = []
78
+ new_node = TableTree(node.tag,
79
+ int(node.attrib.get('colspan', '1')),
80
+ int(node.attrib.get('rowspan', '1')),
81
+ cell, *deque())
82
+ else:
83
+ new_node = TableTree(node.tag, None, None, None, *deque())
84
+ if parent is not None:
85
+ parent.children.append(new_node)
86
+ if node.tag != 'td':
87
+ for n in node.getchildren():
88
+ tree_convert_html(n, convert_cell, new_node)
89
+ if parent is None:
90
+ return new_node
91
+
92
+ def similarity_eval_html(pred, true, structure_only=False):
93
+ """
94
+ Computes TEDS score between the prediction and the ground truth of a given samples
95
+ """
96
+ pred, true = html.fromstring(pred), html.fromstring(true)
97
+ if pred.xpath('body/table') and true.xpath('body/table'):
98
+ pred = pred.xpath('body/table')[0]
99
+ true = true.xpath('body/table')[0]
100
+ n_nodes_pred = len(pred.xpath(".//*"))
101
+ n_nodes_true = len(true.xpath(".//*"))
102
+ tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
103
+ tree_true = tree_convert_html(true, convert_cell=not structure_only)
104
+ n_nodes = max(n_nodes_pred, n_nodes_true)
105
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
106
+ return 1.0 - (float(distance) / n_nodes)
107
+ else:
108
+ return 0.0
109
+
benchmarks/table/table.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+
6
+ from marker.renderers.json import JSONOutput, JSONBlockOutput
7
+
8
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
9
+
10
+ import base64
11
+ import time
12
+ import datasets
13
+ from tqdm import tqdm
14
+ import tempfile
15
+ import click
16
+ from tabulate import tabulate
17
+ import json
18
+ from bs4 import BeautifulSoup
19
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
20
+ from pypdfium2._helpers.misc import PdfiumError
21
+ from marker.util import matrix_intersection_area
22
+
23
+ from marker.config.parser import ConfigParser
24
+ from marker.converters.table import TableConverter
25
+ from marker.models import create_model_dict
26
+
27
+ from scoring import wrap_table_html, similarity_eval_html
28
+
29
+ def update_teds_score(result):
30
+ prediction, ground_truth = result['marker_table'], result['gt_table']
31
+ prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
32
+ score = similarity_eval_html(prediction, ground_truth)
33
+ result.update({'score':score})
34
+ return result
35
+
36
+
37
+ def extract_tables(children: List[JSONBlockOutput]):
38
+ tables = []
39
+ for child in children:
40
+ if child.block_type == 'Table':
41
+ tables.append(child)
42
+ elif child.children:
43
+ tables.extend(extract_tables(child.children))
44
+ return tables
45
+
46
+
47
+ @click.command(help="Benchmark Table to HTML Conversion")
48
+ @click.argument("out_file", type=str)
49
+ @click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
50
+ @click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
51
+ @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
52
+ @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
53
+ @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
54
+ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool, table_rec_batch_size: int | None):
55
+ models = create_model_dict()
56
+ config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
57
+ start = time.time()
58
+
59
+
60
+ dataset = datasets.load_dataset(dataset, split='train')
61
+ dataset = dataset.shuffle(seed=0)
62
+
63
+ iterations = len(dataset)
64
+ if max_rows is not None:
65
+ iterations = min(max_rows, len(dataset))
66
+
67
+ results = []
68
+ total_unaligned = 0
69
+ for i in tqdm(range(iterations), desc='Converting Tables'):
70
+ try:
71
+ row = dataset[i]
72
+ pdf_binary = base64.b64decode(row['pdf'])
73
+ gt_tables = row['tables'] #Already sorted by reading order, which is what marker returns
74
+
75
+ converter = TableConverter(
76
+ config=config_parser.generate_config_dict(),
77
+ artifact_dict=models,
78
+ processor_list=config_parser.get_processors(),
79
+ renderer=config_parser.get_renderer()
80
+ )
81
+
82
+ with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
83
+ temp_pdf_file.write(pdf_binary)
84
+ temp_pdf_file.seek(0)
85
+ tqdm.disable = True
86
+ marker_json = converter(temp_pdf_file.name).children
87
+ tqdm.disable = False
88
+
89
+ if len(marker_json) == 0 or len(gt_tables) == 0:
90
+ print(f'No tables detected, skipping...')
91
+ total_unaligned += len(gt_tables)
92
+ continue
93
+
94
+ marker_tables = extract_tables(marker_json)
95
+ marker_table_boxes = [table.bbox for table in marker_tables]
96
+ page_bbox = marker_json[0].bbox
97
+
98
+ # Normalize the bboxes
99
+ for bbox in marker_table_boxes:
100
+ bbox[0] = bbox[0] / page_bbox[2]
101
+ bbox[1] = bbox[1] / page_bbox[3]
102
+ bbox[2] = bbox[2] / page_bbox[2]
103
+ bbox[3] = bbox[3] / page_bbox[3]
104
+
105
+ gt_boxes = [table['normalized_bbox'] for table in gt_tables]
106
+ gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
107
+ marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
108
+ table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
109
+
110
+ aligned_tables = []
111
+ used_tables = set()
112
+ unaligned_tables = set()
113
+ for table_idx, alignment in enumerate(table_alignments):
114
+ try:
115
+ max_area = np.max(alignment)
116
+ aligned_idx = np.argmax(alignment)
117
+ except ValueError:
118
+ # No alignment found
119
+ unaligned_tables.add(table_idx)
120
+ continue
121
+
122
+ if aligned_idx in used_tables:
123
+ # Marker table already aligned with another gt table
124
+ unaligned_tables.add(table_idx)
125
+ continue
126
+
127
+ # Gt table doesn't align well with any marker table
128
+ gt_table_pct = gt_areas[table_idx] / max_area
129
+ if not .75 < gt_table_pct < 1.25:
130
+ unaligned_tables.add(table_idx)
131
+ continue
132
+
133
+ # Marker table doesn't align with gt table
134
+ marker_table_pct = marker_areas[aligned_idx] / max_area
135
+ if not .75 < marker_table_pct < 1.25:
136
+ unaligned_tables.add(table_idx)
137
+ continue
138
+
139
+ aligned_tables.append(
140
+ (marker_tables[aligned_idx], gt_tables[table_idx])
141
+ )
142
+ used_tables.add(aligned_idx)
143
+
144
+ total_unaligned += len(unaligned_tables)
145
+
146
+ for marker_table, gt_table in aligned_tables:
147
+ gt_table_html = gt_table['html']
148
+
149
+ #marker wraps the table in <tbody> which fintabnet data doesn't
150
+ #Fintabnet doesn't use th tags, need to be replaced for fair comparison
151
+ marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser')
152
+ marker_table_soup.find('tbody').unwrap()
153
+ for th_tag in marker_table_soup.find_all('th'):
154
+ th_tag.name = 'td'
155
+ marker_table_html = str(marker_table_soup)
156
+ marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
157
+
158
+ results.append({
159
+ "marker_table": marker_table_html,
160
+ "gt_table": gt_table_html
161
+ })
162
+ except PdfiumError:
163
+ print('Broken PDF, Skipping...')
164
+ continue
165
+
166
+ print(f"Total time: {time.time() - start}.")
167
+ print(f"Could not align {total_unaligned} tables from fintabnet.")
168
+
169
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
170
+ results = list(
171
+ tqdm(
172
+ executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
173
+ )
174
+ )
175
+ avg_score = sum([r["score"] for r in results]) / len(results)
176
+
177
+ headers = ["Avg score", "Total tables"]
178
+ data = [f"{avg_score:.3f}", len(results)]
179
+ table = tabulate([data], headers=headers, tablefmt="github")
180
+ print(table)
181
+ print("Avg score computed by comparing marker predicted HTML with original HTML")
182
+
183
+ with open(out_file, "w+") as f:
184
+ json.dump(results, f, indent=2)
185
+
186
+ if __name__ == '__main__':
187
+ main()
chunk_convert.py CHANGED
@@ -1,22 +1,4 @@
1
- import argparse
2
- import subprocess
3
- import pkg_resources
4
-
5
-
6
- def main():
7
- parser = argparse.ArgumentParser(description="Convert a folder of PDFs to a folder of markdown files in chunks.")
8
- parser.add_argument("in_folder", help="Input folder with pdfs.")
9
- parser.add_argument("out_folder", help="Output folder")
10
- args = parser.parse_args()
11
-
12
- script_path = pkg_resources.resource_filename(__name__, 'chunk_convert.sh')
13
-
14
- # Construct the command
15
- cmd = f"{script_path} {args.in_folder} {args.out_folder}"
16
-
17
- # Execute the shell script
18
- subprocess.run(cmd, shell=True, check=True)
19
-
20
 
21
  if __name__ == "__main__":
22
- main()
 
1
+ from marker.scripts import chunk_convert_cli
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ chunk_convert_cli()
convert.py CHANGED
@@ -1,117 +1,4 @@
1
- import os
2
-
3
- os.environ["GRPC_VERBOSITY"] = "ERROR"
4
- os.environ["GLOG_minloglevel"] = "2"
5
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
- os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya
7
-
8
- import math
9
- import traceback
10
-
11
- import click
12
- import torch.multiprocessing as mp
13
- from tqdm import tqdm
14
-
15
- from marker.config.parser import ConfigParser
16
- from marker.config.printer import CustomClickPrinter
17
- from marker.converters.pdf import PdfConverter
18
- from marker.logger import configure_logging
19
- from marker.models import create_model_dict
20
- from marker.output import output_exists, save_output
21
- from marker.settings import settings
22
-
23
- configure_logging()
24
-
25
-
26
- def worker_init(model_dict):
27
- if model_dict is None:
28
- model_dict = create_model_dict()
29
-
30
- global model_refs
31
- model_refs = model_dict
32
-
33
-
34
- def worker_exit():
35
- global model_refs
36
- del model_refs
37
-
38
-
39
- def process_single_pdf(args):
40
- fpath, cli_options = args
41
- config_parser = ConfigParser(cli_options)
42
-
43
- out_folder = config_parser.get_output_folder(fpath)
44
- base_name = config_parser.get_base_filename(fpath)
45
- if cli_options.get('skip_existing') and output_exists(out_folder, base_name):
46
- return
47
-
48
- try:
49
- converter = PdfConverter(
50
- config=config_parser.generate_config_dict(),
51
- artifact_dict=model_refs,
52
- processor_list=config_parser.get_processors(),
53
- renderer=config_parser.get_renderer()
54
- )
55
- rendered = converter(fpath)
56
- out_folder = config_parser.get_output_folder(fpath)
57
- save_output(rendered, out_folder, base_name)
58
- except Exception as e:
59
- print(f"Error converting {fpath}: {e}")
60
- print(traceback.format_exc())
61
-
62
-
63
- @click.command(cls=CustomClickPrinter)
64
- @click.argument("in_folder", type=str)
65
- @ConfigParser.common_options
66
- @click.option("--chunk_idx", type=int, default=0, help="Chunk index to convert")
67
- @click.option("--num_chunks", type=int, default=1, help="Number of chunks being processed in parallel")
68
- @click.option("--max_files", type=int, default=None, help="Maximum number of pdfs to convert")
69
- @click.option("--workers", type=int, default=5, help="Number of worker processes to use.")
70
- @click.option("--skip_existing", is_flag=True, default=False, help="Skip existing converted files.")
71
- def main(in_folder: str, **kwargs):
72
- in_folder = os.path.abspath(in_folder)
73
- files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)]
74
- files = [f for f in files if os.path.isfile(f)]
75
-
76
- # Handle chunks if we're processing in parallel
77
- # Ensure we get all files into a chunk
78
- chunk_size = math.ceil(len(files) / kwargs["num_chunks"])
79
- start_idx = kwargs["chunk_idx"] * chunk_size
80
- end_idx = start_idx + chunk_size
81
- files_to_convert = files[start_idx:end_idx]
82
-
83
- # Limit files converted if needed
84
- if kwargs["max_files"]:
85
- files_to_convert = files_to_convert[:kwargs["max_files"]]
86
-
87
- # Disable nested multiprocessing
88
- kwargs["disable_multiprocessing"] = True
89
-
90
- total_processes = min(len(files_to_convert), kwargs["workers"])
91
-
92
- try:
93
- mp.set_start_method('spawn') # Required for CUDA, forkserver doesn't work
94
- except RuntimeError:
95
- raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")
96
-
97
- if settings.TORCH_DEVICE == "mps" or settings.TORCH_DEVICE_MODEL == "mps":
98
- model_dict = None
99
- else:
100
- model_dict = create_model_dict()
101
- for k, v in model_dict.items():
102
- v.share_memory()
103
-
104
- print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
105
- task_args = [(f, kwargs) for f in files_to_convert]
106
-
107
- with mp.Pool(processes=total_processes, initializer=worker_init, initargs=(model_dict,)) as pool:
108
- list(tqdm(pool.imap(process_single_pdf, task_args), total=len(task_args), desc="Processing PDFs", unit="pdf"))
109
-
110
- pool._worker_handler.terminate = worker_exit
111
-
112
- # Delete all CUDA tensors
113
- del model_dict
114
-
115
 
116
  if __name__ == "__main__":
117
- main()
 
1
+ from marker.scripts import convert_cli
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ convert_cli()
convert_single.py CHANGED
@@ -1,43 +1,4 @@
1
- import os
2
-
3
- os.environ["GRPC_VERBOSITY"] = "ERROR"
4
- os.environ["GLOG_minloglevel"] = "2"
5
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
-
7
- import time
8
- import click
9
-
10
- from marker.config.parser import ConfigParser
11
- from marker.config.printer import CustomClickPrinter
12
- from marker.converters.pdf import PdfConverter
13
- from marker.logger import configure_logging
14
- from marker.models import create_model_dict
15
- from marker.output import save_output
16
-
17
- configure_logging()
18
-
19
-
20
- @click.command(cls=CustomClickPrinter, help="Convert a single PDF to markdown.")
21
- @click.argument("fpath", type=str)
22
- @ConfigParser.common_options
23
- def main(fpath: str, **kwargs):
24
- models = create_model_dict()
25
- start = time.time()
26
- config_parser = ConfigParser(kwargs)
27
-
28
- converter = PdfConverter(
29
- config=config_parser.generate_config_dict(),
30
- artifact_dict=models,
31
- processor_list=config_parser.get_processors(),
32
- renderer=config_parser.get_renderer()
33
- )
34
- rendered = converter(fpath)
35
- out_folder = config_parser.get_output_folder(fpath)
36
- save_output(rendered, out_folder, config_parser.get_base_filename(fpath))
37
-
38
- print(f"Saved markdown to {out_folder}")
39
- print(f"Total time: {time.time() - start}")
40
-
41
 
42
  if __name__ == "__main__":
43
- main()
 
1
+ from marker.scripts import convert_single_cli
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ convert_single_cli()
marker/builders/document.py CHANGED
@@ -22,11 +22,16 @@ class DocumentBuilder(BaseBuilder):
22
  int,
23
  "DPI setting for high-resolution page images used for OCR.",
24
  ] = 192
 
 
 
 
25
 
26
  def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
27
  document = self.build_document(provider)
28
  layout_builder(document, provider)
29
- ocr_builder(document, provider)
 
30
  return document
31
 
32
  def build_document(self, provider: PdfProvider):
 
22
  int,
23
  "DPI setting for high-resolution page images used for OCR.",
24
  ] = 192
25
+ disable_ocr: Annotated[
26
+ bool,
27
+ "Disable OCR processing.",
28
+ ] = False
29
 
30
  def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
31
  document = self.build_document(provider)
32
  layout_builder(document, provider)
33
+ if not self.disable_ocr:
34
+ ocr_builder(document, provider)
35
  return document
36
 
37
  def build_document(self, provider: PdfProvider):
marker/builders/layout.py CHANGED
@@ -1,11 +1,10 @@
1
  from typing import Annotated, List, Optional, Tuple
2
 
3
  import numpy as np
4
- from surya.layout import batch_layout_detection
5
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
6
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
7
- from surya.ocr_error import batch_ocr_error_detection
8
- from surya.schema import LayoutResult, OCRErrorDetectionResult
9
 
10
  from marker.builders import BaseBuilder
11
  from marker.providers import ProviderOutput, ProviderPageLines
@@ -51,15 +50,23 @@ class LayoutBuilder(BaseBuilder):
51
  Tuple[BlockTypes],
52
  "A list of block types to exclude from the layout coverage check.",
53
  ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
 
 
 
 
54
 
55
- def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
56
  self.layout_model = layout_model
57
  self.ocr_error_model = ocr_error_model
58
 
59
  super().__init__(config)
60
 
61
  def __call__(self, document: Document, provider: PdfProvider):
62
- layout_results = self.surya_layout(document.pages)
 
 
 
 
63
  self.add_blocks_to_pages(document.pages, layout_results)
64
  self.merge_blocks(document.pages, provider.page_lines)
65
 
@@ -70,12 +77,29 @@ class LayoutBuilder(BaseBuilder):
70
  return 6
71
  return 6
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
74
- processor = self.layout_model.processor
75
- layout_results = batch_layout_detection(
76
- [p.lowres_image for p in pages],
77
- self.layout_model,
78
- processor,
79
  batch_size=int(self.get_batch_size())
80
  )
81
  return layout_results
@@ -97,10 +121,8 @@ class LayoutBuilder(BaseBuilder):
97
 
98
  page_texts.append(page_text)
99
 
100
- ocr_error_detection_results = batch_ocr_error_detection(
101
  page_texts,
102
- self.ocr_error_model,
103
- self.ocr_error_model.tokenizer,
104
  batch_size=int(self.get_batch_size()) # TODO Better Multiplier
105
  )
106
  return ocr_error_detection_results
 
1
  from typing import Annotated, List, Optional, Tuple
2
 
3
  import numpy as np
4
+ from surya.layout import LayoutPredictor
5
+ from surya.layout.schema import LayoutResult, LayoutBox
6
+ from surya.ocr_error import OCRErrorPredictor
7
+ from surya.ocr_error.schema import OCRErrorDetectionResult
 
8
 
9
  from marker.builders import BaseBuilder
10
  from marker.providers import ProviderOutput, ProviderPageLines
 
50
  Tuple[BlockTypes],
51
  "A list of block types to exclude from the layout coverage check.",
52
  ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
53
+ force_layout_block: Annotated[
54
+ str,
55
+ "Skip layout and force every page to be treated as a specific block type.",
56
+ ] = None
57
 
58
+ def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
59
  self.layout_model = layout_model
60
  self.ocr_error_model = ocr_error_model
61
 
62
  super().__init__(config)
63
 
64
  def __call__(self, document: Document, provider: PdfProvider):
65
+ if self.force_layout_block is not None:
66
+ # Assign the full content of every page to a single layout type
67
+ layout_results = self.forced_layout(document.pages)
68
+ else:
69
+ layout_results = self.surya_layout(document.pages)
70
  self.add_blocks_to_pages(document.pages, layout_results)
71
  self.merge_blocks(document.pages, provider.page_lines)
72
 
 
77
  return 6
78
  return 6
79
 
80
+ def forced_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
81
+ layout_results = []
82
+ for page in pages:
83
+ layout_results.append(
84
+ LayoutResult(
85
+ image_bbox=page.polygon.bbox,
86
+ bboxes=[
87
+ LayoutBox(
88
+ label=self.force_layout_block,
89
+ position=0,
90
+ top_k={self.force_layout_block: 1},
91
+ polygon=page.polygon.polygon,
92
+ ),
93
+ ],
94
+ sliced=False
95
+ )
96
+ )
97
+ return layout_results
98
+
99
+
100
  def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
101
+ layout_results = self.layout_model(
102
+ [p.get_image(highres=False) for p in pages],
 
 
 
103
  batch_size=int(self.get_batch_size())
104
  )
105
  return layout_results
 
121
 
122
  page_texts.append(page_text)
123
 
124
+ ocr_error_detection_results = self.ocr_error_model(
125
  page_texts,
 
 
126
  batch_size=int(self.get_batch_size()) # TODO Better Multiplier
127
  )
128
  return ocr_error_detection_results
marker/builders/llm_layout.py CHANGED
@@ -1,10 +1,9 @@
1
- import json
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from typing import Annotated, Optional
4
 
5
  from google.ai.generativelanguage_v1beta.types import content
6
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
7
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
8
  from tqdm import tqdm
9
 
10
  from marker.builders.layout import LayoutBuilder
@@ -24,16 +23,16 @@ class LLMLayoutBuilder(LayoutBuilder):
24
  """
25
 
26
  google_api_key: Annotated[
27
- Optional[str],
28
  "The Google API key to use for the Gemini model.",
29
  ] = settings.GOOGLE_API_KEY
30
  confidence_threshold: Annotated[
31
  float,
32
- "The confidence threshold to use for relabeling.",
33
- ] = 0.75
34
  picture_height_threshold: Annotated[
35
  float,
36
- "The height threshold for pictures that may actually be complex regions.",
37
  ] = 0.8
38
  model_name: Annotated[
39
  str,
@@ -55,43 +54,47 @@ class LLMLayoutBuilder(LayoutBuilder):
55
  str,
56
  "The prompt to use for relabelling blocks.",
57
  "Default is a string containing the Gemini relabelling prompt."
58
- ] = """You are a layout expert specializing in document analysis.
59
  Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
60
- You will be provided with an image of a layout block and the top k predictions from the current model, along with their confidence scores.
61
  Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
62
  Do not invent any new labels.
63
- Carefully examine the image and consider the provided predictions.
64
- Choose the label you believe is the most accurate representation of the layout block.
 
 
 
 
 
 
 
65
 
66
- Here are the top k predictions from the model followed by the image:
67
 
 
68
  """
69
  complex_relabeling_prompt: Annotated[
70
  str,
71
  "The prompt to use for complex relabelling blocks.",
72
  "Default is a string containing the complex relabelling prompt."
73
- ] = """You are a layout expert specializing in document analysis.
74
  Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
75
- You will be provided with an image of a layout block and some potential labels.
76
  Your job is to analyze the image and choose the single most appropriate label from the provided labels.
77
  Do not invent any new labels.
78
- Carefully examine the image and consider the provided predictions.
79
- Choose the label you believe is the most accurate representation of the layout block.
 
 
80
 
81
  Potential labels:
82
 
83
- - Picture
84
- - Table
85
- - Form
86
- - Figure - A graph or diagram with text.
87
- - ComplexRegion - a complex region containing multiple text and other elements.
88
 
89
  Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
90
-
91
- Here is the image of the layout block:
92
  """
93
 
94
- def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
95
  super().__init__(layout_model, ocr_error_model, config)
96
 
97
  self.model = GoogleModel(self.google_api_key, self.model_name)
@@ -114,10 +117,10 @@ Here is the image of the layout block:
114
  confidence = block.top_k.get(block.block_type)
115
  # Case when the block is detected as a different type with low confidence
116
  if confidence < self.confidence_threshold:
117
- futures.append(executor.submit(self.process_block_topk_relabeling, page, block))
118
  # Case when the block is detected as a picture or figure, but is actually complex
119
  elif block.block_type in (BlockTypes.Picture, BlockTypes.Figure, BlockTypes.SectionHeader) and block.polygon.height > page.polygon.height * self.picture_height_threshold:
120
- futures.append(executor.submit(self.process_block_complex_relabeling, page, block))
121
 
122
  for future in as_completed(futures):
123
  future.result() # Raise exceptions if any occurred
@@ -125,23 +128,40 @@ Here is the image of the layout block:
125
 
126
  pbar.close()
127
 
128
- def process_block_topk_relabeling(self, page: PageGroup, block: Block):
129
- topk = {str(k): round(v, 3) for k, v in block.top_k.items()}
 
 
 
 
 
 
 
 
130
 
131
- prompt = self.topk_relabelling_prompt + '```json' + json.dumps(topk) + '```\n'
132
- return self.process_block_relabeling(page, block, prompt)
133
 
134
- def process_block_complex_relabeling(self, page: PageGroup, block: Block):
135
- complex_prompt = self.complex_relabeling_prompt
136
- return self.process_block_relabeling(page, block, complex_prompt)
137
 
138
- def process_block_relabeling(self, page: PageGroup, block: Block, prompt: str):
139
- image = self.extract_image(page, block)
 
 
 
 
 
 
 
 
 
140
  response_schema = content.Schema(
141
  type=content.Type.OBJECT,
142
  enum=[],
143
- required=["label"],
144
  properties={
 
 
 
145
  "label": content.Schema(
146
  type=content.Type.STRING,
147
  ),
@@ -162,10 +182,5 @@ Here is the image of the layout block:
162
  )
163
  page.replace_block(block, generated_block)
164
 
165
- def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
166
- page_img = page.lowres_image
167
- image_box = image_block.polygon\
168
- .rescale(page.polygon.size, page_img.size)\
169
- .expand(expand, expand)
170
- cropped = page_img.crop(image_box.bbox)
171
- return cropped
 
 
1
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ from typing import Annotated
3
 
4
  from google.ai.generativelanguage_v1beta.types import content
5
+ from surya.layout import LayoutPredictor
6
+ from surya.ocr_error import OCRErrorPredictor
7
  from tqdm import tqdm
8
 
9
  from marker.builders.layout import LayoutBuilder
 
23
  """
24
 
25
  google_api_key: Annotated[
26
+ str,
27
  "The Google API key to use for the Gemini model.",
28
  ] = settings.GOOGLE_API_KEY
29
  confidence_threshold: Annotated[
30
  float,
31
+ "The confidence threshold to use for relabeling (anything below is relabeled).",
32
+ ] = 0.7
33
  picture_height_threshold: Annotated[
34
  float,
35
+ "The height threshold for pictures that may actually be complex regions. (anything above this ratio against the page is relabeled)",
36
  ] = 0.8
37
  model_name: Annotated[
38
  str,
 
54
  str,
55
  "The prompt to use for relabelling blocks.",
56
  "Default is a string containing the Gemini relabelling prompt."
57
+ ] = """You're a layout expert specializing in document analysis.
58
  Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
59
+ You will be provided with an image of a layout block and the top k predictions from the current model, along with the per-label confidence scores.
60
  Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
61
  Do not invent any new labels.
62
+ Carefully examine the image and consider the provided predictions. Take the model confidence scores into account. The confidence is reported on a 0-1 scale, with 1 being 100% confident. If the existing label is the most appropriate, you should not change it.
63
+ **Instructions**
64
+ 1. Analyze the image and consider the provided top k predictions.
65
+ 2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
66
+ 3. Choose the single most appropriate label from the provided top k predictions.
67
+
68
+ Here are descriptions of the layout blocks you can choose from:
69
+
70
+ {potential_labels}
71
 
72
+ Here are the top k predictions from the model:
73
 
74
+ {top_k}
75
  """
76
  complex_relabeling_prompt: Annotated[
77
  str,
78
  "The prompt to use for complex relabelling blocks.",
79
  "Default is a string containing the complex relabelling prompt."
80
+ ] = """You're a layout expert specializing in document analysis.
81
  Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
82
+ You will be provided with an image of a layout block and some potential labels that might be appropriate.
83
  Your job is to analyze the image and choose the single most appropriate label from the provided labels.
84
  Do not invent any new labels.
85
+ **Instructions**
86
+ 1. Analyze the image and consider the potential labels.
87
+ 2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
88
+ 3. Choose the single most appropriate label from the provided labels.
89
 
90
  Potential labels:
91
 
92
+ {potential_labels}
 
 
 
 
93
 
94
  Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
 
 
95
  """
96
 
97
+ def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
98
  super().__init__(layout_model, ocr_error_model, config)
99
 
100
  self.model = GoogleModel(self.google_api_key, self.model_name)
 
117
  confidence = block.top_k.get(block.block_type)
118
  # Case when the block is detected as a different type with low confidence
119
  if confidence < self.confidence_threshold:
120
+ futures.append(executor.submit(self.process_block_topk_relabeling, document, page, block))
121
  # Case when the block is detected as a picture or figure, but is actually complex
122
  elif block.block_type in (BlockTypes.Picture, BlockTypes.Figure, BlockTypes.SectionHeader) and block.polygon.height > page.polygon.height * self.picture_height_threshold:
123
+ futures.append(executor.submit(self.process_block_complex_relabeling, document, page, block))
124
 
125
  for future in as_completed(futures):
126
  future.result() # Raise exceptions if any occurred
 
128
 
129
  pbar.close()
130
 
131
+ def process_block_topk_relabeling(self, document: Document, page: PageGroup, block: Block):
132
+ topk_types = list(block.top_k.keys())
133
+ potential_labels = ""
134
+ for block_type in topk_types:
135
+ label_cls = get_block_class(block_type)
136
+ potential_labels += f"- `{block_type}` - {label_cls.model_fields['block_description'].default}\n"
137
+
138
+ topk = ""
139
+ for k,v in block.top_k.items():
140
+ topk += f"- `{k}` - Confidence {round(v, 3)}\n"
141
 
142
+ prompt = self.topk_relabelling_prompt.replace("{potential_labels}", potential_labels).replace("{top_k}", topk)
 
143
 
144
+ return self.process_block_relabeling(document, page, block, prompt)
 
 
145
 
146
+ def process_block_complex_relabeling(self, document: Document, page: PageGroup, block: Block):
147
+ potential_labels = ""
148
+ for block_type in [BlockTypes.Figure, BlockTypes.Picture, BlockTypes.ComplexRegion, BlockTypes.Table, BlockTypes.Form]:
149
+ label_cls = get_block_class(block_type)
150
+ potential_labels += f"- `{block_type}` - {label_cls.model_fields['block_description'].default}\n"
151
+
152
+ complex_prompt = self.complex_relabeling_prompt.replace("{potential_labels}", potential_labels)
153
+ return self.process_block_relabeling(document, page, block, complex_prompt)
154
+
155
+ def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
156
+ image = self.extract_image(document, block)
157
  response_schema = content.Schema(
158
  type=content.Type.OBJECT,
159
  enum=[],
160
+ required=["image_description", "label"],
161
  properties={
162
+ "image_description": content.Schema(
163
+ type=content.Type.STRING,
164
+ ),
165
  "label": content.Schema(
166
  type=content.Type.STRING,
167
  ),
 
182
  )
183
  page.replace_block(block, generated_block)
184
 
185
+ def extract_image(self, document: Document, image_block: Block, expand: float = 0.01):
186
+ return image_block.get_image(document, highres=False, expansion=(expand, expand))
 
 
 
 
 
marker/builders/ocr.py CHANGED
@@ -1,9 +1,8 @@
1
  from typing import Annotated, List, Optional
2
 
3
  from ftfy import fix_text
4
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
5
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
6
- from surya.ocr import run_ocr
7
 
8
  from marker.builders import BaseBuilder
9
  from marker.providers import ProviderOutput, ProviderPageLines
@@ -37,7 +36,7 @@ class OcrBuilder(BaseBuilder):
37
  "Default is None."
38
  ] = None
39
 
40
- def __init__(self, detection_model: EfficientViTForSemanticSegmentation, recognition_model: OCREncoderDecoderModel, config=None):
41
  super().__init__(config)
42
 
43
  self.detection_model = detection_model
@@ -65,16 +64,13 @@ class OcrBuilder(BaseBuilder):
65
 
66
  def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
67
  page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
68
- recognition_results = run_ocr(
69
- images=[page.lowres_image for page in page_list],
70
  langs=[self.languages] * len(page_list),
71
- det_model=self.detection_model,
72
- det_processor=self.detection_model.processor,
73
- rec_model=self.recognition_model,
74
- rec_processor=self.recognition_model.processor,
75
  detection_batch_size=int(self.get_detection_batch_size()),
76
  recognition_batch_size=int(self.get_recognition_batch_size()),
77
- highres_images=[page.highres_image for page in page_list]
78
  )
79
 
80
  page_lines = {}
 
1
  from typing import Annotated, List, Optional
2
 
3
  from ftfy import fix_text
4
+ from surya.detection import DetectionPredictor
5
+ from surya.recognition import RecognitionPredictor
 
6
 
7
  from marker.builders import BaseBuilder
8
  from marker.providers import ProviderOutput, ProviderPageLines
 
36
  "Default is None."
37
  ] = None
38
 
39
+ def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
40
  super().__init__(config)
41
 
42
  self.detection_model = detection_model
 
64
 
65
  def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
66
  page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
67
+ recognition_results = self.recognition_model(
68
+ images=[page.get_image(highres=False) for page in page_list],
69
  langs=[self.languages] * len(page_list),
70
+ det_predictor=self.detection_model,
 
 
 
71
  detection_batch_size=int(self.get_detection_batch_size()),
72
  recognition_batch_size=int(self.get_recognition_batch_size()),
73
+ highres_images=[page.get_image(highres=True) for page in page_list]
74
  )
75
 
76
  page_lines = {}
marker/config/parser.py CHANGED
@@ -5,11 +5,13 @@ from typing import Dict
5
  import click
6
 
7
  from marker.config.crawler import crawler
 
8
  from marker.renderers.html import HTMLRenderer
9
  from marker.renderers.json import JSONRenderer
10
  from marker.renderers.markdown import MarkdownRenderer
11
  from marker.settings import settings
12
  from marker.util import classes_to_strings, parse_range_str, strings_to_classes
 
13
 
14
 
15
  class ConfigParser:
@@ -39,6 +41,10 @@ class ConfigParser:
39
  # we put common options here
40
  fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
41
  fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
 
 
 
 
42
  return fn
43
 
44
  def generate_config_dict(self) -> Dict[str, any]:
@@ -95,6 +101,17 @@ class ConfigParser:
95
 
96
  return processors
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  def get_output_folder(self, filepath: str):
99
  output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
100
  fname_base = os.path.splitext(os.path.basename(filepath))[0]
 
5
  import click
6
 
7
  from marker.config.crawler import crawler
8
+ from marker.converters.pdf import PdfConverter
9
  from marker.renderers.html import HTMLRenderer
10
  from marker.renderers.json import JSONRenderer
11
  from marker.renderers.markdown import MarkdownRenderer
12
  from marker.settings import settings
13
  from marker.util import classes_to_strings, parse_range_str, strings_to_classes
14
+ from marker.schema import BlockTypes
15
 
16
 
17
  class ConfigParser:
 
41
  # we put common options here
42
  fn = click.option("--google_api_key", type=str, default=None, help="Google API key for using LLMs.")(fn)
43
  fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
44
+ fn = click.option("--converter_cls", type=str, default=None, help="Converter class to use. Defaults to PDF converter.")(fn)
45
+
46
+ # enum options
47
+ fn = click.option("--force_layout_block", type=click.Choice(choices=[t.name for t in BlockTypes]), default=None,)(fn)
48
  return fn
49
 
50
  def generate_config_dict(self) -> Dict[str, any]:
 
101
 
102
  return processors
103
 
104
+ def get_converter_cls(self):
105
+ converter_cls = self.cli_options.get("converter_cls", None)
106
+ if converter_cls is not None:
107
+ try:
108
+ return strings_to_classes([converter_cls])[0]
109
+ except Exception as e:
110
+ print(f"Error loading converter: {converter_cls} with error: {e}")
111
+ raise
112
+
113
+ return PdfConverter
114
+
115
  def get_output_folder(self, filepath: str):
116
  output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
117
  fname_base = os.path.splitext(os.path.basename(filepath))[0]
marker/config/printer.py CHANGED
@@ -6,19 +6,47 @@ from marker.config.crawler import crawler
6
 
7
 
8
  class CustomClickPrinter(click.Command):
9
- def get_help(self, ctx):
10
- additional_help = (
11
- "\n\nTip: Use 'config --help' to display all the attributes of the Builders, Processors, and Converters in Marker."
12
- )
13
- help_text = super().get_help(ctx)
14
- help_text = help_text + additional_help
15
- click.echo(help_text)
16
-
17
  def parse_args(self, ctx, args):
 
18
  display_help = 'config' in args and '--help' in args
19
  if display_help:
20
- click.echo("Here is a list of all the Builders, Processors, Converters, Providers and Renderers in Marker along with their attributes:")
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  for base_type, base_type_dict in crawler.class_config_map.items():
23
  if display_help:
24
  click.echo(f"{base_type}s:")
@@ -32,16 +60,14 @@ class CustomClickPrinter(click.Command):
32
  if display_help:
33
  click.echo(" " * 8 + f"{attr} ({formatted_type}):")
34
  click.echo("\n".join([f'{" " * 12}' + desc for desc in metadata]))
35
- if attr_type in [str, int, float, bool, Optional[int], Optional[float], Optional[str]]:
 
36
  is_flag = attr_type in [bool, Optional[bool]] and not default
37
- if crawler.attr_counts.get(attr) > 1:
38
- options = ["--" + class_name_attr]
39
- else:
40
- options = ["--" + attr, "--" + class_name_attr]
41
- options.append(class_name_attr)
42
  ctx.command.params.append(
43
  click.Option(
44
- options,
45
  type=attr_type,
46
  help=" ".join(metadata),
47
  is_flag=is_flag,
 
6
 
7
 
8
  class CustomClickPrinter(click.Command):
 
 
 
 
 
 
 
 
9
  def parse_args(self, ctx, args):
10
+
11
  display_help = 'config' in args and '--help' in args
12
  if display_help:
13
+ click.echo(
14
+ "Here is a list of all the Builders, Processors, Converters, Providers and Renderers in Marker along with their attributes:")
15
+
16
+ # Keep track of shared attributes and their types
17
+ shared_attrs = {}
18
 
19
+ # First pass: identify shared attributes and verify compatibility
20
+ for base_type, base_type_dict in crawler.class_config_map.items():
21
+ for class_name, class_map in base_type_dict.items():
22
+ for attr, (attr_type, formatted_type, default, metadata) in class_map['config'].items():
23
+ if attr not in shared_attrs:
24
+ shared_attrs[attr] = {
25
+ 'classes': [],
26
+ 'type': attr_type,
27
+ 'is_flag': attr_type in [bool, Optional[bool]] and not default,
28
+ 'metadata': metadata,
29
+ 'default': default
30
+ }
31
+ shared_attrs[attr]['classes'].append(class_name)
32
+
33
+ # These are the types of attrs that can be set from the command line
34
+ attr_types = [str, int, float, bool, Optional[int], Optional[float], Optional[str]]
35
+
36
+ # Add shared attribute options first
37
+ for attr, info in shared_attrs.items():
38
+ if info['type'] in attr_types:
39
+ ctx.command.params.append(
40
+ click.Option(
41
+ ["--" + attr],
42
+ type=info['type'],
43
+ help=" ".join(info['metadata']) + f" (Applies to: {', '.join(info['classes'])})",
44
+ default=info['default'],
45
+ is_flag=info['is_flag'],
46
+ )
47
+ )
48
+
49
+ # Second pass: create class-specific options
50
  for base_type, base_type_dict in crawler.class_config_map.items():
51
  if display_help:
52
  click.echo(f"{base_type}s:")
 
60
  if display_help:
61
  click.echo(" " * 8 + f"{attr} ({formatted_type}):")
62
  click.echo("\n".join([f'{" " * 12}' + desc for desc in metadata]))
63
+
64
+ if attr_type in attr_types:
65
  is_flag = attr_type in [bool, Optional[bool]] and not default
66
+
67
+ # Only add class-specific options
 
 
 
68
  ctx.command.params.append(
69
  click.Option(
70
+ ["--" + class_name_attr, class_name_attr],
71
  type=attr_type,
72
  help=" ".join(metadata),
73
  is_flag=is_flag,
marker/converters/pdf.py CHANGED
@@ -1,12 +1,14 @@
1
  import os
2
-
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
4
 
5
  import inspect
6
  from collections import defaultdict
7
  from functools import cache
8
- from typing import Annotated, Any, Dict, List, Optional, Type
9
 
 
 
 
10
  from marker.builders.document import DocumentBuilder
11
  from marker.builders.layout import LayoutBuilder
12
  from marker.builders.llm_layout import LLMLayoutBuilder
@@ -32,12 +34,13 @@ from marker.processors.reference import ReferenceProcessor
32
  from marker.processors.sectionheader import SectionHeaderProcessor
33
  from marker.processors.table import TableProcessor
34
  from marker.processors.text import TextProcessor
35
- from marker.providers.pdf import PdfProvider
36
  from marker.renderers.markdown import MarkdownRenderer
37
  from marker.schema import BlockTypes
38
  from marker.schema.blocks import Block
39
  from marker.schema.registry import register_block_class
40
  from marker.util import strings_to_classes
 
41
 
42
 
43
  class PdfConverter(BaseConverter):
@@ -55,6 +58,30 @@ class PdfConverter(BaseConverter):
55
  bool,
56
  "Enable higher quality processing with LLMs.",
57
  ] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[str]] = None, renderer: str | None = None, config=None):
60
  super().__init__(config)
@@ -65,27 +92,7 @@ class PdfConverter(BaseConverter):
65
  if processor_list:
66
  processor_list = strings_to_classes(processor_list)
67
  else:
68
- processor_list = [
69
- BlockquoteProcessor,
70
- CodeProcessor,
71
- DocumentTOCProcessor,
72
- EquationProcessor,
73
- FootnoteProcessor,
74
- IgnoreTextProcessor,
75
- LineNumbersProcessor,
76
- ListProcessor,
77
- PageHeaderProcessor,
78
- SectionHeaderProcessor,
79
- TableProcessor,
80
- LLMTableProcessor,
81
- LLMFormProcessor,
82
- TextProcessor,
83
- LLMTextProcessor,
84
- LLMComplexRegionProcessor,
85
- LLMImageDescriptionProcessor,
86
- ReferenceProcessor,
87
- DebugProcessor,
88
- ]
89
 
90
  if renderer:
91
  renderer = strings_to_classes([renderer])[0]
@@ -121,11 +128,11 @@ class PdfConverter(BaseConverter):
121
 
122
  @cache
123
  def build_document(self, filepath: str):
 
124
  layout_builder = self.resolve_dependencies(self.layout_builder_class)
125
  ocr_builder = self.resolve_dependencies(OcrBuilder)
126
-
127
- with PdfProvider(filepath, self.config) as pdf_provider:
128
- document = DocumentBuilder(self.config)(pdf_provider, layout_builder, ocr_builder)
129
  StructureBuilder(self.config)(document)
130
 
131
  for processor_cls in self.processor_list:
 
1
  import os
 
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
3
 
4
  import inspect
5
  from collections import defaultdict
6
  from functools import cache
7
+ from typing import Annotated, Any, Dict, List, Optional, Type, Tuple
8
 
9
+ from marker.processors import BaseProcessor
10
+ from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
11
+ from marker.providers.registry import provider_from_filepath
12
  from marker.builders.document import DocumentBuilder
13
  from marker.builders.layout import LayoutBuilder
14
  from marker.builders.llm_layout import LLMLayoutBuilder
 
34
  from marker.processors.sectionheader import SectionHeaderProcessor
35
  from marker.processors.table import TableProcessor
36
  from marker.processors.text import TextProcessor
37
+ from marker.processors.llm.llm_equation import LLMEquationProcessor
38
  from marker.renderers.markdown import MarkdownRenderer
39
  from marker.schema import BlockTypes
40
  from marker.schema.blocks import Block
41
  from marker.schema.registry import register_block_class
42
  from marker.util import strings_to_classes
43
+ from marker.processors.llm.llm_handwriting import LLMHandwritingProcessor
44
 
45
 
46
  class PdfConverter(BaseConverter):
 
58
  bool,
59
  "Enable higher quality processing with LLMs.",
60
  ] = False
61
+ default_processors: Tuple[BaseProcessor, ...] = (
62
+ BlockquoteProcessor,
63
+ CodeProcessor,
64
+ DocumentTOCProcessor,
65
+ EquationProcessor,
66
+ FootnoteProcessor,
67
+ IgnoreTextProcessor,
68
+ LineNumbersProcessor,
69
+ ListProcessor,
70
+ PageHeaderProcessor,
71
+ SectionHeaderProcessor,
72
+ TableProcessor,
73
+ LLMTableProcessor,
74
+ LLMTableMergeProcessor,
75
+ LLMFormProcessor,
76
+ TextProcessor,
77
+ LLMTextProcessor,
78
+ LLMComplexRegionProcessor,
79
+ LLMImageDescriptionProcessor,
80
+ LLMEquationProcessor,
81
+ LLMHandwritingProcessor,
82
+ ReferenceProcessor,
83
+ DebugProcessor,
84
+ )
85
 
86
  def __init__(self, artifact_dict: Dict[str, Any], processor_list: Optional[List[str]] = None, renderer: str | None = None, config=None):
87
  super().__init__(config)
 
92
  if processor_list:
93
  processor_list = strings_to_classes(processor_list)
94
  else:
95
+ processor_list = self.default_processors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if renderer:
98
  renderer = strings_to_classes([renderer])[0]
 
128
 
129
  @cache
130
  def build_document(self, filepath: str):
131
+ provider_cls = provider_from_filepath(filepath)
132
  layout_builder = self.resolve_dependencies(self.layout_builder_class)
133
  ocr_builder = self.resolve_dependencies(OcrBuilder)
134
+ with provider_cls(filepath, self.config) as provider:
135
+ document = DocumentBuilder(self.config)(provider, layout_builder, ocr_builder)
 
136
  StructureBuilder(self.config)(document)
137
 
138
  for processor_cls in self.processor_list:
marker/converters/table.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cache
2
+ from typing import Tuple, List
3
+
4
+ from marker.builders.document import DocumentBuilder
5
+ from marker.builders.ocr import OcrBuilder
6
+ from marker.converters.pdf import PdfConverter
7
+ from marker.processors import BaseProcessor
8
+ from marker.processors.llm.llm_complex import LLMComplexRegionProcessor
9
+ from marker.processors.llm.llm_form import LLMFormProcessor
10
+ from marker.processors.llm.llm_table import LLMTableProcessor
11
+ from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
12
+ from marker.processors.table import TableProcessor
13
+ from marker.providers.registry import provider_from_filepath
14
+ from marker.schema import BlockTypes
15
+
16
+
17
+ class TableConverter(PdfConverter):
18
+ default_processors: Tuple[BaseProcessor, ...] = (
19
+ TableProcessor,
20
+ LLMTableProcessor,
21
+ LLMTableMergeProcessor,
22
+ LLMFormProcessor,
23
+ LLMComplexRegionProcessor,
24
+ )
25
+ converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)
26
+
27
+ @cache
28
+ def build_document(self, filepath: str):
29
+ provider_cls = provider_from_filepath(filepath)
30
+ layout_builder = self.resolve_dependencies(self.layout_builder_class)
31
+ ocr_builder = self.resolve_dependencies(OcrBuilder)
32
+ document_builder = DocumentBuilder(self.config)
33
+ document_builder.disable_ocr = True
34
+ with provider_cls(filepath, self.config) as provider:
35
+ document = document_builder(provider, layout_builder, ocr_builder)
36
+
37
+ for page in document.pages:
38
+ page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
39
+
40
+ for processor_cls in self.processor_list:
41
+ processor = self.resolve_dependencies(processor_cls)
42
+ processor(document)
43
+
44
+ return document
45
+
46
+ def __call__(self, filepath: str):
47
+ document = self.build_document(filepath)
48
+ renderer = self.resolve_dependencies(self.renderer)
49
+ return renderer(document)
marker/models.py CHANGED
@@ -1,86 +1,49 @@
1
  import os
2
 
3
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
4
-
5
- from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
6
- from surya.model.layout.model import load_model as load_layout_model
7
- from surya.model.layout.processor import load_processor as load_layout_processor
8
- from texify.model.model import load_model as load_texify_model
9
- from texify.model.processor import load_processor as load_texify_processor
10
  from marker.settings import settings
11
- from surya.model.recognition.model import load_model as load_recognition_model
12
- from surya.model.recognition.processor import load_processor as load_recognition_processor
13
- from surya.model.table_rec.model import load_model as load_table_model
14
- from surya.model.table_rec.processor import load_processor as load_table_processor
15
- from surya.model.ocr_error.model import load_model as load_ocr_error_model
16
- from surya.model.ocr_error.model import load_tokenizer as load_ocr_error_tokenizer
17
-
18
- from texify.model.model import GenerateVisionEncoderDecoderModel
19
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
20
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
21
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
22
- from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
23
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
24
-
25
 
26
- def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
27
- if device:
28
- table_model = load_table_model(device=device, dtype=dtype)
29
- else:
30
- table_model = load_table_model()
31
- table_model.processor = load_table_processor()
32
- return table_model
33
-
34
-
35
- def setup_recognition_model(device=None, dtype=None) -> OCREncoderDecoderModel:
36
- if device:
37
- rec_model = load_recognition_model(device=device, dtype=dtype)
38
- else:
39
- rec_model = load_recognition_model()
40
- rec_model.processor = load_recognition_processor()
41
- return rec_model
42
 
 
 
43
 
44
- def setup_detection_model(device=None, dtype=None) -> EfficientViTForSemanticSegmentation:
45
- if device:
46
- model = load_detection_model(device=device, dtype=dtype)
47
- else:
48
- model = load_detection_model()
49
- model.processor = load_detection_processor()
50
- return model
51
 
 
 
 
52
 
53
- def setup_texify_model(device=None, dtype=None) -> GenerateVisionEncoderDecoderModel:
54
- if device:
55
- texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
56
- else:
57
- texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
58
- texify_model.processor = load_texify_processor()
59
- return texify_model
60
 
 
 
 
 
61
 
62
- def setup_layout_model(device=None, dtype=None) -> SuryaLayoutModel:
63
- if device:
64
- model = load_layout_model(device=device, dtype=dtype)
65
- else:
66
- model = load_layout_model()
67
- model.processor = load_layout_processor()
68
- return model
69
 
70
- def setup_ocr_error_model(device=None, dtype=None) -> DistilBertForSequenceClassification:
71
- if device:
72
- model = load_ocr_error_model(device=device, dtype=dtype)
73
- else:
74
- model = load_ocr_error_model()
75
- model.tokenizer = load_ocr_error_tokenizer()
76
- return model
77
 
78
  def create_model_dict(device=None, dtype=None) -> dict:
79
  return {
80
- "layout_model": setup_layout_model(device, dtype),
81
- "texify_model": setup_texify_model(device, dtype),
82
- "recognition_model": setup_recognition_model(device, dtype),
83
- "table_rec_model": setup_table_rec_model(device, dtype),
84
- "detection_model": setup_detection_model(device, dtype),
85
- "ocr_error_model": setup_ocr_error_model(device,dtype)
86
  }
 
1
  import os
2
 
 
 
 
 
 
 
 
3
  from marker.settings import settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from typing import List
8
+ from PIL import Image
9
 
10
+ from surya.detection import DetectionPredictor
11
+ from surya.layout import LayoutPredictor
12
+ from surya.ocr_error import OCRErrorPredictor
13
+ from surya.recognition import RecognitionPredictor
14
+ from surya.table_rec import TableRecPredictor
 
 
15
 
16
+ from texify.model.model import load_model as load_texify_model
17
+ from texify.model.processor import load_processor as load_texify_processor
18
+ from texify.inference import batch_inference
19
 
20
+ class TexifyPredictor:
21
+ def __init__(self, device=None, dtype=None):
22
+ if not device:
23
+ device = settings.TORCH_DEVICE_MODEL
24
+ if not dtype:
25
+ dtype = settings.TEXIFY_DTYPE
 
26
 
27
+ self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
28
+ self.processor = load_texify_processor()
29
+ self.device = device
30
+ self.dtype = dtype
31
 
32
+ def __call__(self, batch_images: List[Image.Image], max_tokens: int):
33
+ return batch_inference(
34
+ batch_images,
35
+ self.model,
36
+ self.processor,
37
+ max_tokens=max_tokens
38
+ )
39
 
 
 
 
 
 
 
 
40
 
41
  def create_model_dict(device=None, dtype=None) -> dict:
42
  return {
43
+ "layout_model": LayoutPredictor(device=device, dtype=dtype),
44
+ "texify_model": TexifyPredictor(device=device, dtype=dtype),
45
+ "recognition_model": RecognitionPredictor(device=device, dtype=dtype),
46
+ "table_rec_model": TableRecPredictor(device=device, dtype=dtype),
47
+ "detection_model": DetectionPredictor(device=device, dtype=dtype),
48
+ "ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype)
49
  }
marker/processors/debug.py CHANGED
@@ -68,7 +68,7 @@ class DebugProcessor(BaseProcessor):
68
 
69
  def draw_pdf_debug_images(self, document: Document):
70
  for page in document.pages:
71
- png_image = page.highres_image.copy()
72
 
73
  line_bboxes = []
74
  span_bboxes = []
@@ -90,7 +90,7 @@ class DebugProcessor(BaseProcessor):
90
 
91
  def draw_layout_debug_images(self, document: Document, pdf_mode=False):
92
  for page in document.pages:
93
- img_size = page.highres_image.size
94
  png_image = Image.new("RGB", img_size, color="white")
95
 
96
  line_bboxes = []
 
68
 
69
  def draw_pdf_debug_images(self, document: Document):
70
  for page in document.pages:
71
+ png_image = page.get_image(highres=True).copy()
72
 
73
  line_bboxes = []
74
  span_bboxes = []
 
90
 
91
  def draw_layout_debug_images(self, document: Document, pdf_mode=False):
92
  for page in document.pages:
93
+ img_size = page.get_image(highres=True).size
94
  png_image = Image.new("RGB", img_size, color="white")
95
 
96
  line_bboxes = []
marker/processors/equation.py CHANGED
@@ -4,6 +4,7 @@ from texify.inference import batch_inference
4
  from texify.model.model import GenerateVisionEncoderDecoderModel
5
  from tqdm import tqdm
6
 
 
7
  from marker.processors import BaseProcessor
8
  from marker.schema import BlockTypes
9
  from marker.schema.document import Document
@@ -32,7 +33,7 @@ class EquationProcessor(BaseProcessor):
32
  "The number of tokens to buffer above max for the Texify model.",
33
  ] = 256
34
 
35
- def __init__(self, texify_model: GenerateVisionEncoderDecoderModel, config=None):
36
  super().__init__(config)
37
 
38
  self.texify_model = texify_model
@@ -42,8 +43,7 @@ class EquationProcessor(BaseProcessor):
42
 
43
  for page in document.pages:
44
  for block in page.contained_blocks(document, self.block_types):
45
- image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.lowres_image.size)
46
- image = page.lowres_image.crop(image_poly.bbox).convert("RGB")
47
  raw_text = block.raw_text(document)
48
  token_count = self.get_total_texify_tokens(raw_text)
49
 
@@ -92,10 +92,8 @@ class EquationProcessor(BaseProcessor):
92
 
93
  batch_images = [eq["image"] for eq in batch_equations]
94
 
95
- model_output = batch_inference(
96
  batch_images,
97
- self.texify_model,
98
- self.texify_model.processor,
99
  max_tokens=max_length
100
  )
101
 
 
4
  from texify.model.model import GenerateVisionEncoderDecoderModel
5
  from tqdm import tqdm
6
 
7
+ from marker.models import TexifyPredictor
8
  from marker.processors import BaseProcessor
9
  from marker.schema import BlockTypes
10
  from marker.schema.document import Document
 
33
  "The number of tokens to buffer above max for the Texify model.",
34
  ] = 256
35
 
36
+ def __init__(self, texify_model: TexifyPredictor, config=None):
37
  super().__init__(config)
38
 
39
  self.texify_model = texify_model
 
43
 
44
  for page in document.pages:
45
  for block in page.contained_blocks(document, self.block_types):
46
+ image = block.get_image(document, highres=False).convert("RGB")
 
47
  raw_text = block.raw_text(document)
48
  token_count = self.get_total_texify_tokens(raw_text)
49
 
 
92
 
93
  batch_images = [eq["image"] for eq in batch_equations]
94
 
95
+ model_output = self.texify_model(
96
  batch_images,
 
 
97
  max_tokens=max_length
98
  )
99
 
marker/processors/ignoretext.py CHANGED
@@ -17,8 +17,7 @@ class IgnoreTextProcessor(BaseProcessor):
17
  These blocks often represent repetitive or non-essential elements, such as headers, footers, or page numbers.
18
  """
19
  block_types = (
20
- BlockTypes.Text, BlockTypes.PageHeader,
21
- BlockTypes.PageFooter, BlockTypes.SectionHeader,
22
  BlockTypes.TextInlineMath
23
  )
24
  common_element_threshold: Annotated[
@@ -47,7 +46,6 @@ class IgnoreTextProcessor(BaseProcessor):
47
  last_blocks = []
48
  for page in document.pages:
49
  initial_block = None
50
- block = None
51
  last_block = None
52
  for block in page.contained_blocks(document, self.block_types):
53
  if block.structure is not None:
 
17
  These blocks often represent repetitive or non-essential elements, such as headers, footers, or page numbers.
18
  """
19
  block_types = (
20
+ BlockTypes.Text, BlockTypes.SectionHeader,
 
21
  BlockTypes.TextInlineMath
22
  )
23
  common_element_threshold: Annotated[
 
46
  last_blocks = []
47
  for page in document.pages:
48
  initial_block = None
 
49
  last_block = None
50
  for block in page.contained_blocks(document, self.block_types):
51
  if block.structure is not None:
marker/processors/llm/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
  from typing import Annotated, Optional
3
 
@@ -16,7 +17,7 @@ class BaseLLMProcessor(BaseProcessor):
16
  A processor for using LLMs to convert blocks.
17
  """
18
  google_api_key: Annotated[
19
- Optional[str],
20
  "The Google API key to use for the Gemini model.",
21
  ] = settings.GOOGLE_API_KEY
22
  model_name: Annotated[
@@ -39,11 +40,6 @@ class BaseLLMProcessor(BaseProcessor):
39
  float,
40
  "The ratio to expand the image by when cropping.",
41
  ] = 0.01
42
- gemini_rewriting_prompt: Annotated[
43
- str,
44
- "The prompt to use for rewriting text.",
45
- "Default is a string containing the Gemini rewriting prompt."
46
- ] = ''
47
  use_llm: Annotated[
48
  bool,
49
  "Whether to use the LLM model.",
@@ -84,10 +80,5 @@ class BaseLLMProcessor(BaseProcessor):
84
 
85
  pbar.close()
86
 
87
- def extract_image(self, page: PageGroup, image_block: Block):
88
- page_img = page.lowres_image
89
- image_box = image_block.polygon\
90
- .rescale(page.polygon.size, page_img.size)\
91
- .expand(self.image_expansion_ratio, self.image_expansion_ratio)
92
- cropped = page_img.crop(image_box.bbox)
93
- return cropped
 
1
+ import traceback
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
  from typing import Annotated, Optional
4
 
 
17
  A processor for using LLMs to convert blocks.
18
  """
19
  google_api_key: Annotated[
20
+ str,
21
  "The Google API key to use for the Gemini model.",
22
  ] = settings.GOOGLE_API_KEY
23
  model_name: Annotated[
 
40
  float,
41
  "The ratio to expand the image by when cropping.",
42
  ] = 0.01
 
 
 
 
 
43
  use_llm: Annotated[
44
  bool,
45
  "Whether to use the LLM model.",
 
80
 
81
  pbar.close()
82
 
83
+ def extract_image(self, document: Document, image_block: Block):
84
+ return image_block.get_image(document, highres=False, expansion=(self.image_expansion_ratio, self.image_expansion_ratio))
 
 
 
 
 
marker/processors/llm/llm_complex.py CHANGED
@@ -12,9 +12,9 @@ from marker.schema.groups.page import PageGroup
12
 
13
  class LLMComplexRegionProcessor(BaseLLMProcessor):
14
  block_types = (BlockTypes.ComplexRegion,)
15
- gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
16
  You will receive an image of a text block and the text that can be extracted from the image.
17
- Your task is to correct any errors in the text, and format it properly.
18
 
19
  Formatting should be in markdown, with the following rules:
20
  - * for italics, ** for bold, and ` for inline code.
@@ -29,27 +29,32 @@ Formatting should be in markdown, with the following rules:
29
 
30
  **Instructions:**
31
  1. Carefully examine the provided block image.
32
- 2. Analyze the text representation
33
- 3. If the text representation is largely correct, then write "No corrections needed."
34
- 4. If the text representation contains errors, generate the corrected markdown representation.
35
- 5. Output only either the corrected markdown representation or "No corrections needed."
36
  **Example:**
37
  Input:
38
  ```text
39
- This is an example text block.
40
  ```
41
  Output:
42
  ```markdown
43
- No corrections needed.
 
 
 
 
 
44
  ```
45
  **Input:**
 
 
 
46
  """
47
 
48
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
49
  text = block.raw_text(document)
50
-
51
- prompt = self.gemini_rewriting_prompt + '```text\n`' + text + '`\n```\n'
52
- image = self.extract_image(page, block)
53
  response_schema = content.Schema(
54
  type=content.Type.OBJECT,
55
  enum=[],
@@ -79,4 +84,5 @@ No corrections needed.
79
  return
80
 
81
  # Convert LLM markdown to html
 
82
  block.html = markdown2.markdown(corrected_markdown)
 
12
 
13
  class LLMComplexRegionProcessor(BaseLLMProcessor):
14
  block_types = (BlockTypes.ComplexRegion,)
15
+ complex_region_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
16
  You will receive an image of a text block and the text that can be extracted from the image.
17
+ Your task is to generate markdown to properly represent the content of the image. Do not omit any text present in the image - make sure everything is included in the markdown representation. The markdown representation should be as faithful to the original image as possible.
18
 
19
  Formatting should be in markdown, with the following rules:
20
  - * for italics, ** for bold, and ` for inline code.
 
29
 
30
  **Instructions:**
31
  1. Carefully examine the provided block image.
32
+ 2. Analyze the existing text representation.
33
+ 3. Generate the markdown representation of the content in the image.
 
 
34
  **Example:**
35
  Input:
36
  ```text
37
+ Table 1: Car Sales
38
  ```
39
  Output:
40
  ```markdown
41
+ ## Table 1: Car Sales
42
+
43
+ | Car | Sales |
44
+ | --- | --- |
45
+ | Honda | 100 |
46
+ | Toyota | 200 |
47
  ```
48
  **Input:**
49
+ ```text
50
+ {extracted_text}
51
+ ```
52
  """
53
 
54
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
55
  text = block.raw_text(document)
56
+ prompt = self.complex_region_prompt.replace("{extracted_text}", text)
57
+ image = self.extract_image(document, block)
 
58
  response_schema = content.Schema(
59
  type=content.Type.OBJECT,
60
  enum=[],
 
84
  return
85
 
86
  # Convert LLM markdown to html
87
+ corrected_markdown = corrected_markdown.strip().lstrip("```markdown").rstrip("```").strip()
88
  block.html = markdown2.markdown(corrected_markdown)
marker/processors/llm/llm_equation.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from marker.processors.llm import BaseLLMProcessor
2
+
3
+ from google.ai.generativelanguage_v1beta.types import content
4
+
5
+ from marker.schema import BlockTypes
6
+ from marker.schema.blocks import Equation
7
+ from marker.schema.document import Document
8
+ from marker.schema.groups.page import PageGroup
9
+
10
+ from typing import Annotated
11
+
12
+
13
+ class LLMEquationProcessor(BaseLLMProcessor):
14
+ block_types = (BlockTypes.Equation,)
15
+ min_equation_height: Annotated[
16
+ float,
17
+ "The minimum ratio between equation height and page height to consider for processing.",
18
+ ] = 0.1
19
+ equation_latex_prompt: Annotated[
20
+ str,
21
+ "The prompt to use for generating LaTeX from equations.",
22
+ "Default is a string containing the Gemini prompt."
23
+ ] = """You're an expert mathematician who is good at writing LaTeX code for equations'.
24
+ You will receive an image of a math block that may contain one or more equations. Your job is to write the LaTeX code for the equation, along with markdown for any other text.
25
+
26
+ Some guidelines:
27
+ - Keep the LaTeX code simple and concise.
28
+ - Make it KaTeX compatible.
29
+ - Use $$ as a block equation delimiter and $ for inline equations. Block equations should also be on their own line. Do not use any other delimiters.
30
+ - You can include text in between equation blocks as needed. Try to put long text segments into plain text and not inside the equations.
31
+
32
+ **Instructions:**
33
+ 1. Carefully examine the provided image.
34
+ 2. Analyze the existing markdown, which may include LaTeX code.
35
+ 3. If the markdown and LaTeX are correct, write "No corrections needed."
36
+ 4. If the markdown and LaTeX are incorrect, generate the corrected markdown and LaTeX.
37
+ 5. Output only the corrected text or "No corrections needed."
38
+ **Example:**
39
+ Input:
40
+ ```markdown
41
+ Equation 1:
42
+ $$x^2 + y^2 = z2$$
43
+ ```
44
+ Output:
45
+ ```markdown
46
+ Equation 1:
47
+ $$x^2 + y^2 = z^2$$
48
+ ```
49
+ **Input:**
50
+ ```markdown
51
+ {equation}
52
+ ```
53
+ """
54
+
55
+ def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
56
+ text = block.latex if block.latex else block.raw_text(document)
57
+ prompt = self.equation_latex_prompt.replace("{equation}", text)
58
+
59
+ image = self.extract_image(document, block)
60
+ response_schema = content.Schema(
61
+ type=content.Type.OBJECT,
62
+ enum=[],
63
+ required=["markdown_equation"],
64
+ properties={
65
+ "markdown_equation": content.Schema(
66
+ type=content.Type.STRING
67
+ )
68
+ },
69
+ )
70
+
71
+ response = self.model.generate_response(prompt, image, block, response_schema)
72
+
73
+ if not response or "markdown_equation" not in response:
74
+ block.update_metadata(llm_error_count=1)
75
+ return
76
+
77
+ markdown_equation = response["markdown_equation"]
78
+ if len(markdown_equation) < len(text) * .5:
79
+ block.update_metadata(llm_error_count=1)
80
+ return
81
+
82
+ block.latex = markdown_equation
marker/processors/llm/llm_form.py CHANGED
@@ -1,9 +1,6 @@
1
- import markdown2
2
-
3
  from marker.processors.llm import BaseLLMProcessor
4
 
5
  from google.ai.generativelanguage_v1beta.types import content
6
- from tabled.formats import markdown_format
7
 
8
  from marker.schema import BlockTypes
9
  from marker.schema.blocks import Block
@@ -13,48 +10,75 @@ from marker.schema.groups.page import PageGroup
13
 
14
  class LLMFormProcessor(BaseLLMProcessor):
15
  block_types = (BlockTypes.Form,)
16
- gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
17
- You will receive an image of a text block and a markdown representation of the form in the image.
18
- Your task is to correct any errors in the markdown representation, and format it properly.
19
- Values and labels should appear in markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
20
  **Instructions:**
21
  1. Carefully examine the provided form block image.
22
- 2. Analyze the markdown representation of the form.
23
- 3. If the markdown representation is largely correct, then write "No corrections needed."
24
- 4. If the markdown representation contains errors, generate the corrected markdown representation.
25
- 5. Output only either the corrected markdown representation or "No corrections needed."
26
  **Example:**
27
  Input:
28
- ```markdown
29
- | Label 1 | Label 2 | Label 3 |
30
- |----------|----------|----------|
31
- | Value 1 | Value 2 | Value 3 |
 
 
 
 
 
 
 
 
 
32
  ```
33
  Output:
34
- ```markdown
35
- | Labels | Values |
36
- |--------|--------|
37
- | Label 1 | Value 1 |
38
- | Label 2 | Value 2 |
39
- | Label 3 | Value 3 |
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ```
41
  **Input:**
 
 
 
42
  """
43
 
44
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
45
- cells = block.cells
46
- if cells is None:
47
  # Happens if table/form processors didn't run
48
  return
49
 
50
- prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
51
- image = self.extract_image(page, block)
 
 
52
  response_schema = content.Schema(
53
  type=content.Type.OBJECT,
54
  enum=[],
55
- required=["corrected_markdown"],
56
  properties={
57
- "corrected_markdown": content.Schema(
58
  type=content.Type.STRING
59
  )
60
  },
@@ -62,22 +86,20 @@ Output:
62
 
63
  response = self.model.generate_response(prompt, image, block, response_schema)
64
 
65
- if not response or "corrected_markdown" not in response:
66
  block.update_metadata(llm_error_count=1)
67
  return
68
 
69
- corrected_markdown = response["corrected_markdown"]
70
 
71
  # The original table is okay
72
- if "no corrections" in corrected_markdown.lower():
73
  return
74
 
75
- orig_cell_text = "".join([cell.text for cell in cells])
76
-
77
  # Potentially a partial response
78
- if len(corrected_markdown) < len(orig_cell_text) * .5:
79
  block.update_metadata(llm_error_count=1)
80
  return
81
 
82
- # Convert LLM markdown to html
83
- block.html = markdown2.markdown(corrected_markdown)
 
 
 
1
  from marker.processors.llm import BaseLLMProcessor
2
 
3
  from google.ai.generativelanguage_v1beta.types import content
 
4
 
5
  from marker.schema import BlockTypes
6
  from marker.schema.blocks import Block
 
10
 
11
  class LLMFormProcessor(BaseLLMProcessor):
12
  block_types = (BlockTypes.Form,)
13
+ form_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
14
+ You will receive an image of a text block and an html representation of the form in the image.
15
+ Your task is to correct any errors in the html representation, and format it properly.
16
+ Values and labels should appear in html tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables. Only use the tags `table, p, span, i, b, th, td, tr, and div`. Do not omit any text from the form - make sure everything is included in the html representation. It should be as faithful to the original form as possible.
17
  **Instructions:**
18
  1. Carefully examine the provided form block image.
19
+ 2. Analyze the html representation of the form.
20
+ 3. If the html representation is largely correct, then write "No corrections needed."
21
+ 4. If the html representation contains errors, generate the corrected html representation.
22
+ 5. Output only either the corrected html representation or "No corrections needed."
23
  **Example:**
24
  Input:
25
+ ```html
26
+ <table>
27
+ <tr>
28
+ <td>Label 1</td>
29
+ <td>Label 2</td>
30
+ <td>Label 3</td>
31
+ </tr>
32
+ <tr>
33
+ <td>Value 1</td>
34
+ <td>Value 2</td>
35
+ <td>Value 3</td>
36
+ </tr>
37
+ </table>
38
  ```
39
  Output:
40
+ ```html
41
+ <table>
42
+ <tr>
43
+ <th>Labels</th>
44
+ <th>Values</th>
45
+ </tr>
46
+ <tr>
47
+ <td>Label 1</td>
48
+ <td>Value 1</td>
49
+ </tr>
50
+ <tr>
51
+ <td>Label 2</td>
52
+ <td>Value 2</td>
53
+ </tr>
54
+ <tr>
55
+ <td>Label 3</td>
56
+ <td>Value 3</td>
57
+ </tr>
58
+ </table>
59
  ```
60
  **Input:**
61
+ ```html
62
+ {block_html}
63
+ ```
64
  """
65
 
66
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
67
+ children = block.contained_blocks(document, (BlockTypes.TableCell,))
68
+ if not children:
69
  # Happens if table/form processors didn't run
70
  return
71
 
72
+ block_html = block.render(document).html
73
+ prompt = self.form_rewriting_prompt.replace("{block_html}", block_html)
74
+
75
+ image = self.extract_image(document, block)
76
  response_schema = content.Schema(
77
  type=content.Type.OBJECT,
78
  enum=[],
79
+ required=["corrected_html"],
80
  properties={
81
+ "corrected_html": content.Schema(
82
  type=content.Type.STRING
83
  )
84
  },
 
86
 
87
  response = self.model.generate_response(prompt, image, block, response_schema)
88
 
89
+ if not response or "corrected_html" not in response:
90
  block.update_metadata(llm_error_count=1)
91
  return
92
 
93
+ corrected_html = response["corrected_html"]
94
 
95
  # The original table is okay
96
+ if "no corrections" in corrected_html.lower():
97
  return
98
 
 
 
99
  # Potentially a partial response
100
+ if len(corrected_html) < len(block_html) * .33:
101
  block.update_metadata(llm_error_count=1)
102
  return
103
 
104
+ corrected_html = corrected_html.strip().lstrip("```html").rstrip("```").strip()
105
+ block.html = corrected_html
marker/processors/llm/llm_handwriting.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import markdown2
2
+
3
+ from marker.processors.llm import BaseLLMProcessor
4
+
5
+ from google.ai.generativelanguage_v1beta.types import content
6
+
7
+ from marker.schema import BlockTypes
8
+ from marker.schema.blocks import Equation
9
+ from marker.schema.document import Document
10
+ from marker.schema.groups.page import PageGroup
11
+
12
+ from typing import Annotated
13
+
14
+
15
+ class LLMHandwritingProcessor(BaseLLMProcessor):
16
+ block_types = (BlockTypes.Equation,)
17
+ min_handwriting_height: Annotated[
18
+ float,
19
+ "The minimum ratio between handwriting height and page height to consider for processing.",
20
+ ] = 0.1
21
+ handwriting_generation_prompt: Annotated[
22
+ str,
23
+ "The prompt to use for OCRing handwriting.",
24
+ "Default is a string containing the Gemini prompt."
25
+ ] = """You are an expert editor specializing in accurately reproducing text from images.
26
+ You will receive an image of a text block, along with the text that can be extracted. Your task is to generate markdown to properly represent the content of the image. Do not omit any text present in the image - make sure everything is included in the markdown representation. The markdown representation should be as faithful to the original image as possible.
27
+
28
+ Formatting should be in markdown, with the following rules:
29
+ - * for italics, ** for bold, and ` for inline code.
30
+ - Headers should be formatted with #, with one # for the largest header, and up to 6 for the smallest.
31
+ - Lists should be formatted with either - or 1. for unordered and ordered lists, respectively.
32
+ - Links should be formatted with [text](url).
33
+ - Use ``` for code blocks.
34
+ - Inline math should be formatted with <math>math expression</math>.
35
+ - Display math should be formatted with <math display="block">math expression</math>.
36
+ - Values and labels should be extracted from forms, and put into markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
37
+ - Tables should be formatted with markdown tables, with the headers bolded.
38
+
39
+ **Instructions:**
40
+ 1. Carefully examine the provided block image.
41
+ 2. Analyze the existing text representation.
42
+ 3. Output the markdown representing the content of the image.
43
+ **Example:**
44
+ Input:
45
+ ```text
46
+ This i sm handwritting.
47
+ ```
48
+ Output:
49
+ ```markdown
50
+ This is some *handwriting*.
51
+ ```
52
+ **Input:**
53
+ ```text
54
+ {extracted_text}
55
+ ```
56
+ """
57
+
58
+ def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
59
+ text = block.raw_text(document)
60
+ prompt = self.handwriting_generation_prompt.replace("{handwriting_text}", text)
61
+
62
+ image = self.extract_image(document, block)
63
+ response_schema = content.Schema(
64
+ type=content.Type.OBJECT,
65
+ enum=[],
66
+ required=["markdown"],
67
+ properties={
68
+ "markdown": content.Schema(
69
+ type=content.Type.STRING
70
+ )
71
+ },
72
+ )
73
+
74
+ response = self.model.generate_response(prompt, image, block, response_schema)
75
+
76
+ if not response or "markdown" not in response:
77
+ block.update_metadata(llm_error_count=1)
78
+ return
79
+
80
+ markdown = response["markdown"]
81
+ if len(markdown) < len(text) * .5:
82
+ block.update_metadata(llm_error_count=1)
83
+ return
84
+
85
+ markdown = markdown.strip().lstrip("```markdown").rstrip("```").strip()
86
+ block.html = markdown2.markdown(markdown)
marker/processors/llm/llm_image_description.py CHANGED
@@ -36,6 +36,9 @@ Apples, Bananas, Oranges
36
  Output:
37
  In this figure, a bar chart titled "Fruit Preference Survey" is showing the number of people who prefer different types of fruits. The x-axis shows the types of fruits, and the y-axis shows the number of people. The bar chart shows that most people prefer apples, followed by bananas and oranges. 20 people prefer apples, 15 people prefer bananas, and 10 people prefer oranges.
38
  **Input:**
 
 
 
39
  """
40
 
41
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
@@ -44,8 +47,8 @@ In this figure, a bar chart titled "Fruit Preference Survey" is showing the numb
44
  # Since this processor replaces images with descriptions
45
  return
46
 
47
- prompt = self.image_description_prompt + '```text\n`' + block.raw_text(document) + '`\n```\n'
48
- image = self.extract_image(page, block)
49
  response_schema = content.Schema(
50
  type=content.Type.OBJECT,
51
  enum=[],
 
36
  Output:
37
  In this figure, a bar chart titled "Fruit Preference Survey" is showing the number of people who prefer different types of fruits. The x-axis shows the types of fruits, and the y-axis shows the number of people. The bar chart shows that most people prefer apples, followed by bananas and oranges. 20 people prefer apples, 15 people prefer bananas, and 10 people prefer oranges.
38
  **Input:**
39
+ ```text
40
+ {raw_text}
41
+ ```
42
  """
43
 
44
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
 
47
  # Since this processor replaces images with descriptions
48
  return
49
 
50
+ prompt = self.image_description_prompt.replace("{raw_text}", block.raw_text(document))
51
+ image = self.extract_image(document, block)
52
  response_schema = content.Schema(
53
  type=content.Type.OBJECT,
54
  enum=[],
marker/processors/llm/llm_table.py CHANGED
@@ -2,12 +2,10 @@ from typing import Annotated, List, Tuple
2
 
3
  from bs4 import BeautifulSoup
4
  from google.ai.generativelanguage_v1beta.types import content
5
- from tabled.formats import html_format
6
- from tabled.schema import SpanTableCell
7
 
8
  from marker.processors.llm import BaseLLMProcessor
9
  from marker.schema import BlockTypes
10
- from marker.schema.blocks import Block
11
  from marker.schema.document import Document
12
  from marker.schema.groups.page import PageGroup
13
  from marker.schema.polygon import PolygonBox
@@ -17,19 +15,26 @@ class LLMTableProcessor(BaseLLMProcessor):
17
  block_types: Annotated[
18
  Tuple[BlockTypes],
19
  "The block types to process.",
20
- ] = (BlockTypes.Table,)
21
- gemini_rewriting_prompt: Annotated[
22
  str,
23
  "The prompt to use for rewriting text.",
24
  "Default is a string containing the Gemini rewriting prompt."
25
  ] = """You are a text correction expert specializing in accurately reproducing text from images.
26
  You will receive an image of a text block and an html representation of the table in the image.
27
  Your task is to correct any errors in the html representation. The html representation should be as faithful to the original table as possible.
 
 
 
 
 
 
 
28
  **Instructions:**
29
  1. Carefully examine the provided text block image.
30
  2. Analyze the html representation of the table.
31
  3. If the html representation is largely correct, then write "No corrections needed."
32
- 4. If the html representation contains errors, generate the corrected html representation. Only use the tags th, td, tr, and table. Only use the attributes colspan and rowspan if necessary.
33
  5. Output only either the corrected html representation or "No corrections needed."
34
  **Example:**
35
  Input:
@@ -52,16 +57,21 @@ Output:
52
  No corrections needed.
53
  ```
54
  **Input:**
 
 
 
55
  """
56
 
57
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
58
- cells = block.cells
59
- if cells is None:
60
  # Happens if table/form processors didn't run
61
  return
62
 
63
- prompt = self.gemini_rewriting_prompt + '```html\n`' + html_format(cells) + '`\n```\n'
64
- image = self.extract_image(page, block)
 
 
65
  response_schema = content.Schema(
66
  type=content.Type.OBJECT,
67
  enum=[],
@@ -85,31 +95,49 @@ No corrections needed.
85
  if "no corrections" in corrected_html.lower():
86
  return
87
 
88
- parsed_cells = self.parse_html_table(corrected_html, block)
 
89
  if len(parsed_cells) <= 1:
90
  block.update_metadata(llm_error_count=1)
91
  return
92
 
93
  parsed_cell_text = "".join([cell.text for cell in parsed_cells])
94
- orig_cell_text = "".join([cell.text for cell in cells])
95
-
96
  # Potentially a partial response
97
  if len(parsed_cell_text) < len(orig_cell_text) * .5:
98
  block.update_metadata(llm_error_count=1)
99
  return
100
 
101
- block.cells = parsed_cells
 
 
 
102
 
103
- def parse_html_table(self, html_text: str, block: Block) -> List[SpanTableCell]:
 
 
 
 
 
 
 
104
  soup = BeautifulSoup(html_text, 'html.parser')
105
  table = soup.find('table')
106
 
107
  # Initialize grid
108
  rows = table.find_all('tr')
109
  cells = []
110
- max_cols = max(len(row.find_all(['td', 'th'])) for row in rows)
111
- if max_cols == 0:
112
- return []
 
 
 
 
 
 
 
 
113
 
114
  grid = [[True] * max_cols for _ in range(len(rows))]
115
 
@@ -124,7 +152,7 @@ No corrections needed.
124
  print("Table parsing warning: too many columns found")
125
  break
126
 
127
- cell_text = cell.text.strip()
128
  rowspan = min(int(cell.get('rowspan', 1)), len(rows) - i)
129
  colspan = min(int(cell.get('colspan', 1)), max_cols - cur_col)
130
  cell_rows = list(range(i, i + rowspan))
@@ -146,11 +174,15 @@ No corrections needed.
146
  ]
147
  cell_polygon = PolygonBox.from_bbox(cell_bbox)
148
 
149
- cell_obj = SpanTableCell(
150
  text=cell_text,
151
- row_ids=cell_rows,
152
- col_ids=cell_cols,
153
- bbox=cell_polygon.bbox
 
 
 
 
154
  )
155
  cells.append(cell_obj)
156
  cur_col += colspan
 
2
 
3
  from bs4 import BeautifulSoup
4
  from google.ai.generativelanguage_v1beta.types import content
 
 
5
 
6
  from marker.processors.llm import BaseLLMProcessor
7
  from marker.schema import BlockTypes
8
+ from marker.schema.blocks import Block, TableCell
9
  from marker.schema.document import Document
10
  from marker.schema.groups.page import PageGroup
11
  from marker.schema.polygon import PolygonBox
 
15
  block_types: Annotated[
16
  Tuple[BlockTypes],
17
  "The block types to process.",
18
+ ] = (BlockTypes.Table, BlockTypes.TableOfContents)
19
+ table_rewriting_prompt: Annotated[
20
  str,
21
  "The prompt to use for rewriting text.",
22
  "Default is a string containing the Gemini rewriting prompt."
23
  ] = """You are a text correction expert specializing in accurately reproducing text from images.
24
  You will receive an image of a text block and an html representation of the table in the image.
25
  Your task is to correct any errors in the html representation. The html representation should be as faithful to the original table as possible.
26
+
27
+ Some guidelines:
28
+ - Make sure to reproduce the original values as faithfully as possible.
29
+ - If you see any math in a table cell, fence it with the <math display="inline"> tag. Block math should be fenced with <math display="block">.
30
+ - Replace any images with a description, like "Image: [description]".
31
+ - Only use the tags th, td, tr, span, i, b, math, and table. Only use the attributes display, style, colspan, and rowspan if necessary.
32
+
33
  **Instructions:**
34
  1. Carefully examine the provided text block image.
35
  2. Analyze the html representation of the table.
36
  3. If the html representation is largely correct, then write "No corrections needed."
37
+ 4. If the html representation contains errors, generate the corrected html representation.
38
  5. Output only either the corrected html representation or "No corrections needed."
39
  **Example:**
40
  Input:
 
57
  No corrections needed.
58
  ```
59
  **Input:**
60
+ ```html
61
+ {block_html}
62
+ ```
63
  """
64
 
65
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
66
+ children = block.contained_blocks(document, (BlockTypes.TableCell,))
67
+ if not children:
68
  # Happens if table/form processors didn't run
69
  return
70
 
71
+ block_html = block.render(document).html
72
+ prompt = self.table_rewriting_prompt.replace("{block_html}", block_html)
73
+
74
+ image = self.extract_image(document, block)
75
  response_schema = content.Schema(
76
  type=content.Type.OBJECT,
77
  enum=[],
 
95
  if "no corrections" in corrected_html.lower():
96
  return
97
 
98
+ corrected_html = corrected_html.strip().lstrip("```html").rstrip("```").strip()
99
+ parsed_cells = self.parse_html_table(corrected_html, block, page)
100
  if len(parsed_cells) <= 1:
101
  block.update_metadata(llm_error_count=1)
102
  return
103
 
104
  parsed_cell_text = "".join([cell.text for cell in parsed_cells])
105
+ orig_cell_text = "".join([cell.text for cell in children])
 
106
  # Potentially a partial response
107
  if len(parsed_cell_text) < len(orig_cell_text) * .5:
108
  block.update_metadata(llm_error_count=1)
109
  return
110
 
111
+ block.structure = []
112
+ for cell in parsed_cells:
113
+ page.add_full_block(cell)
114
+ block.add_structure(cell)
115
 
116
+ @staticmethod
117
+ def get_cell_text(element, keep_tags=('br',)):
118
+ for tag in element.find_all(True):
119
+ if tag.name not in keep_tags:
120
+ tag.unwrap()
121
+ return element.decode_contents().replace("<br>", "\n")
122
+
123
+ def parse_html_table(self, html_text: str, block: Block, page: PageGroup) -> List[TableCell]:
124
  soup = BeautifulSoup(html_text, 'html.parser')
125
  table = soup.find('table')
126
 
127
  # Initialize grid
128
  rows = table.find_all('tr')
129
  cells = []
130
+
131
+ # Find maximum number of columns in colspan-aware way
132
+ max_cols = 0
133
+ for row in rows:
134
+ row_tds = row.find_all(['td', 'th'])
135
+ curr_cols = 0
136
+ for cell in row_tds:
137
+ colspan = int(cell.get('colspan', 1))
138
+ curr_cols += colspan
139
+ if curr_cols > max_cols:
140
+ max_cols = curr_cols
141
 
142
  grid = [[True] * max_cols for _ in range(len(rows))]
143
 
 
152
  print("Table parsing warning: too many columns found")
153
  break
154
 
155
+ cell_text = self.get_cell_text(cell).strip()
156
  rowspan = min(int(cell.get('rowspan', 1)), len(rows) - i)
157
  colspan = min(int(cell.get('colspan', 1)), max_cols - cur_col)
158
  cell_rows = list(range(i, i + rowspan))
 
174
  ]
175
  cell_polygon = PolygonBox.from_bbox(cell_bbox)
176
 
177
+ cell_obj = TableCell(
178
  text=cell_text,
179
+ row_id=i,
180
+ col_id=cur_col,
181
+ rowspan=rowspan,
182
+ colspan=colspan,
183
+ is_header=cell.name == 'th',
184
+ polygon=cell_polygon,
185
+ page_id=page.page_id,
186
  )
187
  cells.append(cell_obj)
188
  cur_col += colspan
marker/processors/llm/llm_table_merge.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ from typing import Annotated, List, Tuple, Literal
3
+
4
+ from google.ai.generativelanguage_v1beta.types import content
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ from marker.processors.llm import BaseLLMProcessor
9
+ from marker.schema import BlockTypes
10
+ from marker.schema.blocks import Block, TableCell
11
+ from marker.schema.document import Document
12
+
13
+
14
+ class LLMTableMergeProcessor(BaseLLMProcessor):
15
+ block_types: Annotated[
16
+ Tuple[BlockTypes],
17
+ "The block types to process.",
18
+ ] = (BlockTypes.Table, BlockTypes.TableOfContents)
19
+ table_height_threshold: Annotated[
20
+ float,
21
+ "The minimum height ratio relative to the page for the first table in a pair to be considered for merging.",
22
+ ] = 0.6
23
+ table_start_threshold: Annotated[
24
+ float,
25
+ "The maximum percentage down the page the second table can start to be considered for merging."
26
+ ] = 0.2
27
+ vertical_table_height_threshold: Annotated[
28
+ float,
29
+ "The height tolerance for 2 adjacent tables to be merged into one."
30
+ ] = 0.25
31
+ vertical_table_distance_threshold: Annotated[
32
+ int,
33
+ "The maximum distance between table edges for adjacency."
34
+ ] = 20
35
+ column_gap_threshold: Annotated[
36
+ int,
37
+ "The maximum gap between columns to merge tables"
38
+ ] = 50
39
+ table_merge_prompt: Annotated[
40
+ str,
41
+ "The prompt to use for rewriting text.",
42
+ "Default is a string containing the Gemini rewriting prompt."
43
+ ] = """You're a text correction expert specializing in accurately reproducing tables from PDFs.
44
+ You'll receive two images of tables from successive pages of a PDF. Table 1 is from the first page, and Table 2 is from the second page. Both tables may actually be part of the same larger table. Your job is to decide if Table 2 should be merged with Table 1, and how they should be joined. The should only be merged if they're part of the same larger table, and Table 2 cannot be interpreted without merging.
45
+
46
+ You'll specify your judgement in json format - first whether Table 2 should be merged with Table 1, then the direction of the merge, either `bottom` or `right`. A bottom merge means that the rows of Table 2 are joined to the rows of Table 1. A right merge means that the columns of Table 2 are joined to the columns of Table 1. (bottom merge is equal to np.vstack, right merge is equal to np.hstack)
47
+
48
+ Table 2 should be merged at the bottom of Table 1 if Table 2 has no headers, and the rows have similar values, meaning that Table 2 continues Table 1. Table 2 should be merged to the right of Table 1 if each row in Table 2 matches a row in Table 1, meaning that Table 2 contains additional columns that augment Table 1.
49
+
50
+ Only merge Table 1 and Table 2 if Table 2 cannot be interpreted without merging.
51
+
52
+ **Instructions:**
53
+ 1. Carefully examine the provided table images. Table 1 is the first image, and Table 2 is the second image.
54
+ 2. Examine the provided html representations of Table 1 and Table 2.
55
+ 3. Write a description of Table 1.
56
+ 4. Write a description of Table 2.
57
+ 5. Analyze whether Table 2 should be merged into Table 1, and write an explanation.
58
+ 6. Output your decision on whether they should be merged, and merge direction.
59
+ **Example:**
60
+ Input:
61
+ Table 1
62
+ ```html
63
+ <table>
64
+ <tr>
65
+ <th>Name</th>
66
+ <th>Age</th>
67
+ <th>City</th>
68
+ <th>State</th>
69
+ </tr>
70
+ <tr>
71
+ <td>John</td>
72
+ <td>25</td>
73
+ <td>Chicago</td>
74
+ <td>IL</td>
75
+ </tr>
76
+ ```
77
+ Table 2
78
+ ```html
79
+ <table>
80
+ <tr>
81
+ <td>Jane</td>
82
+ <td>30</td>
83
+ <td>Los Angeles</td>
84
+ <td>CA</td>
85
+ </tr>
86
+ ```
87
+ Output:
88
+ ```json
89
+ {
90
+ "table1_description": "Table 1 has 4 headers, and 1 row. The headers are Name, Age, City, and State.",
91
+ "table2_description": "Table 2 has no headers, but the values appear to represent a person's name, age, city, and state.",
92
+ "explanation": "The values in Table 2 match the headers in Table 1, and Table 2 has no headers. Table 2 should be merged to the bottom of Table 1.",
93
+ "merge": "true",
94
+ "direction": "bottom"
95
+ }
96
+ ```
97
+ **Input:**
98
+ Table 1
99
+ ```html
100
+ {{table1}}
101
+ Table 2
102
+ ```html
103
+ {{table2}}
104
+ ```
105
+ """
106
+
107
+ @staticmethod
108
+ def get_row_count(cells: List[TableCell]):
109
+ max_rows = None
110
+ for col_id in set([cell.col_id for cell in cells]):
111
+ col_cells = [cell for cell in cells if cell.col_id == col_id]
112
+ rows = 0
113
+ for cell in col_cells:
114
+ rows += cell.rowspan
115
+ if max_rows is None or rows > max_rows:
116
+ max_rows = rows
117
+ return max_rows
118
+
119
+ @staticmethod
120
+ def get_column_count(cells: List[TableCell]):
121
+ max_cols = None
122
+ for row_id in set([cell.row_id for cell in cells]):
123
+ row_cells = [cell for cell in cells if cell.row_id == row_id]
124
+ cols = 0
125
+ for cell in row_cells:
126
+ cols += cell.colspan
127
+ if max_cols is None or cols > max_cols:
128
+ max_cols = cols
129
+ return max_cols
130
+
131
+ def rewrite_blocks(self, document: Document):
132
+ pbar = tqdm(desc=f"{self.__class__.__name__} running")
133
+ table_runs = []
134
+ table_run = []
135
+ prev_block = None
136
+ prev_page_block_count = None
137
+ for page in document.pages:
138
+ page_blocks = page.contained_blocks(document, self.block_types)
139
+ for block in page_blocks:
140
+ merge_condition = False
141
+ if prev_block is not None:
142
+ prev_cells = prev_block.contained_blocks(document, (BlockTypes.TableCell,))
143
+ curr_cells = block.contained_blocks(document, (BlockTypes.TableCell,))
144
+ row_match = abs(self.get_row_count(prev_cells) - self.get_row_count(curr_cells)) < 5, # Similar number of rows
145
+ col_match = abs(self.get_column_count(prev_cells) - self.get_column_count(curr_cells)) < 2
146
+
147
+ subsequent_page_table = all([
148
+ prev_block.page_id == block.page_id - 1, # Subsequent pages
149
+ max(prev_block.polygon.height / page.polygon.height,
150
+ block.polygon.height / page.polygon.height) > self.table_height_threshold, # Take up most of the page height
151
+ (len(page_blocks) == 1 or prev_page_block_count == 1), # Only table on the page
152
+ (row_match or col_match)
153
+ ])
154
+
155
+ same_page_vertical_table = all([
156
+ prev_block.page_id == block.page_id, # On the same page
157
+ (1 - self.vertical_table_height_threshold) < prev_block.polygon.height / block.polygon.height < (1 + self.vertical_table_height_threshold), # Similar height
158
+ abs(block.polygon.x_start - prev_block.polygon.x_end) < self.vertical_table_distance_threshold, # Close together in x
159
+ abs(block.polygon.y_start - prev_block.polygon.y_start) < self.vertical_table_distance_threshold, # Close together in y
160
+ row_match
161
+ ])
162
+
163
+ same_page_new_column = all([
164
+ prev_block.page_id == block.page_id, # On the same page
165
+ abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
166
+ block.polygon.y_start < prev_block.polygon.y_end,
167
+ block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
168
+ col_match
169
+ ])
170
+ merge_condition = any([subsequent_page_table, same_page_vertical_table, same_page_new_column])
171
+
172
+ if prev_block is not None and merge_condition:
173
+ if prev_block not in table_run:
174
+ table_run.append(prev_block)
175
+ table_run.append(block)
176
+ else:
177
+ if table_run:
178
+ table_runs.append(table_run)
179
+ table_run = []
180
+ prev_block = block
181
+ prev_page_block_count = len(page_blocks)
182
+
183
+ if table_run:
184
+ table_runs.append(table_run)
185
+
186
+ with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
187
+ for future in as_completed([
188
+ executor.submit(self.process_rewriting, document, blocks)
189
+ for blocks in table_runs
190
+ ]):
191
+ future.result() # Raise exceptions if any occurred
192
+ pbar.update(1)
193
+
194
+ pbar.close()
195
+
196
+ def process_rewriting(self, document: Document, blocks: List[Block]):
197
+ if len(blocks) < 2:
198
+ # Can't merge single tables
199
+ return
200
+
201
+ start_block = blocks[0]
202
+ for i in range(1, len(blocks)):
203
+ curr_block = blocks[i]
204
+ children = start_block.contained_blocks(document, (BlockTypes.TableCell,))
205
+ children_curr = curr_block.contained_blocks(document, (BlockTypes.TableCell,))
206
+ if not children or not children_curr:
207
+ # Happens if table/form processors didn't run
208
+ break
209
+
210
+ start_image = start_block.get_image(document, highres=False)
211
+ curr_image = curr_block.get_image(document, highres=False)
212
+ start_html = start_block.render(document).html
213
+ curr_html = curr_block.render(document).html
214
+
215
+ prompt = self.table_merge_prompt.replace("{{table1}}", start_html).replace("{{table2}}", curr_html)
216
+
217
+ response_schema = content.Schema(
218
+ type=content.Type.OBJECT,
219
+ enum=[],
220
+ required=["table1_description", "table2_description", "explanation", "merge", "direction"],
221
+ properties={
222
+ "table1_description": content.Schema(
223
+ type=content.Type.STRING
224
+ ),
225
+ "table2_description": content.Schema(
226
+ type=content.Type.STRING
227
+ ),
228
+ "explanation": content.Schema(
229
+ type=content.Type.STRING
230
+ ),
231
+ "merge": content.Schema(
232
+ type=content.Type.STRING,
233
+ enum=["true", "false"]
234
+ ),
235
+ "direction": content.Schema(
236
+ type=content.Type.STRING,
237
+ enum=["bottom", "right"]
238
+ ),
239
+ },
240
+ )
241
+
242
+ response = self.model.generate_response(
243
+ prompt,
244
+ [start_image, curr_image],
245
+ curr_block,
246
+ response_schema
247
+ )
248
+
249
+ if not response or ("direction" not in response or "merge" not in response):
250
+ curr_block.update_metadata(llm_error_count=1)
251
+ break
252
+
253
+ merge = response["merge"]
254
+
255
+ # The original table is okay
256
+ if "true" not in merge:
257
+ start_block = curr_block
258
+ continue
259
+
260
+ # Merge the cells and images of the tables
261
+ direction = response["direction"]
262
+ if not self.validate_merge(children, children_curr, direction):
263
+ start_block = curr_block
264
+ continue
265
+
266
+ merged_image = self.join_images(start_image, curr_image, direction)
267
+ merged_cells = self.join_cells(children, children_curr, direction)
268
+ curr_block.structure = []
269
+ start_block.structure = [b.id for b in merged_cells]
270
+ start_block.lowres_image = merged_image
271
+
272
+ def validate_merge(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right'):
273
+ if direction == "right":
274
+ # Check if the number of rows is the same
275
+ cells1_row_count = self.get_row_count(cells1)
276
+ cells2_row_count = self.get_row_count(cells2)
277
+ return abs(cells1_row_count - cells2_row_count) < 5
278
+ elif direction == "bottom":
279
+ # Check if the number of columns is the same
280
+ cells1_col_count = self.get_column_count(cells1)
281
+ cells2_col_count = self.get_column_count(cells2)
282
+ return abs(cells1_col_count - cells2_col_count) < 2
283
+
284
+
285
+ def join_cells(self, cells1: List[TableCell], cells2: List[TableCell], direction: Literal['right', 'bottom'] = 'right') -> List[TableCell]:
286
+ if direction == 'right':
287
+ # Shift columns right
288
+ col_count = self.get_column_count(cells1)
289
+ for cell in cells2:
290
+ cell.col_id += col_count
291
+ new_cells = cells1 + cells2
292
+ else:
293
+ # Shift rows up
294
+ row_count = self.get_row_count(cells1)
295
+ for cell in cells2:
296
+ cell.row_id += row_count
297
+ new_cells = cells1 + cells2
298
+ return new_cells
299
+
300
+ @staticmethod
301
+ def join_images(image1: Image.Image, image2: Image.Image, direction: Literal['right', 'bottom'] = 'right') -> Image.Image:
302
+ # Get dimensions
303
+ w1, h1 = image1.size
304
+ w2, h2 = image2.size
305
+
306
+ if direction == 'right':
307
+ new_height = max(h1, h2)
308
+ new_width = w1 + w2
309
+ new_img = Image.new('RGB', (new_width, new_height), 'white')
310
+ new_img.paste(image1, (0, 0))
311
+ new_img.paste(image2, (w1, 0))
312
+ else:
313
+ new_width = max(w1, w2)
314
+ new_height = h1 + h2
315
+ new_img = Image.new('RGB', (new_width, new_height), 'white')
316
+ new_img.paste(image1, (0, 0))
317
+ new_img.paste(image2, (0, h1))
318
+ return new_img
marker/processors/llm/llm_text.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
 
3
  from marker.processors.llm import BaseLLMProcessor
4
  from bs4 import BeautifulSoup
@@ -13,10 +14,10 @@ from marker.schema.text.span import Span
13
 
14
  class LLMTextProcessor(BaseLLMProcessor):
15
  block_types = (BlockTypes.TextInlineMath, BlockTypes.Handwriting)
16
- gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
17
  You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
18
  Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
19
- The number of output lines MUST match the number of input lines.
20
 
21
  **Instructions:**
22
 
@@ -64,7 +65,9 @@ Output:
64
  ```
65
 
66
  **Input:**
67
-
 
 
68
  """
69
 
70
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
@@ -73,8 +76,8 @@ Output:
73
  text_lines = block.contained_blocks(document, (BlockTypes.Line,))
74
  extracted_lines = [line.formatted_text(document) for line in text_lines]
75
 
76
- prompt = self.gemini_rewriting_prompt + '```json\n`' + json.dumps({"extracted_lines": extracted_lines}, indent=2) + '`\n```\n'
77
- image = self.extract_image(page, block)
78
  response_schema = content.Schema(
79
  type=content.Type.OBJECT,
80
  enum=[],
 
1
  import json
2
+ import textwrap
3
 
4
  from marker.processors.llm import BaseLLMProcessor
5
  from bs4 import BeautifulSoup
 
14
 
15
  class LLMTextProcessor(BaseLLMProcessor):
16
  block_types = (BlockTypes.TextInlineMath, BlockTypes.Handwriting)
17
+ text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
18
  You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
19
  Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
20
+ The number of output lines MUST match the number of input lines. Stay as faithful to the original text as possible.
21
 
22
  **Instructions:**
23
 
 
65
  ```
66
 
67
  **Input:**
68
+ ```json
69
+ {extracted_lines}
70
+ ```
71
  """
72
 
73
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
 
76
  text_lines = block.contained_blocks(document, (BlockTypes.Line,))
77
  extracted_lines = [line.formatted_text(document) for line in text_lines]
78
 
79
+ prompt = self.text_math_rewriting_prompt.replace("{extracted_lines}", json.dumps({"extracted_lines": extracted_lines}, indent=2))
80
+ image = self.extract_image(document, block)
81
  response_schema = content.Schema(
82
  type=content.Type.OBJECT,
83
  enum=[],
marker/processors/llm/utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import time
 
3
 
4
  import PIL
5
  import google.generativeai as genai
@@ -25,17 +26,19 @@ class GoogleModel:
25
  def generate_response(
26
  self,
27
  prompt: str,
28
- image: PIL.Image.Image,
29
  block: Block,
30
  response_schema: content.Schema,
31
  max_retries: int = 3,
32
  timeout: int = 60
33
  ):
 
 
34
  tries = 0
35
  while tries < max_retries:
36
  try:
37
  responses = self.model.generate_content(
38
- [prompt, image],
39
  stream=False,
40
  generation_config={
41
  "temperature": 0,
 
1
  import json
2
  import time
3
+ from typing import List
4
 
5
  import PIL
6
  import google.generativeai as genai
 
26
  def generate_response(
27
  self,
28
  prompt: str,
29
+ image: PIL.Image.Image | List[PIL.Image.Image],
30
  block: Block,
31
  response_schema: content.Schema,
32
  max_retries: int = 3,
33
  timeout: int = 60
34
  ):
35
+ if not isinstance(image, list):
36
+ image = [image]
37
  tries = 0
38
  while tries < max_retries:
39
  try:
40
  responses = self.model.generate_content(
41
+ image + [prompt], # According to gemini docs, it performs better if the image is the first element
42
  stream=False,
43
  generation_config={
44
  "temperature": 0,
marker/processors/table.py CHANGED
@@ -1,18 +1,22 @@
1
-
2
- from typing import Annotated
 
 
3
 
4
  from ftfy import fix_text
5
- from surya.input.pdflines import get_page_text_lines
6
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
7
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
8
- from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
9
- from tabled.assignment import assign_rows_columns
10
- from tabled.inference.recognition import get_cells, recognize_tables
11
 
12
  from marker.processors import BaseProcessor
13
  from marker.schema import BlockTypes
 
14
  from marker.schema.document import Document
 
15
  from marker.settings import settings
 
16
 
17
 
18
  class TableProcessor(BaseProcessor):
@@ -42,9 +46,9 @@ class TableProcessor(BaseProcessor):
42
 
43
  def __init__(
44
  self,
45
- detection_model: EfficientViTForSemanticSegmentation,
46
- recognition_model: OCREncoderDecoderModel,
47
- table_rec_model: TableRecEncoderDecoderModel,
48
  config=None
49
  ):
50
  super().__init__(config)
@@ -59,51 +63,207 @@ class TableProcessor(BaseProcessor):
59
  table_data = []
60
  for page in document.pages:
61
  for block in page.contained_blocks(document, self.block_types):
62
- image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.highres_image.size)
63
- image = page.highres_image.crop(image_poly.bbox).convert("RGB")
64
-
65
- if block.text_extraction_method == "surya":
66
- text_lines = None
67
- else:
68
- text_lines = get_page_text_lines(
69
- filepath,
70
- [page.page_id],
71
- [page.highres_image.size],
72
- flatten_pdf=True
73
- )[0]
74
 
75
  table_data.append({
76
  "block_id": block.id,
 
77
  "table_image": image,
78
  "table_bbox": image_poly.bbox,
79
- "text_lines": text_lines,
80
- "img_size": page.highres_image.size
81
  })
82
 
83
- lst_format = [[t[key] for t in table_data] for key in ["table_image", "table_bbox", "img_size", "text_lines"]]
 
84
 
85
- cells, needs_ocr = get_cells(
86
- *lst_format,
87
- [self.detection_model, self.detection_model.processor],
88
- detect_boxes=self.detect_boxes,
89
- detector_batch_size=self.get_detector_batch_size()
90
- )
91
 
92
- tables = recognize_tables(
93
  [t["table_image"] for t in table_data],
94
- cells,
95
- needs_ocr,
96
- [self.table_rec_model, self.table_rec_model.processor, self.recognition_model, self.recognition_model.processor],
97
- table_rec_batch_size=self.get_table_rec_batch_size(),
98
- ocr_batch_size=self.get_recognition_batch_size()
99
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- for table_d, table_res in zip(table_data, tables):
102
- block = document.get_block(table_d["block_id"])
103
- cells = assign_rows_columns(table_res, table_d["img_size"])
104
- for cell in cells:
105
- cell.text = fix_text(cell.text)
106
- block.cells = cells
107
 
108
  def get_detector_batch_size(self):
109
  if self.detector_batch_size is not None:
 
1
+ import re
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+ from typing import Annotated, List
5
 
6
  from ftfy import fix_text
7
+ from surya.detection import DetectionPredictor
8
+ from surya.recognition import RecognitionPredictor, OCRResult
9
+ from surya.table_rec import TableRecPredictor
10
+ from surya.table_rec.schema import TableResult, TableCell as SuryaTableCell
11
+ from pdftext.extraction import table_output
 
12
 
13
  from marker.processors import BaseProcessor
14
  from marker.schema import BlockTypes
15
+ from marker.schema.blocks.tablecell import TableCell
16
  from marker.schema.document import Document
17
+ from marker.schema.polygon import PolygonBox
18
  from marker.settings import settings
19
+ from marker.util import matrix_intersection_area
20
 
21
 
22
  class TableProcessor(BaseProcessor):
 
46
 
47
  def __init__(
48
  self,
49
+ detection_model: DetectionPredictor,
50
+ recognition_model: RecognitionPredictor,
51
+ table_rec_model: TableRecPredictor,
52
  config=None
53
  ):
54
  super().__init__(config)
 
63
  table_data = []
64
  for page in document.pages:
65
  for block in page.contained_blocks(document, self.block_types):
66
+ image = block.get_image(document, highres=True, expansion=(.01, .01))
67
+ image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.get_image(highres=True).size)
 
 
 
 
 
 
 
 
 
 
68
 
69
  table_data.append({
70
  "block_id": block.id,
71
+ "page_id": page.page_id,
72
  "table_image": image,
73
  "table_bbox": image_poly.bbox,
74
+ "img_size": page.get_image(highres=True).size,
75
+ "ocr_block": page.text_extraction_method == "surya",
76
  })
77
 
78
+ extract_blocks = [t for t in table_data if not t["ocr_block"]]
79
+ self.assign_pdftext_lines(extract_blocks, filepath) # Handle tables where good text exists in the PDF
80
 
81
+ ocr_blocks = [t for t in table_data if t["ocr_block"]]
82
+ self.assign_ocr_lines(ocr_blocks) # Handle tables where OCR is needed
83
+ assert all("table_text_lines" in t for t in table_data), "All table data must have table cells"
 
 
 
84
 
85
+ tables: List[TableResult] = self.table_rec_model(
86
  [t["table_image"] for t in table_data],
87
+ batch_size=self.get_table_rec_batch_size()
 
 
 
 
88
  )
89
+ self.assign_text_to_cells(tables, table_data)
90
+ self.split_combined_rows(tables) # Split up rows that were combined
91
+
92
+ # Assign table cells to the table
93
+ table_idx = 0
94
+ for page in document.pages:
95
+ for block in page.contained_blocks(document, self.block_types):
96
+ block.structure = [] # Remove any existing lines, spans, etc.
97
+ cells: List[SuryaTableCell] = tables[table_idx].cells
98
+ for cell in cells:
99
+ # Rescale the cell polygon to the page size
100
+ cell_polygon = PolygonBox(polygon=cell.polygon).rescale(page.get_image(highres=True).size, page.polygon.size)
101
+ cell_block = TableCell(
102
+ polygon=cell_polygon,
103
+ text=self.finalize_cell_text(cell),
104
+ rowspan=cell.rowspan,
105
+ colspan=cell.colspan,
106
+ row_id=cell.row_id,
107
+ col_id=cell.col_id,
108
+ is_header=bool(cell.is_header),
109
+ page_id=page.page_id,
110
+ )
111
+ page.add_full_block(cell_block)
112
+ block.add_structure(cell_block)
113
+ table_idx += 1
114
+
115
+ def finalize_cell_text(self, cell: SuryaTableCell):
116
+ text = "\n".join([t["text"].strip() for t in cell.text_lines]) if cell.text_lines else ""
117
+ text = re.sub(r"(\s\.){2,}", "", text) # Replace . . .
118
+ text = re.sub(r"\.{2,}", "", text) # Replace ..., like in table of contents
119
+ return self.normalize_spaces(fix_text(text))
120
+
121
+ @staticmethod
122
+ def normalize_spaces(text):
123
+ space_chars = [
124
+ '\u2003', # em space
125
+ '\u2002', # en space
126
+ '\u00A0', # non-breaking space
127
+ '\u200B', # zero-width space
128
+ '\u3000', # ideographic space
129
+ ]
130
+ for space in space_chars:
131
+ text = text.replace(space, ' ')
132
+ return text
133
+
134
+ def split_combined_rows(self, tables: List[TableResult]):
135
+ for table in tables:
136
+ if len(table.cells) == 0:
137
+ # Skip empty tables
138
+ continue
139
+ unique_rows = sorted(list(set([c.row_id for c in table.cells])))
140
+ new_cells = []
141
+ shift_up = 0
142
+ max_cell_id = max([c.cell_id for c in table.cells])
143
+ new_cell_count = 0
144
+ for row in unique_rows:
145
+ # Cells in this row
146
+ # Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
147
+ # making them be processed twice
148
+ row_cells = deepcopy([c for c in table.cells if c.row_id == row])
149
+ rowspans = [c.rowspan for c in row_cells]
150
+ line_lens = [len(c.text_lines) if isinstance(c.text_lines, list) else 1 for c in row_cells]
151
+
152
+ # Other cells that span into this row
153
+ rowspan_cells = [c for c in table.cells if c.row_id != row and c.row_id + c.rowspan > row > c.row_id]
154
+ should_split = all([
155
+ len(row_cells) > 0,
156
+ len(rowspan_cells) == 0,
157
+ all([r == 1 for r in rowspans]),
158
+ all([l > 1 for l in line_lens]),
159
+ all([l == line_lens[0] for l in line_lens])
160
+ ])
161
+ if should_split:
162
+ for i in range(0, line_lens[0]):
163
+ for cell in row_cells:
164
+ line = cell.text_lines[i]
165
+ cell_id = max_cell_id + new_cell_count
166
+ new_cells.append(
167
+ SuryaTableCell(
168
+ polygon=line["bbox"],
169
+ text_lines=[line],
170
+ rowspan=1,
171
+ colspan=cell.colspan,
172
+ row_id=cell.row_id + shift_up + i,
173
+ col_id=cell.col_id,
174
+ is_header=cell.is_header and i == 0, # Only first line is header
175
+ within_row_id=cell.within_row_id,
176
+ cell_id=cell_id
177
+ )
178
+ )
179
+ new_cell_count += 1
180
+
181
+ # For each new row we add, shift up subsequent rows
182
+ shift_up += line_lens[0] - 1
183
+ else:
184
+ for cell in row_cells:
185
+ cell.row_id += shift_up
186
+ new_cells.append(cell)
187
+
188
+ # Only update the cells if we added new cells
189
+ if len(new_cells) > len(table.cells):
190
+ table.cells = new_cells
191
+
192
+ def assign_text_to_cells(self, tables: List[TableResult], table_data: list):
193
+ for table_result, table_page_data in zip(tables, table_data):
194
+ table_text_lines = table_page_data["table_text_lines"]
195
+ table_cells: List[SuryaTableCell] = table_result.cells
196
+ text_line_bboxes = [t["bbox"] for t in table_text_lines]
197
+ table_cell_bboxes = [c.bbox for c in table_cells]
198
+
199
+ intersection_matrix = matrix_intersection_area(text_line_bboxes, table_cell_bboxes)
200
+
201
+ cell_text = defaultdict(list)
202
+ for text_line_idx, table_text_line in enumerate(table_text_lines):
203
+ intersections = intersection_matrix[text_line_idx]
204
+ if intersections.sum() == 0:
205
+ continue
206
+
207
+ max_intersection = intersections.argmax()
208
+ cell_text[max_intersection].append(table_text_line)
209
+
210
+ for k in cell_text:
211
+ # TODO: see if the text needs to be sorted (based on rotation)
212
+ text = cell_text[k]
213
+ assert all("text" in t for t in text), "All text lines must have text"
214
+ assert all("bbox" in t for t in text), "All text lines must have a bbox"
215
+ table_cells[k].text_lines = text
216
+
217
+ def assign_pdftext_lines(self, extract_blocks: list, filepath: str):
218
+ table_inputs = []
219
+ unique_pages = list(set([t["page_id"] for t in extract_blocks]))
220
+ if len(unique_pages) == 0:
221
+ return
222
+
223
+ for page in unique_pages:
224
+ tables = []
225
+ img_size = None
226
+ for block in extract_blocks:
227
+ if block["page_id"] == page:
228
+ tables.append(block["table_bbox"])
229
+ img_size = block["img_size"]
230
+
231
+ table_inputs.append({
232
+ "tables": tables,
233
+ "img_size": img_size
234
+ })
235
+ cell_text = table_output(filepath, table_inputs, page_range=unique_pages)
236
+ assert len(cell_text) == len(unique_pages), "Number of pages and table inputs must match"
237
+
238
+ for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)):
239
+ table_idx = 0
240
+ for block in extract_blocks:
241
+ if block["page_id"] == pnum:
242
+ block["table_text_lines"] = page_tables[table_idx]
243
+ table_idx += 1
244
+ assert table_idx == len(page_tables), "Number of tables and table inputs must match"
245
+
246
+ def assign_ocr_lines(self, ocr_blocks: list):
247
+ det_images = [t["table_image"] for t in ocr_blocks]
248
+ ocr_results: List[OCRResult] = self.recognition_model(
249
+ det_images,
250
+ [None] * len(det_images),
251
+ self.detection_model,
252
+ recognition_batch_size=self.get_recognition_batch_size(),
253
+ detection_batch_size=self.get_detector_batch_size()
254
+ )
255
+
256
+ for block, ocr_res in zip(ocr_blocks, ocr_results):
257
+ table_cells = []
258
+ for line in ocr_res.text_lines:
259
+ # Don't need to correct back to image size
260
+ # Table rec boxes are relative to the table
261
+ table_cells.append({
262
+ "bbox": line.bbox,
263
+ "text": line.text
264
+ })
265
+ block["table_text_lines"] = table_cells
266
 
 
 
 
 
 
 
267
 
268
  def get_detector_batch_size(self):
269
  if self.detector_batch_size is not None:
marker/providers/__init__.py CHANGED
@@ -3,6 +3,9 @@ from typing import List, Optional, Dict
3
  from PIL import Image
4
  from pydantic import BaseModel
5
 
 
 
 
6
  from marker.schema.text import Span
7
  from marker.schema.text.line import Line
8
  from marker.util import assign_config
@@ -29,8 +32,17 @@ class BaseProvider:
29
  def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
30
  pass
31
 
32
- def get_page_bbox(self, idx: int) -> List[float]:
33
  pass
34
 
35
  def get_page_lines(self, idx: int) -> List[Line]:
36
  pass
 
 
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
  from pydantic import BaseModel
5
 
6
+ from pdftext.schema import Reference
7
+
8
+ from marker.schema.polygon import PolygonBox
9
  from marker.schema.text import Span
10
  from marker.schema.text.line import Line
11
  from marker.util import assign_config
 
32
  def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
33
  pass
34
 
35
+ def get_page_bbox(self, idx: int) -> PolygonBox | None:
36
  pass
37
 
38
  def get_page_lines(self, idx: int) -> List[Line]:
39
  pass
40
+
41
+ def get_page_refs(self, idx: int) -> List[Reference]:
42
+ pass
43
+
44
+ def __enter__(self):
45
+ return self
46
+
47
+ def __exit__(self, exc_type, exc_value, traceback):
48
+ raise NotImplementedError
marker/providers/image.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Annotated, Optional
2
+ from PIL import Image
3
+
4
+ from marker.providers import ProviderPageLines, BaseProvider
5
+ from marker.schema.polygon import PolygonBox
6
+ from marker.schema.text import Line
7
+ from pdftext.schema import Reference
8
+
9
+
10
+ class ImageProvider(BaseProvider):
11
+ page_range: Annotated[
12
+ Optional[List[int]],
13
+ "The range of pages to process.",
14
+ "Default is None, which will process all pages."
15
+ ] = None
16
+
17
+ image_count: int = 1
18
+
19
+ def __init__(self, filepath: str, config=None):
20
+ super().__init__(filepath, config)
21
+
22
+ self.images = [Image.open(filepath)]
23
+ self.page_lines: ProviderPageLines = {i: [] for i in range(self.image_count)}
24
+
25
+ if self.page_range is None:
26
+ self.page_range = range(self.image_count)
27
+
28
+ assert max(self.page_range) < self.image_count and min(self.page_range) >= 0, \
29
+ f"Invalid page range, values must be between 0 and {len(self.doc) - 1}. Min of provided page range is {min(self.page_range)} and max is {max(self.page_range)}."
30
+
31
+ self.page_bboxes = {i: [0, 0, self.images[i].size[0], self.images[i].size[1]] for i in self.page_range}
32
+
33
+ def __len__(self):
34
+ return self.image_count
35
+
36
+ def __exit__(self, exc_type, exc_value, traceback):
37
+ pass
38
+
39
+ def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
40
+ return [self.images[i] for i in idxs]
41
+
42
+ def get_page_bbox(self, idx: int) -> PolygonBox | None:
43
+ bbox = self.page_bboxes[idx]
44
+ if bbox:
45
+ return PolygonBox.from_bbox(bbox)
46
+
47
+
48
+ def get_page_lines(self, idx: int) -> List[Line]:
49
+ return self.page_lines[idx]
50
+
51
+ def get_page_refs(self, idx: int) -> List[Reference]:
52
+ return []
marker/providers/pdf.py CHANGED
@@ -9,6 +9,7 @@ from ftfy import fix_text
9
  from pdftext.extraction import dictionary_output
10
  from pdftext.schema import Reference
11
  from PIL import Image
 
12
 
13
  from marker.providers import BaseProvider, ProviderOutput, ProviderPageLines
14
  from marker.providers.utils import alphanum_ratio
@@ -91,9 +92,6 @@ class PdfProvider(BaseProvider):
91
 
92
  atexit.register(self.cleanup_pdf_doc)
93
 
94
- def __enter__(self):
95
- return self
96
-
97
  def __exit__(self, exc_type, exc_value, traceback):
98
  self.cleanup_pdf_doc()
99
 
@@ -155,6 +153,19 @@ class PdfProvider(BaseProvider):
155
  formats.add("italic")
156
  return formats
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def pdftext_extraction(self) -> ProviderPageLines:
159
  page_lines: ProviderPageLines = {}
160
  page_char_blocks = dictionary_output(
@@ -191,7 +202,7 @@ class PdfProvider(BaseProvider):
191
  spans.append(
192
  SpanClass(
193
  polygon=polygon,
194
- text=fix_text(span["text"]),
195
  font=font_name,
196
  font_weight=font_weight,
197
  font_size=font_size,
@@ -234,7 +245,11 @@ class PdfProvider(BaseProvider):
234
  def check_page(self, page_id: int) -> bool:
235
  page = self.doc.get_page(page_id)
236
  page_bbox = PolygonBox.from_bbox(page.get_bbox())
237
- page_objs = list(page.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_TEXT, pdfium_c.FPDF_PAGEOBJ_IMAGE]))
 
 
 
 
238
 
239
  # if we do not see any text objects in the pdf, we can skip this page
240
  if not any([obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT for obj in page_objs]):
@@ -313,7 +328,7 @@ class PdfProvider(BaseProvider):
313
  def get_page_lines(self, idx: int) -> List[ProviderOutput]:
314
  return self.page_lines[idx]
315
 
316
- def get_page_refs(self, idx: int):
317
  return self.page_refs[idx]
318
 
319
  @staticmethod
 
9
  from pdftext.extraction import dictionary_output
10
  from pdftext.schema import Reference
11
  from PIL import Image
12
+ from pypdfium2 import PdfiumError
13
 
14
  from marker.providers import BaseProvider, ProviderOutput, ProviderPageLines
15
  from marker.providers.utils import alphanum_ratio
 
92
 
93
  atexit.register(self.cleanup_pdf_doc)
94
 
 
 
 
95
  def __exit__(self, exc_type, exc_value, traceback):
96
  self.cleanup_pdf_doc()
97
 
 
153
  formats.add("italic")
154
  return formats
155
 
156
+ @staticmethod
157
+ def normalize_spaces(text):
158
+ space_chars = [
159
+ '\u2003', # em space
160
+ '\u2002', # en space
161
+ '\u00A0', # non-breaking space
162
+ '\u200B', # zero-width space
163
+ '\u3000', # ideographic space
164
+ ]
165
+ for space in space_chars:
166
+ text = text.replace(space, ' ')
167
+ return text
168
+
169
  def pdftext_extraction(self) -> ProviderPageLines:
170
  page_lines: ProviderPageLines = {}
171
  page_char_blocks = dictionary_output(
 
202
  spans.append(
203
  SpanClass(
204
  polygon=polygon,
205
+ text=self.normalize_spaces(fix_text(span["text"])),
206
  font=font_name,
207
  font_weight=font_weight,
208
  font_size=font_size,
 
245
  def check_page(self, page_id: int) -> bool:
246
  page = self.doc.get_page(page_id)
247
  page_bbox = PolygonBox.from_bbox(page.get_bbox())
248
+ try:
249
+ page_objs = list(page.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_TEXT, pdfium_c.FPDF_PAGEOBJ_IMAGE]))
250
+ except PdfiumError:
251
+ # Happens when pdfium fails to get the number of page objects
252
+ return False
253
 
254
  # if we do not see any text objects in the pdf, we can skip this page
255
  if not any([obj.type == pdfium_c.FPDF_PAGEOBJ_TEXT for obj in page_objs]):
 
328
  def get_page_lines(self, idx: int) -> List[ProviderOutput]:
329
  return self.page_lines[idx]
330
 
331
+ def get_page_refs(self, idx: int) -> List[Reference]:
332
  return self.page_refs[idx]
333
 
334
  @staticmethod
marker/providers/registry.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import filetype
2
+
3
+ from marker.providers.image import ImageProvider
4
+ from marker.providers.pdf import PdfProvider
5
+
6
+
7
+ def provider_from_filepath(filepath: str):
8
+ kind = filetype.image_match(filepath)
9
+ if kind is not None:
10
+ return ImageProvider
11
+
12
+ return PdfProvider
marker/renderers/__init__.py CHANGED
@@ -2,7 +2,7 @@ import base64
2
  import io
3
  import re
4
  from collections import Counter
5
- from typing import Annotated, Optional, Tuple
6
 
7
  from bs4 import BeautifulSoup
8
  from pydantic import BaseModel
@@ -17,6 +17,11 @@ from marker.util import assign_config
17
  class BaseRenderer:
18
  image_blocks: Annotated[Tuple[BlockTypes, ...], "The block types to consider as images."] = (BlockTypes.Picture, BlockTypes.Figure)
19
  extract_images: Annotated[bool, "Extract images from the document."] = True
 
 
 
 
 
20
 
21
  def __init__(self, config: Optional[BaseModel | dict] = None):
22
  assign_config(self, config)
@@ -25,13 +30,10 @@ class BaseRenderer:
25
  # Children are in reading order
26
  raise NotImplementedError
27
 
28
- @staticmethod
29
- def extract_image(document: Document, image_id, to_base64=False):
30
  image_block = document.get_block(image_id)
31
- page = document.get_page(image_block.page_id)
32
- page_img = page.highres_image
33
- image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
34
- cropped = page_img.crop(image_box.bbox)
35
  if to_base64:
36
  image_buffer = io.BytesIO()
37
  cropped.save(image_buffer, format=settings.OUTPUT_IMAGE_FORMAT)
@@ -44,7 +46,11 @@ class BaseRenderer:
44
  return html
45
 
46
  def replace_whitespace(match):
47
- return match.group(1)
 
 
 
 
48
 
49
  pattern = fr'</{tag}>(\s*)<{tag}>'
50
 
 
2
  import io
3
  import re
4
  from collections import Counter
5
+ from typing import Annotated, Optional, Tuple, Literal
6
 
7
  from bs4 import BeautifulSoup
8
  from pydantic import BaseModel
 
17
  class BaseRenderer:
18
  image_blocks: Annotated[Tuple[BlockTypes, ...], "The block types to consider as images."] = (BlockTypes.Picture, BlockTypes.Figure)
19
  extract_images: Annotated[bool, "Extract images from the document."] = True
20
+ image_extraction_mode: Annotated[
21
+ Literal["lowres", "highres"],
22
+ "The mode to use for extracting images.",
23
+ ] = "highres"
24
+
25
 
26
  def __init__(self, config: Optional[BaseModel | dict] = None):
27
  assign_config(self, config)
 
30
  # Children are in reading order
31
  raise NotImplementedError
32
 
33
+ def extract_image(self, document: Document, image_id, to_base64=False):
 
34
  image_block = document.get_block(image_id)
35
+ cropped = image_block.get_image(document, highres=self.image_extraction_mode == "highres")
36
+
 
 
37
  if to_base64:
38
  image_buffer = io.BytesIO()
39
  cropped.save(image_buffer, format=settings.OUTPUT_IMAGE_FORMAT)
 
46
  return html
47
 
48
  def replace_whitespace(match):
49
+ whitespace = match.group(1)
50
+ if len(whitespace) == 0:
51
+ return ""
52
+ else:
53
+ return " "
54
 
55
  pattern = fr'</{tag}>(\s*)<{tag}>'
56
 
marker/renderers/html.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from PIL import Image
2
  from typing import Annotated, Literal, Tuple
3
 
@@ -35,17 +37,10 @@ class HTMLRenderer(BaseRenderer):
35
  bool,
36
  "Whether to paginate the output.",
37
  ] = False
38
- image_extraction_mode: Annotated[
39
- Literal["lowres", "highres"],
40
- "The mode to use for extracting images.",
41
- ] = "highres"
42
 
43
  def extract_image(self, document, image_id):
44
  image_block = document.get_block(image_id)
45
- page = document.get_page(image_block.page_id)
46
- page_img = page.lowres_image if self.image_extraction_mode == "lowres" else page.highres_image
47
- image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
48
- cropped = page_img.crop(image_box.bbox)
49
  return cropped
50
 
51
  def extract_html(self, document, document_output, level=0):
@@ -87,12 +82,25 @@ class HTMLRenderer(BaseRenderer):
87
  if level == 0:
88
  output = self.merge_consecutive_tags(output, 'b')
89
  output = self.merge_consecutive_tags(output, 'i')
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  return output, images
92
 
93
  def __call__(self, document) -> HTMLOutput:
94
  document_output = document.render()
95
  full_html, images = self.extract_html(document, document_output)
 
 
96
  return HTMLOutput(
97
  html=full_html,
98
  images=images,
 
1
+ import textwrap
2
+
3
  from PIL import Image
4
  from typing import Annotated, Literal, Tuple
5
 
 
37
  bool,
38
  "Whether to paginate the output.",
39
  ] = False
 
 
 
 
40
 
41
  def extract_image(self, document, image_id):
42
  image_block = document.get_block(image_id)
43
+ cropped = image_block.get_image(document, highres=self.image_extraction_mode == "highres")
 
 
 
44
  return cropped
45
 
46
  def extract_html(self, document, document_output, level=0):
 
82
  if level == 0:
83
  output = self.merge_consecutive_tags(output, 'b')
84
  output = self.merge_consecutive_tags(output, 'i')
85
+ output = textwrap.dedent(f"""
86
+ <!DOCTYPE html>
87
+ <html>
88
+ <head>
89
+ <meta charset="utf-8" />
90
+ </head>
91
+ <body>
92
+ {output}
93
+ </body>
94
+ </html>
95
+ """)
96
 
97
  return output, images
98
 
99
  def __call__(self, document) -> HTMLOutput:
100
  document_output = document.render()
101
  full_html, images = self.extract_html(document, document_output)
102
+ soup = BeautifulSoup(full_html, 'html.parser')
103
+ full_html = soup.prettify() # Add indentation to the HTML
104
  return HTMLOutput(
105
  html=full_html,
106
  images=images,
marker/renderers/json.py CHANGED
@@ -14,6 +14,7 @@ class JSONBlockOutput(BaseModel):
14
  block_type: str
15
  html: str
16
  polygon: List[List[float]]
 
17
  children: List['JSONBlockOutput'] | None = None
18
  section_hierarchy: Dict[int, str] | None = None
19
  images: dict | None = None
@@ -52,6 +53,7 @@ class JSONRenderer(BaseRenderer):
52
  return JSONBlockOutput(
53
  html=html,
54
  polygon=block_output.polygon.polygon,
 
55
  id=str(block_output.id),
56
  block_type=str(block_output.id.block_type),
57
  images=images,
@@ -66,6 +68,7 @@ class JSONRenderer(BaseRenderer):
66
  return JSONBlockOutput(
67
  html=block_output.html,
68
  polygon=block_output.polygon.polygon,
 
69
  id=str(block_output.id),
70
  block_type=str(block_output.id.block_type),
71
  children=children,
 
14
  block_type: str
15
  html: str
16
  polygon: List[List[float]]
17
+ bbox: List[float]
18
  children: List['JSONBlockOutput'] | None = None
19
  section_hierarchy: Dict[int, str] | None = None
20
  images: dict | None = None
 
53
  return JSONBlockOutput(
54
  html=html,
55
  polygon=block_output.polygon.polygon,
56
+ bbox=block_output.polygon.bbox,
57
  id=str(block_output.id),
58
  block_type=str(block_output.id.block_type),
59
  images=images,
 
68
  return JSONBlockOutput(
69
  html=block_output.html,
70
  polygon=block_output.polygon.polygon,
71
+ bbox=block_output.polygon.bbox,
72
  id=str(block_output.id),
73
  block_type=str(block_output.id.block_type),
74
  children=children,
marker/renderers/markdown.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  from typing import Annotated, Tuple
3
 
4
  import regex
@@ -13,7 +14,7 @@ from marker.schema.document import Document
13
  def cleanup_text(full_text):
14
  full_text = re.sub(r'\n{3,}', '\n\n', full_text)
15
  full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
16
- return full_text
17
 
18
 
19
  class Markdownify(MarkdownConverter):
@@ -53,13 +54,88 @@ class Markdownify(MarkdownConverter):
53
  else:
54
  return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n"
55
 
56
- def convert_td(self, el, text, convert_as_inline):
57
- text = text.replace("|", " ").replace("\n", " ")
58
- return super().convert_td(el, text, convert_as_inline)
59
-
60
- def convert_th(self, el, text, convert_as_inline):
61
- text = text.replace("|", " ").replace("\n", " ")
62
- return super().convert_th(el, text, convert_as_inline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def convert_a(self, el, text, convert_as_inline):
65
  text = self.escape(text)
 
1
  import re
2
+ from collections import defaultdict
3
  from typing import Annotated, Tuple
4
 
5
  import regex
 
14
  def cleanup_text(full_text):
15
  full_text = re.sub(r'\n{3,}', '\n\n', full_text)
16
  full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
17
+ return full_text.strip()
18
 
19
 
20
  class Markdownify(MarkdownConverter):
 
54
  else:
55
  return "\n" + self.block_math_delimiters[0] + text + self.block_math_delimiters[1] + "\n"
56
 
57
+ def convert_table(self, el, text, convert_as_inline):
58
+ total_rows = len(el.find_all('tr'))
59
+ colspans = []
60
+ rowspan_cols = defaultdict(int)
61
+ for i, row in enumerate(el.find_all('tr')):
62
+ row_cols = rowspan_cols[i]
63
+ for cell in row.find_all(['td', 'th']):
64
+ colspan = int(cell.get('colspan', 1))
65
+ row_cols += colspan
66
+ for r in range(int(cell.get('rowspan', 1)) - 1):
67
+ rowspan_cols[i + r] += colspan # Add the colspan to the next rows, so they get the correct number of columns
68
+ colspans.append(row_cols)
69
+ total_cols = max(colspans)
70
+
71
+ grid = [[None for _ in range(total_cols)] for _ in range(total_rows)]
72
+
73
+ for row_idx, tr in enumerate(el.find_all('tr')):
74
+ col_idx = 0
75
+ for cell in tr.find_all(['td', 'th']):
76
+ # Skip filled positions
77
+ while col_idx < total_cols and grid[row_idx][col_idx] is not None:
78
+ col_idx += 1
79
+
80
+ # Fill in grid
81
+ value = cell.get_text(strip=True).replace("\n", " ").replace("|", " ")
82
+ rowspan = int(cell.get('rowspan', 1))
83
+ colspan = int(cell.get('colspan', 1))
84
+
85
+ if col_idx >= total_cols:
86
+ # Skip this cell if we're out of bounds
87
+ continue
88
+
89
+ for r in range(rowspan):
90
+ for c in range(colspan):
91
+ try:
92
+ if r == 0 and c == 0:
93
+ grid[row_idx][col_idx] = value
94
+ else:
95
+ grid[row_idx + r][col_idx + c] = ''
96
+ except IndexError:
97
+ # Sometimes the colspan/rowspan predictions can overflow
98
+ print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
99
+ continue
100
+
101
+ col_idx += colspan
102
+
103
+ markdown_lines = []
104
+ col_widths = [0] * total_cols
105
+ for row in grid:
106
+ for col_idx, cell in enumerate(row):
107
+ if cell is not None:
108
+ col_widths[col_idx] = max(col_widths[col_idx], len(str(cell)))
109
+
110
+ add_header_line = lambda: markdown_lines.append('|' + '|'.join('-' * (width + 2) for width in col_widths) + '|')
111
+
112
+ # Generate markdown rows
113
+ added_header = False
114
+ for i, row in enumerate(grid):
115
+ is_empty_line = all(not cell for cell in row)
116
+ if is_empty_line and not added_header:
117
+ # Skip leading blank lines
118
+ continue
119
+
120
+ line = []
121
+ for col_idx, cell in enumerate(row):
122
+ if cell is None:
123
+ cell = ''
124
+ padding = col_widths[col_idx] - len(str(cell))
125
+ line.append(f" {cell}{' ' * padding} ")
126
+ markdown_lines.append('|' + '|'.join(line) + '|')
127
+
128
+ if not added_header:
129
+ # Skip empty lines when adding the header row
130
+ add_header_line()
131
+ added_header = True
132
+
133
+ # Handle one row tables
134
+ if total_rows == 1:
135
+ add_header_line()
136
+
137
+ table_md = '\n'.join(markdown_lines)
138
+ return "\n\n" + table_md + "\n\n"
139
 
140
  def convert_a(self, el, text, convert_as_inline):
141
  text = self.escape(text)
marker/schema/__init__.py CHANGED
@@ -27,6 +27,7 @@ class BlockTypes(str, Enum):
27
  TableOfContents = auto()
28
  Document = auto()
29
  ComplexRegion = auto()
 
30
  Reference = auto()
31
 
32
  def __str__(self):
 
27
  TableOfContents = auto()
28
  Document = auto()
29
  ComplexRegion = auto()
30
+ TableCell = auto()
31
  Reference = auto()
32
 
33
  def __str__(self):
marker/schema/blocks/__init__.py CHANGED
@@ -18,4 +18,5 @@ from marker.schema.blocks.table import Table
18
  from marker.schema.blocks.text import Text
19
  from marker.schema.blocks.toc import TableOfContents
20
  from marker.schema.blocks.complexregion import ComplexRegion
 
21
  from marker.schema.blocks.reference import Reference
 
18
  from marker.schema.blocks.text import Text
19
  from marker.schema.blocks.toc import TableOfContents
20
  from marker.schema.blocks.complexregion import ComplexRegion
21
+ from marker.schema.blocks.tablecell import TableCell
22
  from marker.schema.blocks.reference import Reference
marker/schema/blocks/base.py CHANGED
@@ -1,8 +1,9 @@
1
  from __future__ import annotations
2
 
3
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Sequence
4
 
5
  from pydantic import BaseModel, ConfigDict, field_validator
 
6
 
7
  from marker.schema import BlockTypes
8
  from marker.schema.polygon import PolygonBox
@@ -71,6 +72,7 @@ class BlockId(BaseModel):
71
 
72
  class Block(BaseModel):
73
  polygon: PolygonBox
 
74
  block_type: Optional[BlockTypes] = None
75
  block_id: Optional[int] = None
76
  page_id: Optional[int] = None
@@ -81,6 +83,8 @@ class Block(BaseModel):
81
  source: Literal['layout', 'heuristics', 'processor'] = 'layout'
82
  top_k: Optional[Dict[BlockTypes, float]] = None
83
  metadata: BlockMetadata | None = None
 
 
84
 
85
  model_config = ConfigDict(arbitrary_types_allowed=True)
86
 
@@ -97,6 +101,21 @@ class Block(BaseModel):
97
  block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
98
  return cls(**block_attrs)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def structure_blocks(self, document_page: Document | PageGroup) -> List[Block]:
101
  if self.structure is None:
102
  return []
@@ -163,7 +182,7 @@ class Block(BaseModel):
163
  text += "\n"
164
  return text
165
 
166
- def assemble_html(self, child_blocks: List[BlockOutput], parent_structure: Optional[List[str]] = None):
167
  if self.ignore_for_output:
168
  return ""
169
 
@@ -172,7 +191,8 @@ class Block(BaseModel):
172
  template += f"<content-ref src='{c.id}'></content-ref>"
173
 
174
  if self.replace_output_newlines:
175
- template = "<p>" + template.replace("\n", " ") + "</p>"
 
176
 
177
  return template
178
 
@@ -205,7 +225,7 @@ class Block(BaseModel):
205
  self.structure[i] = new_block.id
206
  break
207
 
208
- def render(self, document: Document, parent_structure: Optional[List[str]], section_hierarchy=None):
209
  child_content = []
210
  if section_hierarchy is None:
211
  section_hierarchy = {}
@@ -219,7 +239,7 @@ class Block(BaseModel):
219
  child_content.append(rendered)
220
 
221
  return BlockOutput(
222
- html=self.assemble_html(child_content, parent_structure),
223
  polygon=self.polygon,
224
  id=self.id,
225
  children=child_content,
 
1
  from __future__ import annotations
2
 
3
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Sequence, Tuple
4
 
5
  from pydantic import BaseModel, ConfigDict, field_validator
6
+ from PIL import Image
7
 
8
  from marker.schema import BlockTypes
9
  from marker.schema.polygon import PolygonBox
 
72
 
73
  class Block(BaseModel):
74
  polygon: PolygonBox
75
+ block_description: str
76
  block_type: Optional[BlockTypes] = None
77
  block_id: Optional[int] = None
78
  page_id: Optional[int] = None
 
83
  source: Literal['layout', 'heuristics', 'processor'] = 'layout'
84
  top_k: Optional[Dict[BlockTypes, float]] = None
85
  metadata: BlockMetadata | None = None
86
+ lowres_image: Image.Image | None = None
87
+ highres_image: Image.Image | None = None
88
 
89
  model_config = ConfigDict(arbitrary_types_allowed=True)
90
 
 
101
  block_attrs = block.model_dump(exclude=["id", "block_id", "block_type"])
102
  return cls(**block_attrs)
103
 
104
+ def get_image(self, document: Document, highres: bool = False, expansion: Tuple[float, float] | None = None) -> Image.Image | None:
105
+ image = self.highres_image if highres else self.lowres_image
106
+ if image is None:
107
+ page = document.get_page(self.page_id)
108
+ page_image = page.highres_image if highres else page.lowres_image
109
+
110
+ # Scale to the image size
111
+ bbox = self.polygon.rescale((page.polygon.width, page.polygon.height), page_image.size)
112
+ if expansion:
113
+ bbox = bbox.expand(*expansion)
114
+ bbox = bbox.bbox
115
+ image = page_image.crop(bbox)
116
+ return image
117
+
118
+
119
  def structure_blocks(self, document_page: Document | PageGroup) -> List[Block]:
120
  if self.structure is None:
121
  return []
 
182
  text += "\n"
183
  return text
184
 
185
+ def assemble_html(self, document: Document, child_blocks: List[BlockOutput], parent_structure: Optional[List[str]] = None):
186
  if self.ignore_for_output:
187
  return ""
188
 
 
191
  template += f"<content-ref src='{c.id}'></content-ref>"
192
 
193
  if self.replace_output_newlines:
194
+ template = template.replace("\n", " ")
195
+ template = "<p>" + template + "</p>"
196
 
197
  return template
198
 
 
225
  self.structure[i] = new_block.id
226
  break
227
 
228
+ def render(self, document: Document, parent_structure: Optional[List[str]] = None, section_hierarchy: dict | None = None):
229
  child_content = []
230
  if section_hierarchy is None:
231
  section_hierarchy = {}
 
239
  child_content.append(rendered)
240
 
241
  return BlockOutput(
242
+ html=self.assemble_html(document, child_content, parent_structure),
243
  polygon=self.polygon,
244
  id=self.id,
245
  children=child_content,
marker/schema/blocks/basetable.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from marker.schema import BlockTypes
4
+ from marker.schema.blocks import Block, BlockOutput
5
+ from marker.schema.blocks.tablecell import TableCell
6
+
7
+
8
+ class BaseTable(Block):
9
+ block_type: BlockTypes | None = None
10
+ html: str | None = None
11
+
12
+ def format_cells(self, document, child_blocks):
13
+ child_cells: List[TableCell] = [document.get_block(c.id) for c in child_blocks]
14
+ unique_rows = sorted(list(set([c.row_id for c in child_cells])))
15
+ html_repr = "<table><tbody>"
16
+ for row_id in unique_rows:
17
+ row_cells = sorted([c for c in child_cells if c.row_id == row_id], key=lambda x: x.col_id)
18
+ html_repr += "<tr>"
19
+ for cell in row_cells:
20
+ html_repr += cell.assemble_html(document, child_blocks, None)
21
+ html_repr += "</tr>"
22
+ html_repr += "</tbody></table>"
23
+ return html_repr
24
+
25
+
26
+ def assemble_html(self, document, child_blocks: List[BlockOutput], parent_structure=None):
27
+ # Filter out the table cells, so they don't render twice
28
+ child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
29
+ template = super().assemble_html(document, child_ref_blocks, parent_structure)
30
+
31
+ if self.html:
32
+ # LLM processor
33
+ return template + self.html
34
+ elif len(child_blocks) > 0 and child_blocks[0].id.block_type == BlockTypes.TableCell:
35
+ # Table processor
36
+ return template + self.format_cells(document, child_blocks)
37
+ else:
38
+ # Default text lines and spans
39
+ return f"<p>{template}</p>"
marker/schema/blocks/caption.py CHANGED
@@ -4,4 +4,6 @@ from marker.schema.blocks import Block
4
 
5
  class Caption(Block):
6
  block_type: BlockTypes = BlockTypes.Caption
 
7
  replace_output_newlines: bool = True
 
 
4
 
5
  class Caption(Block):
6
  block_type: BlockTypes = BlockTypes.Caption
7
+ block_description: str = "A text caption that is directly above or below an image or table. Only used for text describing the image or table. "
8
  replace_output_newlines: bool = True
9
+
marker/schema/blocks/code.py CHANGED
@@ -7,8 +7,9 @@ from marker.schema.blocks import Block
7
  class Code(Block):
8
  block_type: BlockTypes = BlockTypes.Code
9
  code: str | None = None
 
10
 
11
- def assemble_html(self, child_blocks, parent_structure):
12
  code = self.code or ""
13
  return (f"<pre>"
14
  f"{html.escape(code)}"
 
7
  class Code(Block):
8
  block_type: BlockTypes = BlockTypes.Code
9
  code: str | None = None
10
+ block_description: str = "A programming code block."
11
 
12
+ def assemble_html(self, document, child_blocks, parent_structure):
13
  code = self.code or ""
14
  return (f"<pre>"
15
  f"{html.escape(code)}"
marker/schema/blocks/complexregion.py CHANGED
@@ -5,10 +5,11 @@ from marker.schema.blocks import Block
5
  class ComplexRegion(Block):
6
  block_type: BlockTypes = BlockTypes.ComplexRegion
7
  html: str | None = None
 
8
 
9
- def assemble_html(self, child_blocks, parent_structure):
10
  if self.html:
11
  return self.html
12
  else:
13
- template = super().assemble_html(child_blocks, parent_structure)
14
  return f"<p>{template}</p>"
 
5
  class ComplexRegion(Block):
6
  block_type: BlockTypes = BlockTypes.ComplexRegion
7
  html: str | None = None
8
+ block_description: str = "A complex region that can consist of multiple different types of blocks mixed with images. This block is chosen when it is difficult to categorize the region as a single block type."
9
 
10
+ def assemble_html(self, document, child_blocks, parent_structure):
11
  if self.html:
12
  return self.html
13
  else:
14
+ template = super().assemble_html(document, child_blocks, parent_structure)
15
  return f"<p>{template}</p>"
marker/schema/blocks/equation.py CHANGED
@@ -7,11 +7,12 @@ from marker.schema.blocks import Block
7
  class Equation(Block):
8
  block_type: BlockTypes = BlockTypes.Equation
9
  latex: str | None = None
 
10
 
11
- def assemble_html(self, child_blocks, parent_structure=None):
12
  if self.latex:
13
  child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
14
- html_out = super().assemble_html(child_ref_blocks, parent_structure)
15
  html_out += f"<p block-type='{self.block_type}'>"
16
 
17
  try:
@@ -33,7 +34,7 @@ class Equation(Block):
33
  html_out += "</p>"
34
  return html_out
35
  else:
36
- template = super().assemble_html(child_blocks, parent_structure)
37
  return f"<p block-type='{self.block_type}'>{template}</p>"
38
 
39
  @staticmethod
 
7
  class Equation(Block):
8
  block_type: BlockTypes = BlockTypes.Equation
9
  latex: str | None = None
10
+ block_description: str = "A block math equation."
11
 
12
+ def assemble_html(self, document, child_blocks, parent_structure=None):
13
  if self.latex:
14
  child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
15
+ html_out = super().assemble_html(document, child_ref_blocks, parent_structure)
16
  html_out += f"<p block-type='{self.block_type}'>"
17
 
18
  try:
 
34
  html_out += "</p>"
35
  return html_out
36
  else:
37
+ template = super().assemble_html(document, child_blocks, parent_structure)
38
  return f"<p block-type='{self.block_type}'>{template}</p>"
39
 
40
  @staticmethod
marker/schema/blocks/figure.py CHANGED
@@ -5,10 +5,11 @@ from marker.schema.blocks import Block
5
  class Figure(Block):
6
  block_type: BlockTypes = BlockTypes.Figure
7
  description: str | None = None
 
8
 
9
- def assemble_html(self, child_blocks, parent_structure):
10
  child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
11
- html = super().assemble_html(child_ref_blocks, parent_structure)
12
  if self.description:
13
  html += f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
14
  return html
 
5
  class Figure(Block):
6
  block_type: BlockTypes = BlockTypes.Figure
7
  description: str | None = None
8
+ block_description: str = "A chart or other image that contains data."
9
 
10
+ def assemble_html(self, document, child_blocks, parent_structure):
11
  child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
12
+ html = super().assemble_html(document, child_ref_blocks, parent_structure)
13
  if self.description:
14
  html += f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
15
  return html
marker/schema/blocks/footnote.py CHANGED
@@ -4,4 +4,5 @@ from marker.schema.blocks import Block
4
 
5
  class Footnote(Block):
6
  block_type: BlockTypes = BlockTypes.Footnote
 
7
  replace_output_newlines: bool = True
 
4
 
5
  class Footnote(Block):
6
  block_type: BlockTypes = BlockTypes.Footnote
7
+ block_description: str = "A footnote that explains a term or concept in the document."
8
  replace_output_newlines: bool = True
marker/schema/blocks/form.py CHANGED
@@ -1,20 +1,9 @@
1
  from typing import List
2
 
3
- from tabled.formats import html_format
4
- from tabled.schema import SpanTableCell
5
-
6
  from marker.schema import BlockTypes
7
- from marker.schema.blocks import Block
8
-
9
-
10
- class Form(Block):
11
- block_type: str = BlockTypes.Form
12
- cells: List[SpanTableCell] | None = None
13
- html: str | None = None
14
 
15
- def assemble_html(self, child_blocks, parent_structure=None):
16
- # Some processors convert the form to html
17
- if self.html is not None:
18
- return self.html
19
 
20
- return str(html_format(self.cells))
 
 
 
1
  from typing import List
2
 
 
 
 
3
  from marker.schema import BlockTypes
4
+ from marker.schema.blocks.basetable import BaseTable
 
 
 
 
 
 
5
 
 
 
 
 
6
 
7
+ class Form(BaseTable):
8
+ block_type: BlockTypes = BlockTypes.Form
9
+ block_description: str = "A form, such as a tax form, that contains fields and labels. It most likely doesn't have a table structure."
marker/schema/blocks/handwriting.py CHANGED
@@ -4,4 +4,12 @@ from marker.schema.blocks import Block
4
 
5
  class Handwriting(Block):
6
  block_type: BlockTypes = BlockTypes.Handwriting
 
 
7
  replace_output_newlines: bool = True
 
 
 
 
 
 
 
4
 
5
  class Handwriting(Block):
6
  block_type: BlockTypes = BlockTypes.Handwriting
7
+ block_description: str = "A region that contains handwriting."
8
+ html: str | None = None
9
  replace_output_newlines: bool = True
10
+
11
+ def assemble_html(self, document, child_blocks, parent_structure):
12
+ if self.html:
13
+ return self.html
14
+ else:
15
+ return super().assemble_html(document, child_blocks, parent_structure)