Moses Paul R commited on
Commit
ca8a504
·
2 Parent(s): 69b4d9a 6bd5629

Merge remote-tracking branch 'origin/dev' into dev-mose/input-formats-2

Browse files
README.md CHANGED
@@ -219,7 +219,12 @@ rendered = converter("FILEPATH")
219
  text, _, images = text_from_rendered(rendered)
220
  ```
221
 
222
- This takes all the same configuration as the PdfConverter. You can specify the configuration `force_layout_block=Table` to avoid layout detection and instead assume every page is a table.
 
 
 
 
 
223
 
224
  # Output Formats
225
 
@@ -400,8 +405,8 @@ Marker can extract tables from PDFs using `marker.converters.table.TableConverte
400
 
401
  | Avg score | Total tables | use_llm |
402
  |-----------|--------------|---------|
403
- | 0.824 | 54 | False |
404
- | 0.873 | 54 | True |
405
 
406
  The `--use_llm` flag can significantly improve table recognition performance, as you can see.
407
 
 
219
  text, _, images = text_from_rendered(rendered)
220
  ```
221
 
222
+ This takes all the same configuration as the PdfConverter. You can specify the configuration `--force_layout_block=Table` to avoid layout detection and instead assume every page is a table.
223
+
224
+ You can also run this via the CLI with
225
+ ```shell
226
+ python convert_single.py FILENAME --use_llm --force_layout_block Table --converter_cls marker.converters.table.TableConverter
227
+ ```
228
 
229
  # Output Formats
230
 
 
405
 
406
  | Avg score | Total tables | use_llm |
407
  |-----------|--------------|---------|
408
+ | 0.822 | 54 | False |
409
+ | 0.887 | 54 | True |
410
 
411
  The `--use_llm` flag can significantly improve table recognition performance, as you can see.
412
 
benchmarks/table/gemini.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from PIL import Image
3
+ import google.generativeai as genai
4
+ from google.ai.generativelanguage_v1beta.types import content
5
+ from marker.settings import settings
6
+
7
+ prompt = """
8
+ You're an expert document analyst who is good at turning tables in documents into HTML. Analyze the provided image, and convert it to a faithful HTML representation.
9
+
10
+ Guidelines:
11
+ - Keep the HTML simple and concise.
12
+ - Only include the <table> tag and contents.
13
+ - Only use <table>, <tr>, and <td> tags. Only use the colspan and rowspan attributes if necessary. Do not use <tbody>, <thead>, or <th> tags.
14
+ - Make sure the table is as faithful to the image as possible with the given tags.
15
+
16
+ **Instructions**
17
+ 1. Analyze the image, and determine the table structure.
18
+ 2. Convert the table image to HTML, following the guidelines above.
19
+ 3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
20
+ """.strip()
21
+
22
+ genai.configure(api_key=settings.GOOGLE_API_KEY)
23
+
24
+ def gemini_table_rec(image: Image.Image):
25
+ schema = content.Schema(
26
+ type=content.Type.OBJECT,
27
+ required=["table_html"],
28
+ properties={
29
+ "table_html": content.Schema(
30
+ type=content.Type.STRING,
31
+ )
32
+ }
33
+ )
34
+
35
+ model = genai.GenerativeModel("gemini-1.5-flash")
36
+
37
+ responses = model.generate_content(
38
+ [image, prompt], # According to gemini docs, it performs better if the image is the first element
39
+ stream=False,
40
+ generation_config={
41
+ "temperature": 0,
42
+ "response_schema": schema,
43
+ "response_mime_type": "application/json",
44
+ },
45
+ request_options={'timeout': 60}
46
+ )
47
+
48
+ output = responses.candidates[0].content.parts[0].text
49
+ return json.loads(output)["table_html"]
benchmarks/table/table.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- from typing import List
3
-
4
- import numpy as np
5
-
6
- from marker.renderers.json import JSONOutput, JSONBlockOutput
7
 
8
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
9
 
 
 
10
  import base64
11
  import time
12
  import datasets
@@ -16,21 +15,24 @@ import click
16
  from tabulate import tabulate
17
  import json
18
  from bs4 import BeautifulSoup
19
- from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
20
  from pypdfium2._helpers.misc import PdfiumError
 
21
  from marker.util import matrix_intersection_area
 
22
 
23
  from marker.config.parser import ConfigParser
24
  from marker.converters.table import TableConverter
25
  from marker.models import create_model_dict
26
 
27
  from scoring import wrap_table_html, similarity_eval_html
 
28
 
29
- def update_teds_score(result):
30
- prediction, ground_truth = result['marker_table'], result['gt_table']
31
  prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
32
  score = similarity_eval_html(prediction, ground_truth)
33
- result.update({'score':score})
34
  return result
35
 
36
 
@@ -51,7 +53,16 @@ def extract_tables(children: List[JSONBlockOutput]):
51
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
52
  @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
53
  @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
54
- def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool, table_rec_batch_size: int | None):
 
 
 
 
 
 
 
 
 
55
  models = create_model_dict()
56
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
57
  start = time.time()
@@ -86,6 +97,9 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
86
  marker_json = converter(temp_pdf_file.name).children
87
  tqdm.disable = False
88
 
 
 
 
89
  if len(marker_json) == 0 or len(gt_tables) == 0:
90
  print(f'No tables detected, skipping...')
91
  total_unaligned += len(gt_tables)
@@ -94,6 +108,8 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
94
  marker_tables = extract_tables(marker_json)
95
  marker_table_boxes = [table.bbox for table in marker_tables]
96
  page_bbox = marker_json[0].bbox
 
 
97
 
98
  # Normalize the bboxes
99
  for bbox in marker_table_boxes:
@@ -136,14 +152,18 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
136
  unaligned_tables.add(table_idx)
137
  continue
138
 
 
 
 
 
139
  aligned_tables.append(
140
- (marker_tables[aligned_idx], gt_tables[table_idx])
141
  )
142
  used_tables.add(aligned_idx)
143
 
144
  total_unaligned += len(unaligned_tables)
145
 
146
- for marker_table, gt_table in aligned_tables:
147
  gt_table_html = gt_table['html']
148
 
149
  #marker wraps the table in <tbody> which fintabnet data doesn't
@@ -154,10 +174,12 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
154
  th_tag.name = 'td'
155
  marker_table_html = str(marker_table_soup)
156
  marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
 
157
 
158
  results.append({
159
  "marker_table": marker_table_html,
160
- "gt_table": gt_table_html
 
161
  })
162
  except PdfiumError:
163
  print('Broken PDF, Skipping...')
@@ -167,19 +189,37 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
167
  print(f"Could not align {total_unaligned} tables from fintabnet.")
168
 
169
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
170
- results = list(
171
  tqdm(
172
  executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
173
  )
174
  )
175
- avg_score = sum([r["score"] for r in results]) / len(results)
176
 
 
177
  headers = ["Avg score", "Total tables"]
178
- data = [f"{avg_score:.3f}", len(results)]
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  table = tabulate([data], headers=headers, tablefmt="github")
180
  print(table)
181
  print("Avg score computed by comparing marker predicted HTML with original HTML")
182
 
 
 
 
 
 
183
  with open(out_file, "w+") as f:
184
  json.dump(results, f, indent=2)
185
 
 
1
  import os
2
+ from itertools import repeat
3
+ from tkinter import Image
 
 
 
4
 
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
 
7
+ from typing import List
8
+ import numpy as np
9
  import base64
10
  import time
11
  import datasets
 
15
  from tabulate import tabulate
16
  import json
17
  from bs4 import BeautifulSoup
18
+ from concurrent.futures import ProcessPoolExecutor
19
  from pypdfium2._helpers.misc import PdfiumError
20
+ import pypdfium2 as pdfium
21
  from marker.util import matrix_intersection_area
22
+ from marker.renderers.json import JSONOutput, JSONBlockOutput
23
 
24
  from marker.config.parser import ConfigParser
25
  from marker.converters.table import TableConverter
26
  from marker.models import create_model_dict
27
 
28
  from scoring import wrap_table_html, similarity_eval_html
29
+ from gemini import gemini_table_rec
30
 
31
+ def update_teds_score(result, prefix: str = "marker"):
32
+ prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
33
  prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
34
  score = similarity_eval_html(prediction, ground_truth)
35
+ result.update({f'{prefix}_score':score})
36
  return result
37
 
38
 
 
53
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
54
  @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
55
  @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
56
+ @click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
57
+ def main(
58
+ out_file: str,
59
+ dataset: str,
60
+ max_rows: int,
61
+ max_workers: int,
62
+ use_llm: bool,
63
+ table_rec_batch_size: int | None,
64
+ use_gemini: bool = False
65
+ ):
66
  models = create_model_dict()
67
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
68
  start = time.time()
 
97
  marker_json = converter(temp_pdf_file.name).children
98
  tqdm.disable = False
99
 
100
+ doc = pdfium.PdfDocument(temp_pdf_file.name)
101
+ page_image = doc[0].render(scale=92/72).to_pil()
102
+
103
  if len(marker_json) == 0 or len(gt_tables) == 0:
104
  print(f'No tables detected, skipping...')
105
  total_unaligned += len(gt_tables)
 
108
  marker_tables = extract_tables(marker_json)
109
  marker_table_boxes = [table.bbox for table in marker_tables]
110
  page_bbox = marker_json[0].bbox
111
+ w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]
112
+ table_images = [page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox in marker_table_boxes]
113
 
114
  # Normalize the bboxes
115
  for bbox in marker_table_boxes:
 
152
  unaligned_tables.add(table_idx)
153
  continue
154
 
155
+ gemini_html = ""
156
+ if use_gemini:
157
+ gemini_html = gemini_table_rec(table_images[aligned_idx])
158
+
159
  aligned_tables.append(
160
+ (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
161
  )
162
  used_tables.add(aligned_idx)
163
 
164
  total_unaligned += len(unaligned_tables)
165
 
166
+ for marker_table, gt_table, gemini_table in aligned_tables:
167
  gt_table_html = gt_table['html']
168
 
169
  #marker wraps the table in <tbody> which fintabnet data doesn't
 
174
  th_tag.name = 'td'
175
  marker_table_html = str(marker_table_soup)
176
  marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
177
+ gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines
178
 
179
  results.append({
180
  "marker_table": marker_table_html,
181
+ "gt_table": gt_table_html,
182
+ "gemini_table": gemini_table_html
183
  })
184
  except PdfiumError:
185
  print('Broken PDF, Skipping...')
 
189
  print(f"Could not align {total_unaligned} tables from fintabnet.")
190
 
191
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
192
+ marker_results = list(
193
  tqdm(
194
  executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
195
  )
196
  )
 
197
 
198
+ avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
199
  headers = ["Avg score", "Total tables"]
200
+ data = [f"{avg_score:.3f}", len(marker_results)]
201
+ gemini_results = None
202
+ if use_gemini:
203
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
204
+ gemini_results = list(
205
+ tqdm(
206
+ executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
207
+ total=len(results)
208
+ )
209
+ )
210
+ avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
211
+ headers.append("Avg Gemini score")
212
+ data.append(f"{avg_gemini_score:.3f}")
213
+
214
  table = tabulate([data], headers=headers, tablefmt="github")
215
  print(table)
216
  print("Avg score computed by comparing marker predicted HTML with original HTML")
217
 
218
+ results = {
219
+ "marker": marker_results,
220
+ "gemini": gemini_results
221
+ }
222
+
223
  with open(out_file, "w+") as f:
224
  json.dump(results, f, indent=2)
225
 
chunk_convert.py CHANGED
@@ -1,4 +1,4 @@
1
- from marker.scripts import chunk_convert_cli
2
 
3
  if __name__ == "__main__":
4
  chunk_convert_cli()
 
1
+ from marker.scripts.chunk_convert import chunk_convert_cli
2
 
3
  if __name__ == "__main__":
4
  chunk_convert_cli()
convert.py CHANGED
@@ -1,4 +1,4 @@
1
- from marker.scripts import convert_cli
2
 
3
  if __name__ == "__main__":
4
  convert_cli()
 
1
+ from marker.scripts.convert import convert_cli
2
 
3
  if __name__ == "__main__":
4
  convert_cli()
convert_single.py CHANGED
@@ -1,4 +1,4 @@
1
- from marker.scripts import convert_single_cli
2
 
3
  if __name__ == "__main__":
4
  convert_single_cli()
 
1
+ from marker.scripts.convert_single import convert_single_cli
2
 
3
  if __name__ == "__main__":
4
  convert_single_cli()
marker/builders/ocr.py CHANGED
@@ -35,6 +35,10 @@ class OcrBuilder(BaseBuilder):
35
  "A list of languages to use for OCR.",
36
  "Default is None."
37
  ] = None
 
 
 
 
38
 
39
  def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
40
  super().__init__(config)
@@ -67,12 +71,12 @@ class OcrBuilder(BaseBuilder):
67
 
68
  # Remove tables because we re-OCR them later with the table processor
69
  recognition_results = self.recognition_model(
70
- images=[page.get_image(highres=False, remove_tables=True) for page in page_list],
71
  langs=[self.languages] * len(page_list),
72
  det_predictor=self.detection_model,
73
  detection_batch_size=int(self.get_detection_batch_size()),
74
  recognition_batch_size=int(self.get_recognition_batch_size()),
75
- highres_images=[page.get_image(highres=True, remove_tables=True) for page in page_list]
76
  )
77
 
78
  page_lines = {}
 
35
  "A list of languages to use for OCR.",
36
  "Default is None."
37
  ] = None
38
+ enable_table_ocr: Annotated[
39
+ bool,
40
+ "Whether to skip OCR on tables. The TableProcessor will re-OCR them. Only enable if the TableProcessor is not running.",
41
+ ] = False
42
 
43
  def __init__(self, detection_model: DetectionPredictor, recognition_model: RecognitionPredictor, config=None):
44
  super().__init__(config)
 
71
 
72
  # Remove tables because we re-OCR them later with the table processor
73
  recognition_results = self.recognition_model(
74
+ images=[page.get_image(highres=False, remove_tables=not self.enable_table_ocr) for page in page_list],
75
  langs=[self.languages] * len(page_list),
76
  det_predictor=self.detection_model,
77
  detection_batch_size=int(self.get_detection_batch_size()),
78
  recognition_batch_size=int(self.get_recognition_batch_size()),
79
+ highres_images=[page.get_image(highres=True, remove_tables=not self.enable_table_ocr) for page in page_list]
80
  )
81
 
82
  page_lines = {}
marker/processors/llm/llm_form.py CHANGED
@@ -17,7 +17,7 @@ Values and labels should appear in html tables, with the labels on the left side
17
  **Instructions:**
18
  1. Carefully examine the provided form block image.
19
  2. Analyze the html representation of the form.
20
- 3. If the html representation is largely correct, then write "No corrections needed."
21
  4. If the html representation contains errors, generate the corrected html representation.
22
  5. Output only either the corrected html representation or "No corrections needed."
23
  **Example:**
 
17
  **Instructions:**
18
  1. Carefully examine the provided form block image.
19
  2. Analyze the html representation of the form.
20
+ 3. If the html representation is largely correct, or you cannot read the image properly, then write "No corrections needed."
21
  4. If the html representation contains errors, generate the corrected html representation.
22
  5. Output only either the corrected html representation or "No corrections needed."
23
  **Example:**
marker/processors/llm/llm_table.py CHANGED
@@ -16,9 +16,9 @@ class LLMTableProcessor(BaseLLMProcessor):
16
  Tuple[BlockTypes],
17
  "The block types to process.",
18
  ] = (BlockTypes.Table, BlockTypes.TableOfContents)
19
- max_row_count: Annotated[
20
  int,
21
- "If the table has more rows than this, don't run LLM processor. (LLMs can be inaccurate with a lot of rows)",
22
  ] = 75
23
  table_rewriting_prompt: Annotated[
24
  str,
@@ -37,7 +37,7 @@ Some guidelines:
37
  **Instructions:**
38
  1. Carefully examine the provided text block image.
39
  2. Analyze the html representation of the table.
40
- 3. If the html representation is largely correct, then write "No corrections needed."
41
  4. If the html representation contains errors, generate the corrected html representation.
42
  5. Output only either the corrected html representation or "No corrections needed."
43
  **Example:**
@@ -74,7 +74,9 @@ No corrections needed.
74
 
75
  # LLMs don't handle tables with a lot of rows very well
76
  row_count = len(set([cell.row_id for cell in children]))
77
- if row_count > self.max_row_count:
 
 
78
  return
79
 
80
  block_html = block.render(document).html
 
16
  Tuple[BlockTypes],
17
  "The block types to process.",
18
  ] = (BlockTypes.Table, BlockTypes.TableOfContents)
19
+ max_rows_per_batch: Annotated[
20
  int,
21
+ "If the table has more rows than this, chunk the table. (LLMs can be inaccurate with a lot of rows)",
22
  ] = 75
23
  table_rewriting_prompt: Annotated[
24
  str,
 
37
  **Instructions:**
38
  1. Carefully examine the provided text block image.
39
  2. Analyze the html representation of the table.
40
+ 3. If the html representation is largely correct, or you cannot read the image properly, then write "No corrections needed."
41
  4. If the html representation contains errors, generate the corrected html representation.
42
  5. Output only either the corrected html representation or "No corrections needed."
43
  **Example:**
 
74
 
75
  # LLMs don't handle tables with a lot of rows very well
76
  row_count = len(set([cell.row_id for cell in children]))
77
+
78
+ # TODO: eventually chunk the table and inference each chunk
79
+ if row_count > self.max_rows_per_batch:
80
  return
81
 
82
  block_html = block.render(document).html
marker/processors/llm/llm_table_merge.py CHANGED
@@ -55,7 +55,7 @@ You'll specify your judgement in json format - first whether Table 2 should be m
55
 
56
  Table 2 should be merged at the bottom of Table 1 if Table 2 has no headers, and the rows have similar values, meaning that Table 2 continues Table 1. Table 2 should be merged to the right of Table 1 if each row in Table 2 matches a row in Table 1, meaning that Table 2 contains additional columns that augment Table 1.
57
 
58
- Only merge Table 1 and Table 2 if Table 2 cannot be interpreted without merging.
59
 
60
  **Instructions:**
61
  1. Carefully examine the provided table images. Table 1 is the first image, and Table 2 is the second image.
 
55
 
56
  Table 2 should be merged at the bottom of Table 1 if Table 2 has no headers, and the rows have similar values, meaning that Table 2 continues Table 1. Table 2 should be merged to the right of Table 1 if each row in Table 2 matches a row in Table 1, meaning that Table 2 contains additional columns that augment Table 1.
57
 
58
+ Only merge Table 1 and Table 2 if Table 2 cannot be interpreted without merging. Only merge Table 1 and Table 2 if you can read both images properly.
59
 
60
  **Instructions:**
61
  1. Carefully examine the provided table images. Table 1 is the first image, and Table 2 is the second image.
marker/processors/table.py CHANGED
@@ -2,6 +2,8 @@ import re
2
  from collections import defaultdict
3
  from copy import deepcopy
4
  from typing import Annotated, List
 
 
5
 
6
  from ftfy import fix_text
7
  from surya.detection import DetectionPredictor
@@ -67,7 +69,7 @@ class TableProcessor(BaseProcessor):
67
  table_data = []
68
  for page in document.pages:
69
  for block in page.contained_blocks(document, self.block_types):
70
- image = block.get_image(document, highres=True, expansion=(.01, .01))
71
  image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.get_image(highres=True).size)
72
 
73
  table_data.append({
@@ -165,22 +167,35 @@ class TableProcessor(BaseProcessor):
165
 
166
  # Other cells that span into this row
167
  rowspan_cells = [c for c in table.cells if c.row_id != row and c.row_id + c.rowspan > row > c.row_id]
168
- should_split = all([
169
- len(row_cells) > 0,
170
  len(rowspan_cells) == 0,
171
  all([r == 1 for r in rowspans]),
172
  all([l > 1 for l in line_lens]),
173
  all([l == line_lens[0] for l in line_lens])
174
  ])
 
 
 
 
 
 
 
 
 
175
  if should_split:
176
- for i in range(0, line_lens[0]):
177
  for cell in row_cells:
178
- line = cell.text_lines[i]
 
 
 
 
179
  cell_id = max_cell_id + new_cell_count
180
  new_cells.append(
181
  SuryaTableCell(
182
- polygon=line["bbox"],
183
- text_lines=[line],
184
  rowspan=1,
185
  colspan=cell.colspan,
186
  row_id=cell.row_id + shift_up + i,
 
2
  from collections import defaultdict
3
  from copy import deepcopy
4
  from typing import Annotated, List
5
+ from collections import Counter
6
+ from PIL import ImageDraw
7
 
8
  from ftfy import fix_text
9
  from surya.detection import DetectionPredictor
 
69
  table_data = []
70
  for page in document.pages:
71
  for block in page.contained_blocks(document, self.block_types):
72
+ image = block.get_image(document, highres=True)
73
  image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.get_image(highres=True).size)
74
 
75
  table_data.append({
 
167
 
168
  # Other cells that span into this row
169
  rowspan_cells = [c for c in table.cells if c.row_id != row and c.row_id + c.rowspan > row > c.row_id]
170
+ should_split_entire_row = all([
171
+ len(row_cells) > 1,
172
  len(rowspan_cells) == 0,
173
  all([r == 1 for r in rowspans]),
174
  all([l > 1 for l in line_lens]),
175
  all([l == line_lens[0] for l in line_lens])
176
  ])
177
+ line_lens_counter = Counter(line_lens)
178
+ counter_keys = sorted(list(line_lens_counter.keys()))
179
+ should_split_partial_row = all([
180
+ len(row_cells) > 3, # Only split if there are more than 3 cells
181
+ len(rowspan_cells) == 0,
182
+ all([r == 1 for r in rowspans]),
183
+ len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts
184
+ ])
185
+ should_split = should_split_entire_row or should_split_partial_row
186
  if should_split:
187
+ for i in range(0, max(line_lens)):
188
  for cell in row_cells:
189
+ # Calculate height based on number of splits
190
+ split_height = cell.bbox[3] - cell.bbox[1]
191
+ current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height]
192
+
193
+ line = [cell.text_lines[i]] if cell.text_lines and i < len(cell.text_lines) else None
194
  cell_id = max_cell_id + new_cell_count
195
  new_cells.append(
196
  SuryaTableCell(
197
+ polygon=current_bbox,
198
+ text_lines=line,
199
  rowspan=1,
200
  colspan=cell.colspan,
201
  row_id=cell.row_id + shift_up + i,
marker/schema/blocks/base.py CHANGED
@@ -167,9 +167,10 @@ class Block(BaseModel):
167
  def raw_text(self, document: Document) -> str:
168
  from marker.schema.text.line import Line
169
  from marker.schema.text.span import Span
 
170
 
171
  if self.structure is None:
172
- if isinstance(self, Span):
173
  return self.text
174
  else:
175
  return ""
 
167
  def raw_text(self, document: Document) -> str:
168
  from marker.schema.text.line import Line
169
  from marker.schema.text.span import Span
170
+ from marker.schema.blocks.tablecell import TableCell
171
 
172
  if self.structure is None:
173
+ if isinstance(self, (Span, TableCell)):
174
  return self.text
175
  else:
176
  return ""
marker/scripts/__init__.py CHANGED
@@ -1,5 +0,0 @@
1
- from marker.scripts.convert_single import convert_single_cli
2
- from marker.scripts.convert import convert_cli
3
- from marker.scripts.server import server_cli
4
- from marker.scripts.run_streamlit_app import streamlit_app_cli
5
- from marker.scripts.chunk_convert import chunk_convert_cli
 
 
 
 
 
 
marker/scripts/convert.py CHANGED
@@ -100,7 +100,7 @@ def convert_cli(in_folder: str, **kwargs):
100
  else:
101
  model_dict = create_model_dict()
102
  for k, v in model_dict.items():
103
- v.share_memory()
104
 
105
  print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
106
  task_args = [(f, kwargs) for f in files_to_convert]
 
100
  else:
101
  model_dict = create_model_dict()
102
  for k, v in model_dict.items():
103
+ v.model.share_memory()
104
 
105
  print(f"Converting {len(files_to_convert)} pdfs in chunk {kwargs['chunk_idx'] + 1}/{kwargs['num_chunks']} with {total_processes} processes and saving to {kwargs['output_dir']}")
106
  task_args = [(f, kwargs) for f in files_to_convert]
marker/scripts/server.py CHANGED
@@ -3,7 +3,6 @@ import traceback
3
  import click
4
  import os
5
 
6
- import uvicorn
7
  from pydantic import BaseModel, Field
8
  from starlette.responses import HTMLResponse
9
 
@@ -163,6 +162,7 @@ async def convert_pdf_upload(
163
  @click.option("--port", type=int, default=8000, help="Port to run the server on")
164
  @click.option("--host", type=str, default="127.0.0.1", help="Host to run the server on")
165
  def server_cli(port: int, host: str):
 
166
  # Run the server
167
  uvicorn.run(
168
  app,
 
3
  import click
4
  import os
5
 
 
6
  from pydantic import BaseModel, Field
7
  from starlette.responses import HTMLResponse
8
 
 
162
  @click.option("--port", type=int, default=8000, help="Port to run the server on")
163
  @click.option("--host", type=str, default="127.0.0.1", help="Host to run the server on")
164
  def server_cli(port: int, host: str):
165
+ import uvicorn
166
  # Run the server
167
  uvicorn.run(
168
  app,
marker/scripts/streamlit_app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
 
 
2
 
3
  from marker.settings import settings
4
  from streamlit.runtime.uploaded_file_manager import UploadedFile
5
 
6
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
7
- os.environ["IN_STREAMLIT"] = "true"
8
-
9
  import base64
10
  import io
11
  import re
@@ -69,15 +68,12 @@ def markdown_insert_images(markdown, images):
69
  def get_page_image(pdf_file, page_num, dpi=96):
70
  if "pdf" in pdf_file.type:
71
  doc = open_pdf(pdf_file)
72
- renderer = doc.render(
73
- pypdfium2.PdfBitmap.to_pil,
74
- page_indices=[page_num],
75
  scale=dpi / 72,
76
- )
77
- png = list(renderer)[0]
78
- png_image = png.convert("RGB")
79
  else:
80
- png_image = Image.open(in_file).convert("RGB")
81
  return png_image
82
 
83
 
 
1
  import os
2
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3
+ os.environ["IN_STREAMLIT"] = "true"
4
 
5
  from marker.settings import settings
6
  from streamlit.runtime.uploaded_file_manager import UploadedFile
7
 
 
 
 
8
  import base64
9
  import io
10
  import re
 
68
  def get_page_image(pdf_file, page_num, dpi=96):
69
  if "pdf" in pdf_file.type:
70
  doc = open_pdf(pdf_file)
71
+ page = doc[page_num]
72
+ png_image = page.render(
 
73
  scale=dpi / 72,
74
+ ).to_pil().convert("RGB")
 
 
75
  else:
76
+ png_image = Image.open(pdf_file).convert("RGB")
77
  return png_image
78
 
79
 
marker_app.py CHANGED
@@ -1,4 +1,4 @@
1
- from marker.scripts import streamlit_app_cli
2
 
3
  if __name__ == "__main__":
4
  streamlit_app_cli()
 
1
+ from marker.scripts.run_streamlit_app import streamlit_app_cli
2
 
3
  if __name__ == "__main__":
4
  streamlit_app_cli()
marker_server.py CHANGED
@@ -1,4 +1,4 @@
1
- from marker.scripts import server_cli
2
 
3
  if __name__ == "__main__":
4
  server_cli()
 
1
+ from marker.scripts.server import server_cli
2
 
3
  if __name__ == "__main__":
4
  server_cli()
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -23,13 +23,12 @@ transformers = "^4.45.2"
23
  python-dotenv = "^1.0.0"
24
  torch = "^2.5.1"
25
  tqdm = "^4.66.1"
26
- tabulate = "^0.9.0"
27
  ftfy = "^6.1.1"
28
  texify = "^0.2.1"
29
  rapidfuzz = "^3.8.1"
30
- surya-ocr = "~0.8.3"
31
  regex = "^2024.4.28"
32
- pdftext = "~0.4.1"
33
  markdownify = "^0.13.1"
34
  click = "^8.1.7"
35
  google-generativeai = "^0.8.3"
@@ -53,6 +52,7 @@ pytest-mock = "^3.14.0"
53
  apted = "1.0.3"
54
  distance = "0.1.3"
55
  lxml = "5.3.0"
 
56
 
57
  [tool.poetry.scripts]
58
  marker = "marker.scripts.convert:convert_cli"
 
23
  python-dotenv = "^1.0.0"
24
  torch = "^2.5.1"
25
  tqdm = "^4.66.1"
 
26
  ftfy = "^6.1.1"
27
  texify = "^0.2.1"
28
  rapidfuzz = "^3.8.1"
29
+ surya-ocr = "~0.9.0"
30
  regex = "^2024.4.28"
31
+ pdftext = "~0.5.0"
32
  markdownify = "^0.13.1"
33
  click = "^8.1.7"
34
  google-generativeai = "^0.8.3"
 
52
  apted = "1.0.3"
53
  distance = "0.1.3"
54
  lxml = "5.3.0"
55
+ tabulate = "^0.9.0"
56
 
57
  [tool.poetry.scripts]
58
  marker = "marker.scripts.convert:convert_cli"
signatures/version1/cla.json CHANGED
@@ -111,6 +111,38 @@
111
  "created_at": "2024-12-05T13:13:34Z",
112
  "repoId": 712111618,
113
  "pullRequestNo": 416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  }
115
  ]
116
  }
 
111
  "created_at": "2024-12-05T13:13:34Z",
112
  "repoId": 712111618,
113
  "pullRequestNo": 416
114
+ },
115
+ {
116
+ "name": "tarun-menta",
117
+ "id": 66506307,
118
+ "comment_id": 2543907406,
119
+ "created_at": "2024-12-15T15:06:32Z",
120
+ "repoId": 712111618,
121
+ "pullRequestNo": 427
122
+ },
123
+ {
124
+ "name": "ZeyuTeng96",
125
+ "id": 96521059,
126
+ "comment_id": 2567236036,
127
+ "created_at": "2025-01-02T02:36:02Z",
128
+ "repoId": 712111618,
129
+ "pullRequestNo": 452
130
+ },
131
+ {
132
+ "name": "xiaoyao9184",
133
+ "id": 6614349,
134
+ "comment_id": 2571623521,
135
+ "created_at": "2025-01-05T13:15:34Z",
136
+ "repoId": 712111618,
137
+ "pullRequestNo": 463
138
+ },
139
+ {
140
+ "name": "yasyf",
141
+ "id": 709645,
142
+ "comment_id": 2571679069,
143
+ "created_at": "2025-01-05T16:23:12Z",
144
+ "repoId": 712111618,
145
+ "pullRequestNo": 464
146
  }
147
  ]
148
  }
tests/builders/test_garbled_pdf.py CHANGED
@@ -2,10 +2,11 @@ import pytest
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
 
5
  from marker.schema import BlockTypes
6
 
7
  @pytest.mark.filename("water_damage.pdf")
8
- def test_garbled_pdf(pdf_document):
9
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
10
 
11
  table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
@@ -16,9 +17,16 @@ def test_garbled_pdf(pdf_document):
16
  assert table_cell.block_type == BlockTypes.Line
17
  assert table_cell.structure[0] == "/page/0/Span/2"
18
 
19
- span = pdf_document.pages[0].get_block(table_cell.structure[0])
20
  assert span.block_type == BlockTypes.Span
21
- assert "комплекс" in span.text
 
 
 
 
 
 
 
22
 
23
 
24
  @pytest.mark.filename("hindi_judgement.pdf")
@@ -30,7 +38,7 @@ def test_garbled_builder(config, pdf_provider, layout_model, ocr_error_model):
30
 
31
  bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
32
  assert len(bad_ocr_results.labels) == 2
33
- assert all([l == "bad" for l in bad_ocr_results.labels])
34
 
35
 
36
  @pytest.mark.filename("adversarial.pdf")
 
2
 
3
  from marker.builders.document import DocumentBuilder
4
  from marker.builders.layout import LayoutBuilder
5
+ from marker.processors.table import TableProcessor
6
  from marker.schema import BlockTypes
7
 
8
  @pytest.mark.filename("water_damage.pdf")
9
+ def test_garbled_pdf(pdf_document, detection_model, recognition_model, table_rec_model):
10
  assert pdf_document.pages[0].structure[0] == '/page/0/Table/0'
11
 
12
  table_block = pdf_document.pages[0].get_block(pdf_document.pages[0].structure[0])
 
17
  assert table_cell.block_type == BlockTypes.Line
18
  assert table_cell.structure[0] == "/page/0/Span/2"
19
 
20
+ span = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Span,))[0]
21
  assert span.block_type == BlockTypes.Span
22
+ assert len(span.text.strip()) == 0
23
+
24
+ # We don't OCR in the initial pass, only with the TableProcessor
25
+ processor = TableProcessor(detection_model, recognition_model, table_rec_model)
26
+ processor(pdf_document)
27
+
28
+ table = pdf_document.pages[0].contained_blocks(pdf_document, (BlockTypes.Table,))[0]
29
+ assert "варіант" in table.raw_text(pdf_document)
30
 
31
 
32
  @pytest.mark.filename("hindi_judgement.pdf")
 
38
 
39
  bad_ocr_results = layout_builder.surya_ocr_error_detection(document.pages, pdf_provider.page_lines)
40
  assert len(bad_ocr_results.labels) == 2
41
+ assert any([l == "bad" for l in bad_ocr_results.labels])
42
 
43
 
44
  @pytest.mark.filename("adversarial.pdf")