Vik Paruchuri
commited on
Commit
·
6bd5629
1
Parent(s):
457f524
Fix tests
Browse files- marker/builders/ocr.py +6 -2
- marker/schema/blocks/base.py +2 -1
- marker/scripts/convert.py +1 -1
- tests/builders/test_garbled_pdf.py +12 -4
marker/builders/ocr.py
CHANGED
|
@@ -35,6 +35,10 @@ class OcrBuilder(BaseBuilder):
|
|
| 35 |
"A list of languages to use for OCR.",
|
| 36 |
"Default is None."
|
| 37 |
] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
|
| 40 |
super().__init__(config)
|
|
@@ -67,12 +71,12 @@ class OcrBuilder(BaseBuilder):
|
|
| 67 |
|
| 68 |
# Remove tables because we re-OCR them later with the table processor
|
| 69 |
recognition_results = self.recognition_model(
|
| 70 |
-
images=[page.get_image(highres=False, remove_tables=
|
| 71 |
langs=[self.languages] * len(page_list),
|
| 72 |
det_predictor=self.detection_model,
|
| 73 |
detection_batch_size=int(self.get_detection_batch_size()),
|
| 74 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 75 |
-
highres_images=[page.get_image(highres=True, remove_tables=
|
| 76 |
)
|
| 77 |
|
| 78 |
page_lines = {}
|
|
|
|
| 35 |
"A list of languages to use for OCR.",
|
| 36 |
"Default is None."
|
| 37 |
] = None
|
| 38 |
+
enable_table_ocr: Annotated[
|
| 39 |
+
bool,
|
| 40 |
+
"Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.",
|
| 41 |
+
] = False
|
| 42 |
|
| 43 |
def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
|
| 44 |
super().__init__(config)
|
|
|
|
| 71 |
|
| 72 |
# Remove tables because we re-OCR them later with the table processor
|
| 73 |
recognition_results = self.recognition_model(
|
| 74 |
+
images=[page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in page_list],
|
| 75 |
langs=[self.languages] * len(page_list),
|
| 76 |
det_predictor=self.detection_model,
|
| 77 |
detection_batch_size=int(self.get_detection_batch_size()),
|
| 78 |
recognition_batch_size=int(self.get_recognition_batch_size()),
|
| 79 |
+
highres_images=[page.get_image(highres=True, remove_tables=not self.enable_table_ocr) for page in page_list]
|
| 80 |
)
|
| 81 |
|
| 82 |
page_lines = {}
|
marker/schema/blocks/base.py
CHANGED
|
@@ -167,9 +167,10 @@ class Block(BaseModel):
|
|
| 167 |
def raw_text(self, document: Document) -> str:
|
| 168 |
from marker.schema.text.line import Line
|
| 169 |
from marker.schema.text.span import Span
|
|
|
|
| 170 |
|
| 171 |
if self.structure is None:
|
| 172 |
-
if isinstance(self, Span):
|
| 173 |
return self.text
|
| 174 |
else:
|
| 175 |
return ""
|
|
|
|
| 167 |
def raw_text(self, document: Document) -> str:
|
| 168 |
from marker.schema.text.line import Line
|
| 169 |
from marker.schema.text.span import Span
|
| 170 |
+
from marker.schema.blocks.tablecell import TableCell
|
| 171 |
|
| 172 |
if self.structure is None:
|
| 173 |
+
if isinstance(self, (Span, TableCell)):
|
| 174 |
return self.text
|
| 175 |
else:
|
| 176 |
return ""
|
marker/scripts/convert.py
CHANGED
|
@@ -100,7 +100,7 @@ def convert_cli(in_folder: str, **kwargs):
|
|
| 100 |
else:
|
| 101 |
model_dict = create_model_dict()
|
| 102 |
for k, v in model_dict.items():
|
| 103 |
-
v.share_memory()
|
| 104 |
|
| 105 |
print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
|
| 106 |
task_args = [(f, kwargs) for f in files_to_convert]
|
|
|
|
| 100 |
else:
|
| 101 |
model_dict = create_model_dict()
|
| 102 |
for k, v in model_dict.items():
|
| 103 |
+
v.model.share_memory()
|
| 104 |
|
| 105 |
print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
|
| 106 |
task_args = [(f, kwargs) for f in files_to_convert]
|
tests/builders/test_garbled_pdf.py
CHANGED
|
@@ -2,10 +2,11 @@ import pytest
|
|
| 2 |
|
| 3 |
from marker.builders.document import DocumentBuilder
|
| 4 |
from marker.builders.layout import LayoutBuilder
|
|
|
|
| 5 |
from marker.schema import BlockTypes
|
| 6 |
|
| 7 |
@pytest.mark.filename("water_damage.pdf")
|
| 8 |
-
def test_garbled_pdf(pdf_document):
|
| 9 |
assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
|
| 10 |
|
| 11 |
table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
|
|
@@ -16,9 +17,16 @@ def test_garbled_pdf(pdf_document):
|
|
| 16 |
assert table_cell.block_type == BlockTypes.Line
|
| 17 |
assert table_cell.structure[0] == "/page/0/Span/2"
|
| 18 |
|
| 19 |
-
span = pdf_document.pages[0].
|
| 20 |
assert span.block_type == BlockTypes.Span
|
| 21 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@pytest.mark.filename("hindi_judgement.pdf")
|
|
@@ -30,7 +38,7 @@ def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model):
|
|
| 30 |
|
| 31 |
bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
|
| 32 |
assert len(bad_ocr_results.labels) == 2
|
| 33 |
-
assert
|
| 34 |
|
| 35 |
|
| 36 |
@pytest.mark.filename("adversarial.pdf")
|
|
|
|
| 2 |
|
| 3 |
from marker.builders.document import DocumentBuilder
|
| 4 |
from marker.builders.layout import LayoutBuilder
|
| 5 |
+
from marker.processors.table import TableProcessor
|
| 6 |
from marker.schema import BlockTypes
|
| 7 |
|
| 8 |
@pytest.mark.filename("water_damage.pdf")
|
| 9 |
+
def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec_model):
|
| 10 |
assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
|
| 11 |
|
| 12 |
table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
|
|
|
|
| 17 |
assert table_cell.block_type == BlockTypes.Line
|
| 18 |
assert table_cell.structure[0] == "/page/0/Span/2"
|
| 19 |
|
| 20 |
+
span = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Span,))[0]
|
| 21 |
assert span.block_type == BlockTypes.Span
|
| 22 |
+
assert len(span.text.strip()) == 0
|
| 23 |
+
|
| 24 |
+
# We don't OCR in the initial pass, only with the TableProcessor
|
| 25 |
+
processor = TableProcessor(detection_model, recognition_model, table_rec_model)
|
| 26 |
+
processor(pdf_document)
|
| 27 |
+
|
| 28 |
+
table = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Table,))[0]
|
| 29 |
+
assert "варіант" in table.raw_text(pdf_document)
|
| 30 |
|
| 31 |
|
| 32 |
@pytest.mark.filename("hindi_judgement.pdf")
|
|
|
|
| 38 |
|
| 39 |
bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
|
| 40 |
assert len(bad_ocr_results.labels) == 2
|
| 41 |
+
assert any([l == "bad" for l in bad_ocr_results.labels])
|
| 42 |
|
| 43 |
|
| 44 |
@pytest.mark.filename("adversarial.pdf")
|