Moses Paul R commited on
Commit
10e6ee7
·
1 Parent(s): 5aa8146

refactor block merging [skip ci]

Browse files
marker/v2/builders/layout.py CHANGED
@@ -7,7 +7,7 @@ from marker.settings import settings
7
  from marker.v2.builders import BaseBuilder
8
  from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider
9
  from marker.v2.schema import BlockTypes
10
- from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY, Block, Text
11
  from marker.v2.schema.document import Document
12
  from marker.v2.schema.groups.page import PageGroup
13
  from marker.v2.schema.polygon import PolygonBox
@@ -59,56 +59,8 @@ class LayoutBuilder(BaseBuilder):
59
  if not self.check_layout_coverage(document_page, provider_lines):
60
  document_page.text_extraction_method = "surya"
61
  continue
62
-
63
  line_spans = provider_page_spans[document_page.page_id]
64
- provider_line_idxs = set(range(len(provider_lines)))
65
- max_intersections = {}
66
- for line_idx, line in enumerate(provider_lines):
67
- for block_idx, block in enumerate(document_page.children):
68
- if block.block_type in [BlockTypes.Line, BlockTypes.Span]:
69
- continue
70
- intersection_pct = line.polygon.intersection_pct(block.polygon)
71
- if line_idx not in max_intersections:
72
- max_intersections[line_idx] = (intersection_pct, block_idx)
73
- elif intersection_pct > max_intersections[line_idx][0]:
74
- max_intersections[line_idx] = (intersection_pct, block_idx)
75
-
76
- assigned_line_idxs = set()
77
- for line_idx, line in enumerate(provider_lines):
78
- if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0:
79
- document_page.add_full_block(line)
80
- block_idx = max_intersections[line_idx][1]
81
- block: Block = document_page.children[block_idx]
82
- block.add_structure(line)
83
- block.polygon = block.polygon.merge([line.polygon])
84
- block.text_extraction_method = "pdftext"
85
- assigned_line_idxs.add(line_idx)
86
- for span in line_spans[line_idx]:
87
- document_page.add_full_block(span)
88
- line.add_structure(span)
89
-
90
- for line_idx in provider_line_idxs.difference(assigned_line_idxs):
91
- min_dist = None
92
- min_dist_idx = None
93
- line = provider_lines[line_idx]
94
- for block_idx, block in enumerate(document_page.children):
95
- if block.block_type in [BlockTypes.Line, BlockTypes.Span]:
96
- continue
97
- dist = line.polygon.center_distance(block.polygon)
98
- if min_dist_idx is None or dist < min_dist:
99
- min_dist = dist
100
- min_dist_idx = block_idx
101
-
102
- if min_dist_idx is not None:
103
- document_page.add_full_block(line)
104
- nearest_block = document_page.children[min_dist_idx]
105
- nearest_block.add_structure(line)
106
- nearest_block.polygon = nearest_block.polygon.merge([line.polygon])
107
- nearest_block.text_extraction_method = "pdftext"
108
- assigned_line_idxs.add(line_idx)
109
- for span in line_spans[line_idx]:
110
- document_page.add_full_block(span)
111
- line.add_structure(span)
112
 
113
  def check_layout_coverage(
114
  self,
 
7
  from marker.v2.builders import BaseBuilder
8
  from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider
9
  from marker.v2.schema import BlockTypes
10
+ from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY
11
  from marker.v2.schema.document import Document
12
  from marker.v2.schema.groups.page import PageGroup
13
  from marker.v2.schema.polygon import PolygonBox
 
59
  if not self.check_layout_coverage(document_page, provider_lines):
60
  document_page.text_extraction_method = "surya"
61
  continue
 
62
  line_spans = provider_page_spans[document_page.page_id]
63
+ document_page.merge_blocks(provider_lines, line_spans, text_extraction_method="pdftext")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def check_layout_coverage(
66
  self,
marker/v2/builders/ocr.py CHANGED
@@ -97,52 +97,4 @@ class OcrBuilder(BaseBuilder):
97
  def merge_blocks(self, document: Document, page_lines: PageLines, page_spans: PageSpans):
98
  ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"]
99
  for document_page, lines, line_spans in zip(ocred_pages, page_lines.values(), page_spans.values()):
100
-
101
- line_idxs = set(range(len(lines)))
102
- max_intersections = {}
103
- for line_idx, line in enumerate(lines):
104
- for block_idx, block in enumerate(document_page.children):
105
- if block.block_type in [BlockTypes.Line, BlockTypes.Span]:
106
- continue
107
- intersection_pct = line.polygon.intersection_pct(block.polygon)
108
- if line_idx not in max_intersections:
109
- max_intersections[line_idx] = (intersection_pct, block_idx)
110
- elif intersection_pct > max_intersections[line_idx][0]:
111
- max_intersections[line_idx] = (intersection_pct, block_idx)
112
-
113
- assigned_line_idxs = set()
114
- for line_idx, line in enumerate(lines):
115
- if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0:
116
- document_page.add_full_block(line)
117
- block_idx = max_intersections[line_idx][1]
118
- block: Block = document_page.children[block_idx]
119
- block.add_structure(line)
120
- block.polygon = block.polygon.merge([line.polygon])
121
- block.text_extraction_method = "surya"
122
- assigned_line_idxs.add(line_idx)
123
- for span in line_spans[line_idx]:
124
- document_page.add_full_block(span)
125
- line.add_structure(span)
126
-
127
- for line_idx in line_idxs.difference(assigned_line_idxs):
128
- min_dist = None
129
- min_dist_idx = None
130
- line = lines[line_idx]
131
- for block_idx, block in enumerate(document_page.children):
132
- if block.block_type in [BlockTypes.Line, BlockTypes.Span]:
133
- continue
134
- dist = line.polygon.center_distance(block.polygon)
135
- if min_dist_idx is None or dist < min_dist:
136
- min_dist = dist
137
- min_dist_idx = block_idx
138
-
139
- if min_dist_idx is not None:
140
- document_page.add_full_block(line)
141
- nearest_block = document_page.children[min_dist_idx]
142
- nearest_block.add_structure(line)
143
- nearest_block.polygon = nearest_block.polygon.merge([line.polygon])
144
- nearest_block.text_extraction_method = "surya"
145
- assigned_line_idxs.add(line_idx)
146
- for span in line_spans[line_idx]:
147
- document_page.add_full_block(span)
148
- line.add_structure(span)
 
97
  def merge_blocks(self, document: Document, page_lines: PageLines, page_spans: PageSpans):
98
  ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"]
99
  for document_page, lines, line_spans in zip(ocred_pages, page_lines.values(), page_spans.values()):
100
+ document_page.merge_blocks(lines, line_spans, text_extraction_method="surya")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
marker/v2/schema/groups/page.py CHANGED
@@ -1,10 +1,12 @@
1
- from typing import List
2
 
3
  from PIL import Image
4
 
5
  from marker.v2.schema import BlockTypes
6
  from marker.v2.schema.blocks import Block, BlockId
7
  from marker.v2.schema.polygon import PolygonBox
 
 
8
 
9
 
10
  class PageGroup(Block):
@@ -45,3 +47,60 @@ class PageGroup(Block):
45
  block: Block = self.children[block_id.block_id]
46
  assert block.block_id == block_id.block_id
47
  return block
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
 
3
  from PIL import Image
4
 
5
  from marker.v2.schema import BlockTypes
6
  from marker.v2.schema.blocks import Block, BlockId
7
  from marker.v2.schema.polygon import PolygonBox
8
+ from marker.v2.schema.text.line import Line
9
+ from marker.v2.schema.text.span import Span
10
 
11
 
12
  class PageGroup(Block):
 
47
  block: Block = self.children[block_id.block_id]
48
  assert block.block_id == block_id.block_id
49
  return block
50
+
51
+ def merge_blocks(
52
+ self,
53
+ page_lines: List[Line],
54
+ line_spans: Dict[int, List[Span]],
55
+ text_extraction_method: str,
56
+ excluded_block_types=[BlockTypes.Line, BlockTypes.Span]
57
+ ):
58
+ provider_line_idxs = set(range(len(page_lines)))
59
+ max_intersections = {}
60
+
61
+ for line_idx, line in enumerate(page_lines):
62
+ for block_idx, block in enumerate(self.children):
63
+ if block.block_type in excluded_block_types:
64
+ continue
65
+ intersection_pct = line.polygon.intersection_pct(block.polygon)
66
+ if line_idx not in max_intersections:
67
+ max_intersections[line_idx] = (intersection_pct, block_idx)
68
+ elif intersection_pct > max_intersections[line_idx][0]:
69
+ max_intersections[line_idx] = (intersection_pct, block_idx)
70
+
71
+ assigned_line_idxs = set()
72
+ for line_idx, line in enumerate(page_lines):
73
+ if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0:
74
+ self.add_full_block(line)
75
+ block_idx = max_intersections[line_idx][1]
76
+ block: Block = self.children[block_idx]
77
+ block.add_structure(line)
78
+ block.polygon = block.polygon.merge([line.polygon])
79
+ block.text_extraction_method = text_extraction_method
80
+ assigned_line_idxs.add(line_idx)
81
+ for span in line_spans[line_idx]:
82
+ self.add_full_block(span)
83
+ line.add_structure(span)
84
+
85
+ for line_idx in provider_line_idxs.difference(assigned_line_idxs):
86
+ min_dist = None
87
+ min_dist_idx = None
88
+ line = page_lines[line_idx]
89
+ for block_idx, block in enumerate(self.children):
90
+ if block.block_type in excluded_block_types:
91
+ continue
92
+ dist = line.polygon.center_distance(block.polygon)
93
+ if min_dist_idx is None or dist < min_dist:
94
+ min_dist = dist
95
+ min_dist_idx = block_idx
96
+
97
+ if min_dist_idx is not None:
98
+ self.add_full_block(line)
99
+ nearest_block = self.children[min_dist_idx]
100
+ nearest_block.add_structure(line)
101
+ nearest_block.polygon = nearest_block.polygon.merge([line.polygon])
102
+ nearest_block.text_extraction_method = text_extraction_method
103
+ assigned_line_idxs.add(line_idx)
104
+ for span in line_spans[line_idx]:
105
+ self.add_full_block(span)
106
+ line.add_structure(span)