Vik Paruchuri commited on
Commit
1dfe667
·
1 Parent(s): 30e3e08

Test table merge

Browse files
benchmarks/table/table.py CHANGED
@@ -49,9 +49,10 @@ def extract_tables(children: List[JSONBlockOutput]):
49
  @click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
50
  @click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
51
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
52
- def main(out_file: str, dataset: str, max_rows: int, max_workers: int):
 
53
  models = create_model_dict()
54
- config_parser = ConfigParser({'output_format': 'json'})
55
  start = time.time()
56
 
57
 
 
49
  @click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
50
  @click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
51
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
52
+ @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
53
+ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool):
54
  models = create_model_dict()
55
+ config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm})
56
  start = time.time()
57
 
58
 
marker/builders/llm_layout.py CHANGED
@@ -169,7 +169,6 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
169
  )
170
 
171
  response = self.model.generate_response(prompt, image, block, response_schema)
172
- print(response)
173
  generated_label = None
174
  if response and "label" in response:
175
  generated_label = response["label"]
 
169
  )
170
 
171
  response = self.model.generate_response(prompt, image, block, response_schema)
 
172
  generated_label = None
173
  if response and "label" in response:
174
  generated_label = response["label"]
marker/processors/llm/llm_table_merge.py CHANGED
@@ -163,7 +163,7 @@ Table 2
163
  same_page_new_column = all([
164
  prev_block.page_id == block.page_id, # On the same page
165
  abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
166
- block.y_start < prev_block.y_end,
167
  block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
168
  col_match
169
  ])
 
163
  same_page_new_column = all([
164
  prev_block.page_id == block.page_id, # On the same page
165
  abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
166
+ block.polygon.y_start < prev_block.polygon.y_end,
167
  block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
168
  col_match
169
  ])
tests/processors/test_table_merge.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import Mock
2
+
3
+ import pytest
4
+
5
+ from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
6
+ from marker.processors.table import TableProcessor
7
+ from marker.schema import BlockTypes
8
+
9
+
10
+ @pytest.mark.filename("table_ex2.pdf")
11
+ def test_llm_table_processor_nomerge(pdf_document, detection_model, table_rec_model, recognition_model, mocker):
12
+ mock_cls = Mock()
13
+ mock_cls.return_value.generate_response.return_value = {
14
+ "merge": "true",
15
+ "direction": "right"
16
+ }
17
+ mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
18
+
19
+ cell_processor = TableProcessor(detection_model, recognition_model, table_rec_model)
20
+ cell_processor(pdf_document)
21
+
22
+ tables = pdf_document.contained_blocks((BlockTypes.Table,))
23
+ assert len(tables) == 3
24
+
25
+ processor = LLMTableMergeProcessor({"use_llm": True, "google_api_key": "test"})
26
+ processor(pdf_document)
27
+
28
+ tables = pdf_document.contained_blocks((BlockTypes.Table,))
29
+ assert len(tables) == 3