Vik Paruchuri
commited on
Commit
·
1dfe667
1
Parent(s):
30e3e08
Test table merge
Browse files
benchmarks/table/table.py
CHANGED
|
@@ -49,9 +49,10 @@ def extract_tables(children: List[JSONBlockOutput]):
|
|
| 49 |
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
|
| 50 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
| 51 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 52 |
-
|
|
|
|
| 53 |
models = create_model_dict()
|
| 54 |
-
config_parser = ConfigParser({'output_format': 'json'})
|
| 55 |
start = time.time()
|
| 56 |
|
| 57 |
|
|
|
|
| 49 |
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
|
| 50 |
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
| 51 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 52 |
+
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 53 |
+
def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool):
|
| 54 |
models = create_model_dict()
|
| 55 |
+
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm})
|
| 56 |
start = time.time()
|
| 57 |
|
| 58 |
|
marker/builders/llm_layout.py
CHANGED
|
@@ -169,7 +169,6 @@ Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form
|
|
| 169 |
)
|
| 170 |
|
| 171 |
response = self.model.generate_response(prompt, image, block, response_schema)
|
| 172 |
-
print(response)
|
| 173 |
generated_label = None
|
| 174 |
if response and "label" in response:
|
| 175 |
generated_label = response["label"]
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
response = self.model.generate_response(prompt, image, block, response_schema)
|
|
|
|
| 172 |
generated_label = None
|
| 173 |
if response and "label" in response:
|
| 174 |
generated_label = response["label"]
|
marker/processors/llm/llm_table_merge.py
CHANGED
|
@@ -163,7 +163,7 @@ Table 2
|
|
| 163 |
same_page_new_column = all([
|
| 164 |
prev_block.page_id == block.page_id, # On the same page
|
| 165 |
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
|
| 166 |
-
block.y_start < prev_block.y_end,
|
| 167 |
block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
|
| 168 |
col_match
|
| 169 |
])
|
|
|
|
| 163 |
same_page_new_column = all([
|
| 164 |
prev_block.page_id == block.page_id, # On the same page
|
| 165 |
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
|
| 166 |
+
block.polygon.y_start < prev_block.polygon.y_end,
|
| 167 |
block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
|
| 168 |
col_match
|
| 169 |
])
|
tests/processors/test_table_merge.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import Mock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from marker.processors.llm.llm_table_merge import LLMTableMergeProcessor
|
| 6 |
+
from marker.processors.table import TableProcessor
|
| 7 |
+
from marker.schema import BlockTypes
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.mark.filename("table_ex2.pdf")
|
| 11 |
+
def test_llm_table_processor_nomerge(pdf_document, detection_model, table_rec_model, recognition_model, mocker):
|
| 12 |
+
mock_cls = Mock()
|
| 13 |
+
mock_cls.return_value.generate_response.return_value = {
|
| 14 |
+
"merge": "true",
|
| 15 |
+
"direction": "right"
|
| 16 |
+
}
|
| 17 |
+
mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
|
| 18 |
+
|
| 19 |
+
cell_processor = TableProcessor(detection_model, recognition_model, table_rec_model)
|
| 20 |
+
cell_processor(pdf_document)
|
| 21 |
+
|
| 22 |
+
tables = pdf_document.contained_blocks((BlockTypes.Table,))
|
| 23 |
+
assert len(tables) == 3
|
| 24 |
+
|
| 25 |
+
processor = LLMTableMergeProcessor({"use_llm": True, "google_api_key": "test"})
|
| 26 |
+
processor(pdf_document)
|
| 27 |
+
|
| 28 |
+
tables = pdf_document.contained_blocks((BlockTypes.Table,))
|
| 29 |
+
assert len(tables) == 3
|