Vik Paruchuri commited on
Commit
ed65502
·
1 Parent(s): e6cc383

Fix ocr converter

Browse files
README.md CHANGED
@@ -227,6 +227,27 @@ You can also run this via the CLI with
227
  marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
228
  ```
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  # Output Formats
231
 
232
  ## Markdown
 
227
  marker_single FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter --output_format json
228
  ```
229
 
230
+ ### OCR Only
231
+
232
+ If you only want to run OCR, you can also do that through the `OCRConverter`.
233
+
234
+ ```python
235
+ from marker.converters.ocr import OCRConverter
236
+ from marker.models import create_model_dict
237
+
238
+ converter = OCRConverter(
239
+ artifact_dict=create_model_dict(),
240
+ )
241
+ rendered = converter("FILEPATH")
242
+ ```
243
+
244
+ This takes all the same configuration as the PdfConverter.
245
+
246
+ You can also run this via the CLI with
247
+ ```shell
248
+ marker_single FILENAME --converter_cls marker.converters.ocr.OCRConverter
249
+ ```
250
+
251
  # Output Formats
252
 
253
  ## Markdown
marker/builders/ocr.py CHANGED
@@ -171,10 +171,12 @@ class OcrBuilder(BaseBuilder):
171
  before_span, after_span = None, None
172
  if before_text:
173
  before_span = copy.deepcopy(span)
 
174
  before_span.text = before_text
175
  if after_text:
176
  after_span = copy.deepcopy(span)
177
  after_span.text = after_text
 
178
 
179
  match_span = copy.deepcopy(span)
180
  match_span.text = match_text
@@ -214,7 +216,6 @@ class OcrBuilder(BaseBuilder):
214
  if not matched:
215
  remaining_span = copy.deepcopy(original_span)
216
  remaining_span.text = remaining_text
217
- remaining_span.structure = []
218
  final_new_spans.append(remaining_span)
219
  break
220
 
@@ -287,10 +288,11 @@ class OcrBuilder(BaseBuilder):
287
  current_span.html = (
288
  f'<math display="inline">{current_span.text}</math>'
289
  )
 
 
290
  spans.append(current_span)
291
  current_span = None
292
 
293
- current_chars = self.assign_chars(current_span, current_chars)
294
  continue
295
 
296
  if not current_span:
 
171
  before_span, after_span = None, None
172
  if before_text:
173
  before_span = copy.deepcopy(span)
174
+ before_span.structure = [] # Avoid duplicate characters
175
  before_span.text = before_text
176
  if after_text:
177
  after_span = copy.deepcopy(span)
178
  after_span.text = after_text
179
+ after_span.structure = [] # Avoid duplicate characters
180
 
181
  match_span = copy.deepcopy(span)
182
  match_span.text = match_text
 
216
  if not matched:
217
  remaining_span = copy.deepcopy(original_span)
218
  remaining_span.text = remaining_text
 
219
  final_new_spans.append(remaining_span)
220
  break
221
 
 
288
  current_span.html = (
289
  f'<math display="inline">{current_span.text}</math>'
290
  )
291
+
292
+ current_chars = self.assign_chars(current_span, current_chars)
293
  spans.append(current_span)
294
  current_span = None
295
 
 
296
  continue
297
 
298
  if not current_span:
marker/converters/ocr.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from marker.builders.document import DocumentBuilder
4
+ from marker.builders.line import LineBuilder
5
+ from marker.builders.ocr import OcrBuilder
6
+ from marker.converters.pdf import PdfConverter
7
+ from marker.processors import BaseProcessor
8
+ from marker.processors.equation import EquationProcessor
9
+ from marker.providers.registry import provider_from_filepath
10
+ from marker.renderers.ocr_json import OCRJSONRenderer
11
+
12
+
13
+ class OCRConverter(PdfConverter):
14
+ default_processors: Tuple[BaseProcessor, ...] = (EquationProcessor,)
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ if not self.config:
20
+ self.config = {}
21
+
22
+ self.config["format_lines"] = True
23
+ self.config["keep_chars"] = True
24
+ self.renderer = OCRJSONRenderer
25
+
26
+ def build_document(self, filepath: str):
27
+ provider_cls = provider_from_filepath(filepath)
28
+ layout_builder = self.resolve_dependencies(self.layout_builder_class)
29
+ line_builder = self.resolve_dependencies(LineBuilder)
30
+ ocr_builder = self.resolve_dependencies(OcrBuilder)
31
+ document_builder = DocumentBuilder(self.config)
32
+
33
+ provider = provider_cls(filepath, self.config)
34
+ document = document_builder(provider, layout_builder, line_builder, ocr_builder)
35
+
36
+ for processor in self.processor_list:
37
+ processor(document)
38
+
39
+ return document
40
+
41
+ def __call__(self, filepath: str):
42
+ document = self.build_document(filepath)
43
+ renderer = self.resolve_dependencies(self.renderer)
44
+ return renderer(document)
marker/output.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  from marker.renderers.html import HTMLOutput
9
  from marker.renderers.json import JSONOutput, JSONBlockOutput
10
  from marker.renderers.markdown import MarkdownOutput
 
11
  from marker.schema.blocks import BlockOutput
12
  from marker.settings import settings
13
 
@@ -57,6 +58,8 @@ def text_from_rendered(rendered: BaseModel):
57
  return rendered.html, "html", rendered.images
58
  elif isinstance(rendered, JSONOutput):
59
  return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
 
 
60
  else:
61
  raise ValueError("Invalid output type")
62
 
 
8
  from marker.renderers.html import HTMLOutput
9
  from marker.renderers.json import JSONOutput, JSONBlockOutput
10
  from marker.renderers.markdown import MarkdownOutput
11
+ from marker.renderers.ocr_json import OCRJSONOutput
12
  from marker.schema.blocks import BlockOutput
13
  from marker.settings import settings
14
 
 
58
  return rendered.html, "html", rendered.images
59
  elif isinstance(rendered, JSONOutput):
60
  return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
61
+ elif isinstance(rendered, OCRJSONOutput):
62
+ return rendered.model_dump_json(exclude=["metadata"], indent=2), "json", {}
63
  else:
64
  raise ValueError("Invalid output type")
65
 
marker/providers/pdf.py CHANGED
@@ -239,7 +239,7 @@ class PdfProvider(BaseProvider):
239
  )
240
  span_chars = [
241
  CharClass(
242
- char=c["char"],
243
  polygon=PolygonBox.from_bbox(
244
  c["bbox"], ensure_nonzero_area=True
245
  ),
 
239
  )
240
  span_chars = [
241
  CharClass(
242
+ text=c["char"],
243
  polygon=PolygonBox.from_bbox(
244
  c["bbox"], ensure_nonzero_area=True
245
  ),
marker/renderers/ocr_json.py CHANGED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, List, Tuple
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from marker.renderers import BaseRenderer
6
+ from marker.schema import BlockTypes
7
+ from marker.schema.document import Document
8
+
9
+
10
+ class OCRJSONCharOutput(BaseModel):
11
+ id: str
12
+ block_type: str
13
+ text: str
14
+ polygon: List[List[float]]
15
+ bbox: List[float]
16
+
17
+
18
+ class OCRJSONLineOutput(BaseModel):
19
+ id: str
20
+ block_type: str
21
+ html: str
22
+ polygon: List[List[float]]
23
+ bbox: List[float]
24
+ children: List["OCRJSONCharOutput"] | None = None
25
+
26
+
27
+ class OCRJSONPageOutput(BaseModel):
28
+ id: str
29
+ block_type: str
30
+ polygon: List[List[float]]
31
+ bbox: List[float]
32
+ children: List[OCRJSONLineOutput] | None = None
33
+
34
+
35
+ class OCRJSONOutput(BaseModel):
36
+ children: List[OCRJSONPageOutput]
37
+ block_type: str = str(BlockTypes.Document)
38
+ metadata: dict | None = None
39
+
40
+
41
+ class OCRJSONRenderer(BaseRenderer):
42
+ """
43
+ A renderer for OCR JSON output.
44
+ """
45
+
46
+ image_blocks: Annotated[
47
+ Tuple[BlockTypes],
48
+ "The list of block types to consider as images.",
49
+ ] = (BlockTypes.Picture, BlockTypes.Figure)
50
+ page_blocks: Annotated[
51
+ Tuple[BlockTypes],
52
+ "The list of block types to consider as pages.",
53
+ ] = (BlockTypes.Page,)
54
+
55
+ def extract_json(self, document: Document) -> List[OCRJSONPageOutput]:
56
+ pages = []
57
+ for page in document.pages:
58
+ page_equations = [
59
+ b for b in page.children if b.block_type == BlockTypes.Equation
60
+ ]
61
+ equation_lines = []
62
+ for equation in page_equations:
63
+ if not equation.structure:
64
+ continue
65
+
66
+ equation_lines += [
67
+ line
68
+ for line in equation.structure
69
+ if line.block_type == BlockTypes.Line
70
+ ]
71
+
72
+ page_lines = [
73
+ block
74
+ for block in page.children
75
+ if block.block_type == BlockTypes.Line
76
+ and block.id not in equation_lines
77
+ ]
78
+
79
+ lines = []
80
+ for line in page_lines + page_equations:
81
+ line_obj = OCRJSONLineOutput(
82
+ id=str(line.id),
83
+ block_type=str(line.block_type),
84
+ html="",
85
+ polygon=line.polygon.polygon,
86
+ bbox=line.polygon.bbox,
87
+ )
88
+ if line in page_equations:
89
+ line_obj.html = line.html
90
+ else:
91
+ line_obj.html = line.formatted_text(document)
92
+ spans = [document.get_block(span_id) for span_id in line.structure]
93
+ children = []
94
+ for span in spans:
95
+ if not span.structure:
96
+ continue
97
+
98
+ span_chars = [
99
+ document.get_block(char_id) for char_id in span.structure
100
+ ]
101
+ children.extend(
102
+ [
103
+ OCRJSONCharOutput(
104
+ id=str(char.id),
105
+ block_type=str(char.block_type),
106
+ text=char.text,
107
+ polygon=char.polygon.polygon,
108
+ bbox=char.polygon.bbox,
109
+ )
110
+ for char in span_chars
111
+ ]
112
+ )
113
+ line_obj.children = children
114
+ lines.append(line_obj)
115
+
116
+ page = OCRJSONPageOutput(
117
+ id=str(page.id),
118
+ block_type=str(page.block_type),
119
+ polygon=page.polygon.polygon,
120
+ bbox=page.polygon.bbox,
121
+ children=lines,
122
+ )
123
+ pages.append(page)
124
+
125
+ return pages
126
+
127
+ def __call__(self, document: Document) -> OCRJSONOutput:
128
+ return OCRJSONOutput(children=self.extract_json(document), metadata=None)
marker/schema/groups/page.py CHANGED
@@ -253,14 +253,20 @@ class PageGroup(Group):
253
  block.add_structure(line)
254
  block.polygon = block.polygon.merge([line.polygon])
255
  block.text_extraction_method = text_extraction_method
256
- for span in spans:
257
  self.add_full_block(span)
258
  line.add_structure(span)
259
 
260
  if not keep_chars:
261
  continue
262
 
263
- for char in provider_output.chars:
 
 
 
 
 
 
264
  self.add_full_block(char)
265
  span.add_structure(char)
266
 
 
253
  block.add_structure(line)
254
  block.polygon = block.polygon.merge([line.polygon])
255
  block.text_extraction_method = text_extraction_method
256
+ for span_idx, span in enumerate(spans):
257
  self.add_full_block(span)
258
  line.add_structure(span)
259
 
260
  if not keep_chars:
261
  continue
262
 
263
+ # Provider doesn't have chars
264
+ if len(provider_output.chars) == 0:
265
+ continue
266
+
267
+ # Loop through characters associated with the span
268
+ for char in provider_output.chars[span_idx]:
269
+ char.page_id = self.page_id
270
  self.add_full_block(char)
271
  span.add_structure(char)
272
 
marker/schema/text/char.py CHANGED
@@ -6,5 +6,5 @@ class Char(Block):
6
  block_type: BlockTypes = BlockTypes.Char
7
  block_description: str = "A single character inside a span."
8
 
9
- char: str
10
  idx: int
 
6
  block_type: BlockTypes = BlockTypes.Char
7
  block_description: str = "A single character inside a span."
8
 
9
+ text: str
10
  idx: int
tests/conftest.py CHANGED
@@ -1,4 +1,3 @@
1
- from marker.providers.pdf import PdfProvider
2
  import tempfile
3
  from typing import Dict, Type
4
 
@@ -19,7 +18,6 @@ from marker.schema.blocks import Block
19
  from marker.renderers.markdown import MarkdownRenderer
20
  from marker.renderers.json import JSONRenderer
21
  from marker.schema.registry import register_block_class
22
- from marker.services.gemini import GoogleGeminiService
23
  from marker.util import classes_to_strings, strings_to_classes
24
 
25
 
@@ -54,6 +52,7 @@ def table_rec_model(model_dict):
54
  def ocr_error_model(model_dict):
55
  yield model_dict["ocr_error_model"]
56
 
 
57
  @pytest.fixture(scope="function")
58
  def config(request):
59
  config_mark = request.node.get_closest_marker("config")
@@ -65,20 +64,22 @@ def config(request):
65
 
66
  return config
67
 
 
68
  @pytest.fixture(scope="session")
69
  def pdf_dataset():
70
  return datasets.load_dataset("datalab-to/pdfs", split="train")
71
 
 
72
  @pytest.fixture(scope="function")
73
  def temp_doc(request, pdf_dataset):
74
  filename_mark = request.node.get_closest_marker("filename")
75
  filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"
76
 
77
- idx = pdf_dataset['filename'].index(filename)
78
  suffix = filename.split(".")[-1]
79
 
80
  temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
81
- temp_pdf.write(pdf_dataset['pdf'][idx])
82
  temp_pdf.flush()
83
  yield temp_pdf
84
 
@@ -88,8 +89,17 @@ def doc_provider(request, config, temp_doc):
88
  provider_cls = provider_from_filepath(temp_doc.name)
89
  yield provider_cls(temp_doc.name, config)
90
 
 
91
  @pytest.fixture(scope="function")
92
- def pdf_document(request, config, doc_provider, layout_model, ocr_error_model, recognition_model, detection_model):
 
 
 
 
 
 
 
 
93
  layout_builder = LayoutBuilder(layout_model, config)
94
  line_builder = LineBuilder(detection_model, ocr_error_model, config)
95
  ocr_builder = OcrBuilder(recognition_model, config)
@@ -107,7 +117,7 @@ def pdf_converter(request, config, model_dict, renderer, llm_service):
107
  processor_list=None,
108
  renderer=classes_to_strings([renderer])[0],
109
  config=config,
110
- llm_service=llm_service
111
  )
112
 
113
 
 
 
1
  import tempfile
2
  from typing import Dict, Type
3
 
 
18
  from marker.renderers.markdown import MarkdownRenderer
19
  from marker.renderers.json import JSONRenderer
20
  from marker.schema.registry import register_block_class
 
21
  from marker.util import classes_to_strings, strings_to_classes
22
 
23
 
 
52
  def ocr_error_model(model_dict):
53
  yield model_dict["ocr_error_model"]
54
 
55
+
56
  @pytest.fixture(scope="function")
57
  def config(request):
58
  config_mark = request.node.get_closest_marker("config")
 
64
 
65
  return config
66
 
67
+
68
  @pytest.fixture(scope="session")
69
  def pdf_dataset():
70
  return datasets.load_dataset("datalab-to/pdfs", split="train")
71
 
72
+
73
  @pytest.fixture(scope="function")
74
  def temp_doc(request, pdf_dataset):
75
  filename_mark = request.node.get_closest_marker("filename")
76
  filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"
77
 
78
+ idx = pdf_dataset["filename"].index(filename)
79
  suffix = filename.split(".")[-1]
80
 
81
  temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
82
+ temp_pdf.write(pdf_dataset["pdf"][idx])
83
  temp_pdf.flush()
84
  yield temp_pdf
85
 
 
89
  provider_cls = provider_from_filepath(temp_doc.name)
90
  yield provider_cls(temp_doc.name, config)
91
 
92
+
93
  @pytest.fixture(scope="function")
94
+ def pdf_document(
95
+ request,
96
+ config,
97
+ doc_provider,
98
+ layout_model,
99
+ ocr_error_model,
100
+ recognition_model,
101
+ detection_model,
102
+ ):
103
  layout_builder = LayoutBuilder(layout_model, config)
104
  line_builder = LineBuilder(detection_model, ocr_error_model, config)
105
  ocr_builder = OcrBuilder(recognition_model, config)
 
117
  processor_list=None,
118
  renderer=classes_to_strings([renderer])[0],
119
  config=config,
120
+ llm_service=llm_service,
121
  )
122
 
123
 
tests/converters/test_ocr_converter.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from marker.converters.ocr import OCRConverter
4
+ from marker.renderers.ocr_json import OCRJSONOutput
5
+
6
+
7
+ def _ocr_converter(config, model_dict, temp_pdf):
8
+ converter = OCRConverter(artifact_dict=model_dict, config=config)
9
+
10
+ ocr_json: OCRJSONOutput = converter(temp_pdf.name)
11
+ pages = ocr_json.pages
12
+
13
+ assert len(pages) == 1
14
+ breakpoint()
15
+
16
+
17
+ @pytest.mark.config({"page_range": [0]})
18
+ def test_ocr_converter(config, model_dict, temp_doc):
19
+ _ocr_converter(config, model_dict, temp_doc)