Vik Paruchuri commited on
Commit
f7ff7f7
·
1 Parent(s): 621602c

Add tests for llm processors

Browse files
README.md CHANGED
@@ -149,7 +149,7 @@ text, _, images = text_from_rendered(rendered)
149
 
150
  ### Custom configuration
151
 
152
- You can also pass configuration using the `ConfigParser`:
153
 
154
  ```python
155
  from marker.converters.pdf import PdfConverter
@@ -171,6 +171,26 @@ converter = PdfConverter(
171
  rendered = converter("FILEPATH")
172
  ```
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Output Formats
175
 
176
  ## Markdown
 
149
 
150
  ### Custom configuration
151
 
152
+ You can pass configuration using the `ConfigParser`:
153
 
154
  ```python
155
  from marker.converters.pdf import PdfConverter
 
171
  rendered = converter("FILEPATH")
172
  ```
173
 
174
+ ### Extract blocks
175
+
176
+ Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programatically manipulate these blocks.
177
+
178
+ Here's an example of extracting all forms from a document:
179
+
180
+ ```python
181
+ from marker.converters.pdf import PdfConverter
182
+ from marker.models import create_model_dict
183
+ from marker.schema import BlockTypes
184
+
185
+ converter = PdfConverter(
186
+ artifact_dict=create_model_dict(),
187
+ )
188
+ document = converter.build_document("FILEPATH")
189
+ forms = document.contained_blocks((BlockTypes.Form,))
190
+ ```
191
+
192
+ Look at the processors for more examples of extracting and manipulating blocks.
193
+
194
  # Output Formats
195
 
196
  ## Markdown
convert.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
 
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
4
  os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya
5
 
 
1
  import os
2
 
3
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
4
+ os.environ["GLOG_minloglevel"] = "2"
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
  os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya
7
 
convert_single.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
 
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
4
 
5
  import time
6
-
7
  import click
8
 
9
  from marker.config.parser import ConfigParser
 
1
  import os
2
 
3
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
4
+ os.environ["GLOG_minloglevel"] = "2"
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
 
7
  import time
 
8
  import click
9
 
10
  from marker.config.parser import ConfigParser
marker/builders/llm_layout.py CHANGED
@@ -51,7 +51,7 @@ class LLMLayoutBuilder(LayoutBuilder):
51
  """
52
 
53
  google_api_key: Optional[str] = settings.GOOGLE_API_KEY
54
- confidence_threshold: float = 0.7
55
  model_name: str = "gemini-1.5-flash"
56
  max_retries: int = 3
57
  max_concurrency: int = 3
 
51
  """
52
 
53
  google_api_key: Optional[str] = settings.GOOGLE_API_KEY
54
+ confidence_threshold: float = 0.75
55
  model_name: str = "gemini-1.5-flash"
56
  max_retries: int = 3
57
  max_concurrency: int = 3
marker/converters/pdf.py CHANGED
@@ -109,7 +109,7 @@ class PdfConverter(BaseConverter):
109
 
110
  return cls(**resolved_kwargs)
111
 
112
- def __call__(self, filepath: str):
113
  pdf_provider = PdfProvider(filepath, self.config)
114
  layout_builder = self.resolve_dependencies(self.layout_builder_class)
115
  ocr_builder = self.resolve_dependencies(OcrBuilder)
@@ -120,5 +120,9 @@ class PdfConverter(BaseConverter):
120
  processor = self.resolve_dependencies(processor_cls)
121
  processor(document)
122
 
 
 
 
 
123
  renderer = self.resolve_dependencies(self.renderer)
124
  return renderer(document)
 
109
 
110
  return cls(**resolved_kwargs)
111
 
112
+ def build_document(self, filepath: str):
113
  pdf_provider = PdfProvider(filepath, self.config)
114
  layout_builder = self.resolve_dependencies(self.layout_builder_class)
115
  ocr_builder = self.resolve_dependencies(OcrBuilder)
 
120
  processor = self.resolve_dependencies(processor_cls)
121
  processor(document)
122
 
123
+ return document
124
+
125
+ def __call__(self, filepath: str):
126
+ document = self.build_document(filepath)
127
  renderer = self.resolve_dependencies(self.renderer)
128
  return renderer(document)
marker/processors/llm/__init__.py CHANGED
@@ -32,6 +32,9 @@ class BaseLLMProcessor(BaseProcessor):
32
  gemini_rewriting_prompt (str):
33
  The prompt to use for rewriting text.
34
  Default is a string containing the Gemini rewriting prompt.
 
 
 
35
  """
36
 
37
  google_api_key: Optional[str] = settings.GOOGLE_API_KEY
@@ -57,7 +60,10 @@ class BaseLLMProcessor(BaseProcessor):
57
  if not self.use_llm or self.model is None:
58
  return
59
 
60
- self.rewrite_blocks(document)
 
 
 
61
 
62
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
63
  raise NotImplementedError()
 
32
  gemini_rewriting_prompt (str):
33
  The prompt to use for rewriting text.
34
  Default is a string containing the Gemini rewriting prompt.
35
+ use_llm (bool):
36
+ Whether to use the LLM model.
37
+ Default is False.
38
  """
39
 
40
  google_api_key: Optional[str] = settings.GOOGLE_API_KEY
 
60
  if not self.use_llm or self.model is None:
61
  return
62
 
63
+ try:
64
+ self.rewrite_blocks(document)
65
+ except Exception as e:
66
+ print(f"Error rewriting blocks in {self.__class__.__name__}: {e}")
67
 
68
  def process_rewriting(self, document: Document, page: PageGroup, block: Block):
69
  raise NotImplementedError()
marker/processors/llm/llm_form.py CHANGED
@@ -1,11 +1,8 @@
1
  import markdown2
2
 
3
  from marker.processors.llm import BaseLLMProcessor
4
- from marker.processors.llm.utils import GoogleModel
5
- from concurrent.futures import ThreadPoolExecutor, as_completed
6
 
7
  from google.ai.generativelanguage_v1beta.types import content
8
- from tqdm import tqdm
9
  from tabled.formats import markdown_format
10
 
11
  from marker.schema import BlockTypes
 
1
  import markdown2
2
 
3
  from marker.processors.llm import BaseLLMProcessor
 
 
4
 
5
  from google.ai.generativelanguage_v1beta.types import content
 
6
  from tabled.formats import markdown_format
7
 
8
  from marker.schema import BlockTypes
marker/schema/document.py CHANGED
@@ -1,6 +1,6 @@
1
  from __future__ import annotations
2
 
3
- from typing import List
4
 
5
  from pydantic import BaseModel
6
 
@@ -100,3 +100,9 @@ class Document(BaseModel):
100
  children=child_content,
101
  html=self.assemble_html(child_content)
102
  )
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import List, Sequence
4
 
5
  from pydantic import BaseModel
6
 
 
100
  children=child_content,
101
  html=self.assemble_html(child_content)
102
  )
103
+
104
+ def contained_blocks(self, block_types: Sequence[BlockTypes] = None) -> List[Block]:
105
+ blocks = []
106
+ for page in self.pages:
107
+ blocks += page.contained_blocks(self, block_types)
108
+ return blocks
poetry.lock CHANGED
@@ -3449,6 +3449,23 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
3449
  [package.extras]
3450
  dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
3451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3452
  [[package]]
3453
  name = "python-dateutil"
3454
  version = "2.9.0.post0"
@@ -5407,4 +5424,4 @@ propcache = ">=0.2.0"
5407
  [metadata]
5408
  lock-version = "2.0"
5409
  python-versions = "^3.10"
5410
- content-hash = "20eee90138195d778e93da276c2d02e6547738e8eedf3c0a355eaecb128a58c0"
 
3449
  [package.extras]
3450
  dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
3451
 
3452
+ [[package]]
3453
+ name = "pytest-mock"
3454
+ version = "3.14.0"
3455
+ description = "Thin-wrapper around the mock package for easier use with pytest"
3456
+ optional = false
3457
+ python-versions = ">=3.8"
3458
+ files = [
3459
+ {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
3460
+ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
3461
+ ]
3462
+
3463
+ [package.dependencies]
3464
+ pytest = ">=6.2.5"
3465
+
3466
+ [package.extras]
3467
+ dev = ["pre-commit", "pytest-asyncio", "tox"]
3468
+
3469
  [[package]]
3470
  name = "python-dateutil"
3471
  version = "2.9.0.post0"
 
5424
  [metadata]
5425
  lock-version = "2.0"
5426
  python-versions = "^3.10"
5427
+ content-hash = "2a4dfa94c63b5cf4b614fb4908abd2c80c363e9ed4ebf53b71af9bba90b783fd"
pyproject.toml CHANGED
@@ -50,6 +50,7 @@ fastapi = "^0.115.4"
50
  uvicorn = "^0.32.0"
51
  python-multipart = "^0.0.16"
52
  pytest = "^8.3.3"
 
53
 
54
  [tool.poetry.scripts]
55
  marker = "convert:main"
 
50
  uvicorn = "^0.32.0"
51
  python-multipart = "^0.0.16"
52
  pytest = "^8.3.3"
53
+ pytest-mock = "^3.14.0"
54
 
55
  [tool.poetry.scripts]
56
  marker = "convert:main"
tests/processors/test_llm_processors.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import MagicMock, Mock
2
+
3
+ import pytest
4
+
5
+ from marker.processors.llm.llm_form import LLMFormProcessor
6
+ from marker.processors.llm.llm_table import LLMTableProcessor
7
+ from marker.processors.llm.llm_text import LLMTextProcessor
8
+ from marker.processors.table import TableProcessor
9
+ from marker.schema import BlockTypes
10
+
11
+ @pytest.mark.filename("form_1040.pdf")
12
+ @pytest.mark.config({"page_range": [0]})
13
+ def test_llm_form_processor_no_config(pdf_document):
14
+ processor = LLMFormProcessor()
15
+ processor(pdf_document)
16
+
17
+ forms = pdf_document.contained_blocks((BlockTypes.Form,))
18
+ assert forms[0].html is None
19
+
20
+
21
+ @pytest.mark.filename("form_1040.pdf")
22
+ @pytest.mark.config({"page_range": [0]})
23
+ def test_llm_form_processor_no_cells(pdf_document):
24
+ processor = LLMFormProcessor({"use_llm": True})
25
+ processor(pdf_document)
26
+
27
+ forms = pdf_document.contained_blocks((BlockTypes.Form,))
28
+ assert forms[0].html is None
29
+
30
+
31
+ @pytest.mark.filename("form_1040.pdf")
32
+ @pytest.mark.config({"page_range": [0]})
33
+ def test_llm_form_processor(pdf_document, detection_model, table_rec_model, recognition_model, mocker):
34
+ corrected_markdown = "*This is corrected markdown.*\n" * 100
35
+
36
+ corrected_html = "<em>This is corrected markdown.</em>\n" * 100
37
+ corrected_html = "<p>" + corrected_html.strip() + "</p>\n"
38
+
39
+ mock_cls = Mock()
40
+ mock_cls.return_value.generate_response.return_value = {"corrected_markdown": corrected_markdown}
41
+ mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
42
+
43
+ cell_processor = TableProcessor(detection_model, recognition_model, table_rec_model)
44
+ cell_processor(pdf_document)
45
+
46
+ processor = LLMFormProcessor({"use_llm": True})
47
+ processor(pdf_document)
48
+
49
+ forms = pdf_document.contained_blocks((BlockTypes.Form,))
50
+ assert forms[0].html == corrected_html
51
+
52
+
53
+
54
+ @pytest.mark.filename("table_ex2.pdf")
55
+ @pytest.mark.config({"page_range": [0]})
56
+ def test_llm_table_processor(pdf_document, detection_model, table_rec_model, recognition_model, mocker):
57
+ corrected_markdown = """
58
+ | Column 1 | Column 2 | Column 3 | Column 4 |
59
+ |----------|----------|----------|----------|
60
+ | Value 1 | Value 2 | Value 3 | Value 4 |
61
+ | Value 5 | Value 6 | Value 7 | Value 8 |
62
+ | Value 9 | Value 10 | Value 11 | Value 12 |
63
+ """.strip()
64
+
65
+ mock_cls = Mock()
66
+ mock_cls.return_value.generate_response.return_value = {"corrected_markdown": corrected_markdown}
67
+ mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
68
+
69
+ cell_processor = TableProcessor(detection_model, recognition_model, table_rec_model)
70
+ cell_processor(pdf_document)
71
+
72
+ processor = LLMTableProcessor({"use_llm": True})
73
+ processor(pdf_document)
74
+
75
+ tables = pdf_document.contained_blocks((BlockTypes.Table,))
76
+ assert tables[0].cells[0].text == "Column 1"
77
+
78
+
79
+ @pytest.mark.filename("adversarial.pdf")
80
+ @pytest.mark.config({"page_range": [0]})
81
+ def test_llm_text_processor(pdf_document, mocker):
82
+ inline_math_block = pdf_document.contained_blocks((BlockTypes.TextInlineMath,))[0]
83
+ text_lines = inline_math_block.contained_blocks(pdf_document, (BlockTypes.Line,))
84
+ corrected_lines = ["<i>Text</i>"] * len(text_lines)
85
+
86
+ mock_cls = Mock()
87
+ mock_cls.return_value.generate_response.return_value = {"corrected_lines": corrected_lines}
88
+ mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
89
+
90
+ processor = LLMTextProcessor({"use_llm": True})
91
+ processor(pdf_document)
92
+
93
+ contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,))
94
+ assert contained_spans[0].text == "Text\n" # Newline inserted at end of line
95
+ assert contained_spans[0].formats == ["italic"]