Vik Paruchuri commited on
Commit
969ff96
·
1 Parent(s): 6ff9f43

Align with surya refactor

Browse files
marker/builders/layout.py CHANGED
@@ -1,11 +1,10 @@
1
  from typing import Annotated, List, Optional, Tuple
2
 
3
  import numpy as np
4
- from surya.layout import batch_layout_detection
5
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
6
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
7
- from surya.ocr_error import batch_ocr_error_detection
8
- from surya.schema import LayoutResult, OCRErrorDetectionResult
9
 
10
  from marker.builders import BaseBuilder
11
  from marker.providers import ProviderOutput, ProviderPageLines
@@ -52,7 +51,7 @@ class LayoutBuilder(BaseBuilder):
52
  "A list of block types to exclude from the layout coverage check.",
53
  ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
54
 
55
- def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
56
  self.layout_model = layout_model
57
  self.ocr_error_model = ocr_error_model
58
 
@@ -71,11 +70,8 @@ class LayoutBuilder(BaseBuilder):
71
  return 6
72
 
73
  def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
74
- processor = self.layout_model.processor
75
- layout_results = batch_layout_detection(
76
  [p.lowres_image for p in pages],
77
- self.layout_model,
78
- processor,
79
  batch_size=int(self.get_batch_size())
80
  )
81
  return layout_results
@@ -97,10 +93,8 @@ class LayoutBuilder(BaseBuilder):
97
 
98
  page_texts.append(page_text)
99
 
100
- ocr_error_detection_results = batch_ocr_error_detection(
101
  page_texts,
102
- self.ocr_error_model,
103
- self.ocr_error_model.tokenizer,
104
  batch_size=int(self.get_batch_size()) # TODO Better Multiplier
105
  )
106
  return ocr_error_detection_results
 
1
  from typing import Annotated, List, Optional, Tuple
2
 
3
  import numpy as np
4
+ from surya.layout import LayoutPredictor
5
+ from surya.layout.schema import LayoutResult
6
+ from surya.ocr_error import OCRErrorPredictor
7
+ from surya.ocr_error.schema import OCRErrorDetectionResult
 
8
 
9
  from marker.builders import BaseBuilder
10
  from marker.providers import ProviderOutput, ProviderPageLines
 
51
  "A list of block types to exclude from the layout coverage check.",
52
  ] = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)
53
 
54
+ def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
55
  self.layout_model = layout_model
56
  self.ocr_error_model = ocr_error_model
57
 
 
70
  return 6
71
 
72
  def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
73
+ layout_results = self.layout_model(
 
74
  [p.lowres_image for p in pages],
 
 
75
  batch_size=int(self.get_batch_size())
76
  )
77
  return layout_results
 
93
 
94
  page_texts.append(page_text)
95
 
96
+ ocr_error_detection_results = self.ocr_error_model(
97
  page_texts,
 
 
98
  batch_size=int(self.get_batch_size()) # TODO Better Multiplier
99
  )
100
  return ocr_error_detection_results
marker/builders/llm_layout.py CHANGED
@@ -3,8 +3,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
3
  from typing import Annotated, Optional
4
 
5
  from google.ai.generativelanguage_v1beta.types import content
6
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
7
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
8
  from tqdm import tqdm
9
 
10
  from marker.builders.layout import LayoutBuilder
@@ -91,7 +91,7 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
91
  Here is the image of the layout block:
92
  """
93
 
94
- def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
95
  super().__init__(layout_model, ocr_error_model, config)
96
 
97
  self.model = GoogleModel(self.google_api_key, self.model_name)
 
3
  from typing import Annotated, Optional
4
 
5
  from google.ai.generativelanguage_v1beta.types import content
6
+ from surya.layout import LayoutPredictor
7
+ from surya.ocr_error import OCRErrorPredictor
8
  from tqdm import tqdm
9
 
10
  from marker.builders.layout import LayoutBuilder
 
91
  Here is the image of the layout block:
92
  """
93
 
94
+ def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
95
  super().__init__(layout_model, ocr_error_model, config)
96
 
97
  self.model = GoogleModel(self.google_api_key, self.model_name)
marker/builders/ocr.py CHANGED
@@ -1,9 +1,8 @@
1
  from typing import Annotated, List, Optional
2
 
3
  from ftfy import fix_text
4
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
5
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
6
- from surya.ocr import run_ocr
7
 
8
  from marker.builders import BaseBuilder
9
  from marker.providers import ProviderOutput, ProviderPageLines
@@ -37,7 +36,7 @@ class OcrBuilder(BaseBuilder):
37
  "Default is None."
38
  ] = None
39
 
40
- def __init__(self, detection_model: EfficientViTForSemanticSegmentation, recognition_model: OCREncoderDecoderModel, config=None):
41
  super().__init__(config)
42
 
43
  self.detection_model = detection_model
@@ -65,13 +64,10 @@ class OcrBuilder(BaseBuilder):
65
 
66
  def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
67
  page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
68
- recognition_results = run_ocr(
69
  images=[page.lowres_image for page in page_list],
70
  langs=[self.languages] * len(page_list),
71
- det_model=self.detection_model,
72
- det_processor=self.detection_model.processor,
73
- rec_model=self.recognition_model,
74
- rec_processor=self.recognition_model.processor,
75
  detection_batch_size=int(self.get_detection_batch_size()),
76
  recognition_batch_size=int(self.get_recognition_batch_size()),
77
  highres_images=[page.highres_image for page in page_list]
 
1
  from typing import Annotated, List, Optional
2
 
3
  from ftfy import fix_text
4
+ from surya.detection import DetectionPredictor
5
+ from surya.recognition import RecognitionPredictor
 
6
 
7
  from marker.builders import BaseBuilder
8
  from marker.providers import ProviderOutput, ProviderPageLines
 
36
  "Default is None."
37
  ] = None
38
 
39
+ def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
40
  super().__init__(config)
41
 
42
  self.detection_model = detection_model
 
64
 
65
  def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
66
  page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
67
+ recognition_results = self.recognition_model(
68
  images=[page.lowres_image for page in page_list],
69
  langs=[self.languages] * len(page_list),
70
+ det_predictor=self.detection_model,
 
 
 
71
  detection_batch_size=int(self.get_detection_batch_size()),
72
  recognition_batch_size=int(self.get_recognition_batch_size()),
73
  highres_images=[page.highres_image for page in page_list]
marker/models.py CHANGED
@@ -1,86 +1,49 @@
1
  import os
2
 
3
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
4
-
5
- from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
6
- from surya.model.layout.model import load_model as load_layout_model
7
- from surya.model.layout.processor import load_processor as load_layout_processor
8
- from texify.model.model import load_model as load_texify_model
9
- from texify.model.processor import load_processor as load_texify_processor
10
  from marker.settings import settings
11
- from surya.model.recognition.model import load_model as load_recognition_model
12
- from surya.model.recognition.processor import load_processor as load_recognition_processor
13
- from surya.model.table_rec.model import load_model as load_table_model
14
- from surya.model.table_rec.processor import load_processor as load_table_processor
15
- from surya.model.ocr_error.model import load_model as load_ocr_error_model
16
- from surya.model.ocr_error.model import load_tokenizer as load_ocr_error_tokenizer
17
-
18
- from texify.model.model import GenerateVisionEncoderDecoderModel
19
- from surya.model.layout.encoderdecoder import SuryaLayoutModel
20
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
21
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
22
- from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
23
- from surya.model.ocr_error.model import DistilBertForSequenceClassification
24
-
25
 
26
- def setup_table_rec_model(device=None, dtype=None) -> TableRecEncoderDecoderModel:
27
- if device:
28
- table_model = load_table_model(device=device, dtype=dtype)
29
- else:
30
- table_model = load_table_model()
31
- table_model.processor = load_table_processor()
32
- return table_model
33
-
34
-
35
- def setup_recognition_model(device=None, dtype=None) -> OCREncoderDecoderModel:
36
- if device:
37
- rec_model = load_recognition_model(device=device, dtype=dtype)
38
- else:
39
- rec_model = load_recognition_model()
40
- rec_model.processor = load_recognition_processor()
41
- return rec_model
42
 
 
 
43
 
44
- def setup_detection_model(device=None, dtype=None) -> EfficientViTForSemanticSegmentation:
45
- if device:
46
- model = load_detection_model(device=device, dtype=dtype)
47
- else:
48
- model = load_detection_model()
49
- model.processor = load_detection_processor()
50
- return model
51
 
 
 
 
52
 
53
- def setup_texify_model(device=None, dtype=None) -> GenerateVisionEncoderDecoderModel:
54
- if device:
55
- texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
56
- else:
57
- texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
58
- texify_model.processor = load_texify_processor()
59
- return texify_model
60
 
 
 
 
 
61
 
62
- def setup_layout_model(device=None, dtype=None) -> SuryaLayoutModel:
63
- if device:
64
- model = load_layout_model(device=device, dtype=dtype)
65
- else:
66
- model = load_layout_model()
67
- model.processor = load_layout_processor()
68
- return model
69
 
70
- def setup_ocr_error_model(device=None, dtype=None) -> DistilBertForSequenceClassification:
71
- if device:
72
- model = load_ocr_error_model(device=device, dtype=dtype)
73
- else:
74
- model = load_ocr_error_model()
75
- model.tokenizer = load_ocr_error_tokenizer()
76
- return model
77
 
78
  def create_model_dict(device=None, dtype=None) -> dict:
79
  return {
80
- "layout_model": setup_layout_model(device, dtype),
81
- "texify_model": setup_texify_model(device, dtype),
82
- "recognition_model": setup_recognition_model(device, dtype),
83
- "table_rec_model": setup_table_rec_model(device, dtype),
84
- "detection_model": setup_detection_model(device, dtype),
85
- "ocr_error_model": setup_ocr_error_model(device,dtype)
86
  }
 
1
  import os
2
 
 
 
 
 
 
 
 
3
  from marker.settings import settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from typing import List
8
+ from PIL import Image
9
 
10
+ from surya.detection import DetectionPredictor
11
+ from surya.layout import LayoutPredictor
12
+ from surya.ocr_error import OCRErrorPredictor
13
+ from surya.recognition import RecognitionPredictor
14
+ from surya.table_rec import TableRecPredictor
 
 
15
 
16
+ from texify.model.model import load_model as load_texify_model
17
+ from texify.model.processor import load_processor as load_texify_processor
18
+ from texify.inference import batch_inference
19
 
20
+ class TexifyPredictor:
21
+ def __init__(self, device=None, dtype=None):
22
+ if not device:
23
+ device = settings.TORCH_DEVICE_MODEL
24
+ if not dtype:
25
+ dtype = settings.TEXIFY_DTYPE
 
26
 
27
+ self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
28
+ self.processor = load_texify_processor()
29
+ self.device = device
30
+ self.dtype = dtype
31
 
32
+ def __call__(self, batch_images: List[Image.Image], max_tokens: int):
33
+ return batch_inference(
34
+ batch_images,
35
+ self.model,
36
+ self.processor,
37
+ max_tokens=max_tokens
38
+ )
39
 
 
 
 
 
 
 
 
40
 
41
  def create_model_dict(device=None, dtype=None) -> dict:
42
  return {
43
+ "layout_model": LayoutPredictor(device=device, dtype=dtype),
44
+ "texify_model": TexifyPredictor(device=device, dtype=dtype),
45
+ "recognition_model": RecognitionPredictor(device=device, dtype=dtype),
46
+ "table_rec_model": TableRecPredictor(device=device, dtype=dtype),
47
+ "detection_model": DetectionPredictor(device=device, dtype=dtype),
48
+ "ocr_error_model": OCRErrorPredictor(device=device, dtype=dtype)
49
  }
marker/processors/equation.py CHANGED
@@ -4,6 +4,7 @@ from texify.inference import batch_inference
4
  from texify.model.model import GenerateVisionEncoderDecoderModel
5
  from tqdm import tqdm
6
 
 
7
  from marker.processors import BaseProcessor
8
  from marker.schema import BlockTypes
9
  from marker.schema.document import Document
@@ -32,7 +33,7 @@ class EquationProcessor(BaseProcessor):
32
  "The number of tokens to buffer above max for the Texify model.",
33
  ] = 256
34
 
35
- def __init__(self, texify_model: GenerateVisionEncoderDecoderModel, config=None):
36
  super().__init__(config)
37
 
38
  self.texify_model = texify_model
@@ -92,10 +93,8 @@ class EquationProcessor(BaseProcessor):
92
 
93
  batch_images = [eq["image"] for eq in batch_equations]
94
 
95
- model_output = batch_inference(
96
  batch_images,
97
- self.texify_model,
98
- self.texify_model.processor,
99
  max_tokens=max_length
100
  )
101
 
 
4
  from texify.model.model import GenerateVisionEncoderDecoderModel
5
  from tqdm import tqdm
6
 
7
+ from marker.models import TexifyPredictor
8
  from marker.processors import BaseProcessor
9
  from marker.schema import BlockTypes
10
  from marker.schema.document import Document
 
33
  "The number of tokens to buffer above max for the Texify model.",
34
  ] = 256
35
 
36
+ def __init__(self, texify_model: TexifyPredictor, config=None):
37
  super().__init__(config)
38
 
39
  self.texify_model = texify_model
 
93
 
94
  batch_images = [eq["image"] for eq in batch_equations]
95
 
96
+ model_output = self.texify_model(
97
  batch_images,
 
 
98
  max_tokens=max_length
99
  )
100
 
marker/processors/table.py CHANGED
@@ -2,10 +2,9 @@
2
  from typing import Annotated
3
 
4
  from ftfy import fix_text
5
- from surya.input.pdflines import get_page_text_lines
6
- from surya.model.detection.model import EfficientViTForSemanticSegmentation
7
- from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel
8
- from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
9
  from tabled.assignment import assign_rows_columns
10
  from tabled.inference.recognition import get_cells, recognize_tables
11
 
@@ -42,9 +41,9 @@ class TableProcessor(BaseProcessor):
42
 
43
  def __init__(
44
  self,
45
- detection_model: EfficientViTForSemanticSegmentation,
46
- recognition_model: OCREncoderDecoderModel,
47
- table_rec_model: TableRecEncoderDecoderModel,
48
  config=None
49
  ):
50
  super().__init__(config)
 
2
  from typing import Annotated
3
 
4
  from ftfy import fix_text
5
+ from surya.detection import DetectionPredictor
6
+ from surya.recognition import RecognitionPredictor
7
+ from surya.table_rec import TableRecPredictor
 
8
  from tabled.assignment import assign_rows_columns
9
  from tabled.inference.recognition import get_cells, recognize_tables
10
 
 
41
 
42
  def __init__(
43
  self,
44
+ detection_model: DetectionPredictor,
45
+ recognition_model: RecognitionPredictor,
46
+ table_rec_model: TableRecPredictor,
47
  config=None
48
  ):
49
  super().__init__(config)
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [tool.poetry]
2
  name = "marker-pdf"
3
- version = "1.2.3"
4
  description = "Convert PDF to markdown with high speed and accuracy."
5
  authors = ["Vik Paruchuri <github@vikas.sh>"]
6
  readme = "README.md"
 
1
  [tool.poetry]
2
  name = "marker-pdf"
3
+ version = "1.3.0"
4
  description = "Convert PDF to markdown with high speed and accuracy."
5
  authors = ["Vik Paruchuri <github@vikas.sh>"]
6
  readme = "README.md"
tests/builders/test_blank_page.py CHANGED
@@ -1,4 +1,4 @@
1
- from surya.schema import LayoutResult
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
 
1
+ from surya.layout.schema import LayoutResult
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
tests/conftest.py CHANGED
@@ -9,9 +9,7 @@ from marker.builders.document import DocumentBuilder
9
  from marker.builders.layout import LayoutBuilder
10
  from marker.builders.ocr import OcrBuilder
11
  from marker.converters.pdf import PdfConverter
12
- from marker.models import setup_detection_model, setup_layout_model, \
13
- setup_recognition_model, setup_table_rec_model, \
14
- setup_texify_model, setup_ocr_error_model
15
  from marker.schema import BlockTypes
16
  from marker.schema.blocks import Block
17
  from marker.renderers.markdown import MarkdownRenderer
@@ -19,46 +17,42 @@ from marker.renderers.json import JSONRenderer
19
  from marker.schema.registry import register_block_class
20
  from marker.util import classes_to_strings
21
 
 
 
 
 
 
 
22
 
23
  @pytest.fixture(scope="session")
24
- def layout_model():
25
- layout_m = setup_layout_model()
26
- yield layout_m
27
- del layout_m
28
 
29
 
30
  @pytest.fixture(scope="session")
31
- def detection_model():
32
- detection_m = setup_detection_model()
33
- yield detection_m
34
- del detection_m
35
 
36
 
37
  @pytest.fixture(scope="session")
38
- def texify_model():
39
- texify_m = setup_texify_model()
40
- yield texify_m
41
- del texify_m
42
 
43
 
44
  @pytest.fixture(scope="session")
45
- def recognition_model():
46
- ocr_m = setup_recognition_model()
47
- yield ocr_m
48
- del ocr_m
49
 
50
 
51
  @pytest.fixture(scope="session")
52
- def table_rec_model():
53
- table_rec_m = setup_table_rec_model()
54
- yield table_rec_m
55
- del table_rec_m
56
 
57
  @pytest.fixture(scope="session")
58
- def ocr_error_model():
59
- ocr_error_m = setup_ocr_error_model()
60
- yield ocr_error_m
61
- del ocr_error_m
62
 
63
  @pytest.fixture(scope="function")
64
  def config(request):
@@ -101,15 +95,7 @@ def pdf_document(request, config, pdf_provider, layout_model, ocr_error_model, r
101
 
102
 
103
  @pytest.fixture(scope="function")
104
- def pdf_converter(request, config, layout_model, texify_model, recognition_model, table_rec_model, detection_model, ocr_error_model, renderer):
105
- model_dict = {
106
- "layout_model": layout_model,
107
- "texify_model": texify_model,
108
- "recognition_model": recognition_model,
109
- "table_rec_model": table_rec_model,
110
- "detection_model": detection_model,
111
- "ocr_error_model": ocr_error_model
112
- }
113
  yield PdfConverter(
114
  artifact_dict=model_dict,
115
  processor_list=None,
 
9
  from marker.builders.layout import LayoutBuilder
10
  from marker.builders.ocr import OcrBuilder
11
  from marker.converters.pdf import PdfConverter
12
+ from marker.models import create_model_dict
 
 
13
  from marker.schema import BlockTypes
14
  from marker.schema.blocks import Block
15
  from marker.renderers.markdown import MarkdownRenderer
 
17
  from marker.schema.registry import register_block_class
18
  from marker.util import classes_to_strings
19
 
20
+ @pytest.fixture(scope="session")
21
+ def model_dict():
22
+ model_dict = create_model_dict()
23
+ yield model_dict
24
+ del model_dict
25
+
26
 
27
  @pytest.fixture(scope="session")
28
+ def layout_model(model_dict):
29
+ yield model_dict["layout_model"]
 
 
30
 
31
 
32
  @pytest.fixture(scope="session")
33
+ def detection_model(model_dict):
34
+ yield model_dict["detection_model"]
 
 
35
 
36
 
37
  @pytest.fixture(scope="session")
38
+ def texify_model(model_dict):
39
+ yield model_dict["texify_model"]
 
 
40
 
41
 
42
  @pytest.fixture(scope="session")
43
+ def recognition_model(model_dict):
44
+ yield model_dict["recognition_model"]
 
 
45
 
46
 
47
  @pytest.fixture(scope="session")
48
+ def table_rec_model(model_dict):
49
+ yield model_dict["table_rec_model"]
50
+
 
51
 
52
  @pytest.fixture(scope="session")
53
+ def ocr_error_model(model_dict):
54
+ yield model_dict["ocr_error_model"]
55
+
 
56
 
57
  @pytest.fixture(scope="function")
58
  def config(request):
 
95
 
96
 
97
  @pytest.fixture(scope="function")
98
+ def pdf_converter(request, config, model_dict, renderer):
 
 
 
 
 
 
 
 
99
  yield PdfConverter(
100
  artifact_dict=model_dict,
101
  processor_list=None,
tests/utils.py CHANGED
@@ -2,11 +2,6 @@ from marker.providers.pdf import PdfProvider
2
  import tempfile
3
 
4
  import datasets
5
- from marker.models import setup_layout_model, setup_recognition_model, setup_detection_model
6
- from marker.builders.document import DocumentBuilder
7
- from marker.builders.layout import LayoutBuilder
8
- from marker.builders.ocr import OcrBuilder
9
- from marker.schema.document import Document
10
 
11
 
12
  def setup_pdf_provider(
 
2
  import tempfile
3
 
4
  import datasets
 
 
 
 
 
5
 
6
 
7
  def setup_pdf_provider(