Vik Paruchuri commited on
Commit
6bd5629
·
1 Parent(s): 457f524
marker/builders/ocr.py CHANGED
@@ -35,6 +35,10 @@ class OcrBuilder(BaseBuilder):
35
  "A list of languages to use for OCR.",
36
  "Default is None."
37
  ] = None
 
 
 
 
38
 
39
  def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
40
  super().__init__(config)
@@ -67,12 +71,12 @@ class OcrBuilder(BaseBuilder):
67
 
68
  # Remove tables because we re-OCR them later with the table processor
69
  recognition_results = self.recognition_model(
70
- images=[page.get_image(highres=False, remove_tables=True) for page in page_list],
71
  langs=[self.languages] * len(page_list),
72
  det_predictor=self.detection_model,
73
  detection_batch_size=int(self.get_detection_batch_size()),
74
  recognition_batch_size=int(self.get_recognition_batch_size()),
75
- highres_images=[page.get_image(highres=True, remove_tables=True) for page in page_list]
76
  )
77
 
78
  page_lines = {}
 
35
  "A list of languages to use for OCR.",
36
  "Default is None."
37
  ] = None
38
+ enable_table_ocr: Annotated[
39
+ bool,
40
+ "Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.",
41
+ ] = False
42
 
43
  def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
44
  super().__init__(config)
 
71
 
72
  # Remove tables because we re-OCR them later with the table processor
73
  recognition_results = self.recognition_model(
74
+ images=[page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in page_list],
75
  langs=[self.languages] * len(page_list),
76
  det_predictor=self.detection_model,
77
  detection_batch_size=int(self.get_detection_batch_size()),
78
  recognition_batch_size=int(self.get_recognition_batch_size()),
79
+ highres_images=[page.get_image(highres=True, remove_tables=not self.enable_table_ocr) for page in page_list]
80
  )
81
 
82
  page_lines = {}
marker/schema/blocks/base.py CHANGED
@@ -167,9 +167,10 @@ class Block(BaseModel):
167
  def raw_text(self, document: Document) -> str:
168
  from marker.schema.text.line import Line
169
  from marker.schema.text.span import Span
 
170
 
171
  if self.structure is None:
172
- if isinstance(self, Span):
173
  return self.text
174
  else:
175
  return ""
 
167
  def raw_text(self, document: Document) -> str:
168
  from marker.schema.text.line import Line
169
  from marker.schema.text.span import Span
170
+ from marker.schema.blocks.tablecell import TableCell
171
 
172
  if self.structure is None:
173
+ if isinstance(self, (Span, TableCell)):
174
  return self.text
175
  else:
176
  return ""
marker/scripts/convert.py CHANGED
@@ -100,7 +100,7 @@ def convert_cli(in_folder: str, **kwargs):
100
  else:
101
  model_dict = create_model_dict()
102
  for k, v in model_dict.items():
103
- v.share_memory()
104
 
105
  print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
106
  task_args = [(f, kwargs) for f in files_to_convert]
 
100
  else:
101
  model_dict = create_model_dict()
102
  for k, v in model_dict.items():
103
+ v.model.share_memory()
104
 
105
  print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
106
  task_args = [(f, kwargs) for f in files_to_convert]
tests/builders/test_garbled_pdf.py CHANGED
@@ -2,10 +2,11 @@ import pytest
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
 
5
  from marker.schema import BlockTypes
6
 
7
  @pytest.mark.filename("water_damage.pdf")
8
- def test_garbled_pdf(pdf_document):
9
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
10
 
11
  table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
@@ -16,9 +17,16 @@ def test_garbled_pdf(pdf_document):
16
  assert table_cell.block_type == BlockTypes.Line
17
  assert table_cell.structure[0] == "/page/0/Span/2"
18
 
19
- span = pdf_document.pages[0].get_block(table_cell.structure[0])
20
  assert span.block_type == BlockTypes.Span
21
- assert "комплекс" in span.text
 
 
 
 
 
 
 
22
 
23
 
24
  @pytest.mark.filename("hindi_judgement.pdf")
@@ -30,7 +38,7 @@ def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model):
30
 
31
  bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
32
  assert len(bad_ocr_results.labels) == 2
33
- assert all([l == "bad" for l in bad_ocr_results.labels])
34
 
35
 
36
  @pytest.mark.filename("adversarial.pdf")
 
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
5
+ from marker.processors.table import TableProcessor
6
  from marker.schema import BlockTypes
7
 
8
  @pytest.mark.filename("water_damage.pdf")
9
+ def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec_model):
10
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
11
 
12
  table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
 
17
  assert table_cell.block_type == BlockTypes.Line
18
  assert table_cell.structure[0] == "/page/0/Span/2"
19
 
20
+ span = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Span,))[0]
21
  assert span.block_type == BlockTypes.Span
22
+ assert len(span.text.strip()) == 0
23
+
24
+ # We don't OCR in the initial pass, only with the TableProcessor
25
+ processor = TableProcessor(detection_model, recognition_model, table_rec_model)
26
+ processor(pdf_document)
27
+
28
+ table = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Table,))[0]
29
+ assert "варіант" in table.raw_text(pdf_document)
30
 
31
 
32
  @pytest.mark.filename("hindi_judgement.pdf")
 
38
 
39
  bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
40
  assert len(bad_ocr_results.labels) == 2
41
+ assert any([l == "bad" for l in bad_ocr_results.labels])
42
 
43
 
44
  @pytest.mark.filename("adversarial.pdf")