Vik Paruchuri
commited on
Commit
·
a78718d
1
Parent(s):
f712a74
Add gemini to table bench
Browse files- README.md +1 -1
- benchmarks/table/gemini.py +49 -0
- benchmarks/table/table.py +56 -16
README.md
CHANGED
|
@@ -400,7 +400,7 @@ Marker can extract tables from PDFs using `marker.converters.table.TableConverte
|
|
| 400 |
|
| 401 |
| Avg score | Total tables | use_llm |
|
| 402 |
|-----------|--------------|---------|
|
| 403 |
-
| 0.
|
| 404 |
| 0.887 | 54 | True |
|
| 405 |
|
| 406 |
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
|
|
|
|
| 400 |
|
| 401 |
| Avg score | Total tables | use_llm |
|
| 402 |
|-----------|--------------|---------|
|
| 403 |
+
| 0.822 | 54 | False |
|
| 404 |
| 0.887 | 54 | True |
|
| 405 |
|
| 406 |
The `--use_llm` flag can significantly improve table recognition performance, as you can see.
|
benchmarks/table/gemini.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import google.generativeai as genai
|
| 4 |
+
from google.ai.generativelanguage_v1beta.types import content
|
| 5 |
+
from marker.settings import settings
|
| 6 |
+
|
| 7 |
+
prompt = """
|
| 8 |
+
You're an expert document analyst who is good at turning tables in documents into HTML. Analyze the provided image, and convert it to a faithful HTML representation.
|
| 9 |
+
|
| 10 |
+
Guidelines:
|
| 11 |
+
- Keep the HTML simple and concise.
|
| 12 |
+
- Only include the <table> tag and contents.
|
| 13 |
+
- Only use <table>, <tr>, and <td> tags. Only use the colspan and rowspan attributes if necessary. Do not use <tbody>, <thead>, or <th> tags.
|
| 14 |
+
- Make sure the table is as faithful to the image as possible with the given tags.
|
| 15 |
+
|
| 16 |
+
**Instructions**
|
| 17 |
+
1. Analyze the image, and determine the table structure.
|
| 18 |
+
2. Convert the table image to HTML, following the guidelines above.
|
| 19 |
+
3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
|
| 20 |
+
""".strip()
|
| 21 |
+
|
| 22 |
+
genai.configure(api_key=settings.GOOGLE_API_KEY)
|
| 23 |
+
|
| 24 |
+
def gemini_table_rec(image: Image.Image):
|
| 25 |
+
schema = content.Schema(
|
| 26 |
+
type=content.Type.OBJECT,
|
| 27 |
+
required=["table_html"],
|
| 28 |
+
properties={
|
| 29 |
+
"table_html": content.Schema(
|
| 30 |
+
type=content.Type.STRING,
|
| 31 |
+
)
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
| 36 |
+
|
| 37 |
+
responses = model.generate_content(
|
| 38 |
+
[image, prompt], # According to gemini docs, it performs better if the image is the first element
|
| 39 |
+
stream=False,
|
| 40 |
+
generation_config={
|
| 41 |
+
"temperature": 0,
|
| 42 |
+
"response_schema": schema,
|
| 43 |
+
"response_mime_type": "application/json",
|
| 44 |
+
},
|
| 45 |
+
request_options={'timeout': 60}
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
output = responses.candidates[0].content.parts[0].text
|
| 49 |
+
return json.loads(output)["table_html"]
|
benchmarks/table/table.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
-
from
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 7 |
|
| 8 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 9 |
|
|
|
|
|
|
|
| 10 |
import base64
|
| 11 |
import time
|
| 12 |
import datasets
|
|
@@ -16,21 +15,24 @@ import click
|
|
| 16 |
from tabulate import tabulate
|
| 17 |
import json
|
| 18 |
from bs4 import BeautifulSoup
|
| 19 |
-
from concurrent.futures import
|
| 20 |
from pypdfium2._helpers.misc import PdfiumError
|
|
|
|
| 21 |
from marker.util import matrix_intersection_area
|
|
|
|
| 22 |
|
| 23 |
from marker.config.parser import ConfigParser
|
| 24 |
from marker.converters.table import TableConverter
|
| 25 |
from marker.models import create_model_dict
|
| 26 |
|
| 27 |
from scoring import wrap_table_html, similarity_eval_html
|
|
|
|
| 28 |
|
| 29 |
-
def update_teds_score(result):
|
| 30 |
-
prediction, ground_truth = result['
|
| 31 |
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
|
| 32 |
score = similarity_eval_html(prediction, ground_truth)
|
| 33 |
-
result.update({'
|
| 34 |
return result
|
| 35 |
|
| 36 |
|
|
@@ -51,7 +53,16 @@ def extract_tables(children: List[JSONBlockOutput]):
|
|
| 51 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 52 |
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 53 |
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
models = create_model_dict()
|
| 56 |
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
|
| 57 |
start = time.time()
|
|
@@ -86,6 +97,9 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
|
|
| 86 |
marker_json = converter(temp_pdf_file.name).children
|
| 87 |
tqdm.disable = False
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 90 |
print(f'No tables detected, skipping...')
|
| 91 |
total_unaligned += len(gt_tables)
|
|
@@ -94,6 +108,8 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
|
|
| 94 |
marker_tables = extract_tables(marker_json)
|
| 95 |
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 96 |
page_bbox = marker_json[0].bbox
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Normalize the bboxes
|
| 99 |
for bbox in marker_table_boxes:
|
|
@@ -136,14 +152,18 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
|
|
| 136 |
unaligned_tables.add(table_idx)
|
| 137 |
continue
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
aligned_tables.append(
|
| 140 |
-
(marker_tables[aligned_idx], gt_tables[table_idx])
|
| 141 |
)
|
| 142 |
used_tables.add(aligned_idx)
|
| 143 |
|
| 144 |
total_unaligned += len(unaligned_tables)
|
| 145 |
|
| 146 |
-
for marker_table, gt_table in aligned_tables:
|
| 147 |
gt_table_html = gt_table['html']
|
| 148 |
|
| 149 |
#marker wraps the table in <tbody> which fintabnet data doesn't
|
|
@@ -154,10 +174,12 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
|
|
| 154 |
th_tag.name = 'td'
|
| 155 |
marker_table_html = str(marker_table_soup)
|
| 156 |
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
|
|
|
| 157 |
|
| 158 |
results.append({
|
| 159 |
"marker_table": marker_table_html,
|
| 160 |
-
"gt_table": gt_table_html
|
|
|
|
| 161 |
})
|
| 162 |
except PdfiumError:
|
| 163 |
print('Broken PDF, Skipping...')
|
|
@@ -167,19 +189,37 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
|
|
| 167 |
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
| 168 |
|
| 169 |
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 170 |
-
|
| 171 |
tqdm(
|
| 172 |
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
|
| 173 |
)
|
| 174 |
)
|
| 175 |
-
avg_score = sum([r["score"] for r in results]) / len(results)
|
| 176 |
|
|
|
|
| 177 |
headers = ["Avg score", "Total tables"]
|
| 178 |
-
data = [f"{avg_score:.3f}", len(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
table = tabulate([data], headers=headers, tablefmt="github")
|
| 180 |
print(table)
|
| 181 |
print("Avg score computed by comparing marker predicted HTML with original HTML")
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
with open(out_file, "w+") as f:
|
| 184 |
json.dump(results, f, indent=2)
|
| 185 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
from itertools import repeat
|
| 3 |
+
from tkinter import Image
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
|
| 6 |
|
| 7 |
+
from typing import List
|
| 8 |
+
import numpy as np
|
| 9 |
import base64
|
| 10 |
import time
|
| 11 |
import datasets
|
|
|
|
| 15 |
from tabulate import tabulate
|
| 16 |
import json
|
| 17 |
from bs4 import BeautifulSoup
|
| 18 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 19 |
from pypdfium2._helpers.misc import PdfiumError
|
| 20 |
+
import pypdfium2 as pdfium
|
| 21 |
from marker.util import matrix_intersection_area
|
| 22 |
+
from marker.renderers.json import JSONOutput, JSONBlockOutput
|
| 23 |
|
| 24 |
from marker.config.parser import ConfigParser
|
| 25 |
from marker.converters.table import TableConverter
|
| 26 |
from marker.models import create_model_dict
|
| 27 |
|
| 28 |
from scoring import wrap_table_html, similarity_eval_html
|
| 29 |
+
from gemini import gemini_table_rec
|
| 30 |
|
| 31 |
+
def update_teds_score(result, prefix: str = "marker"):
|
| 32 |
+
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
|
| 33 |
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
|
| 34 |
score = similarity_eval_html(prediction, ground_truth)
|
| 35 |
+
result.update({f'{prefix}_score':score})
|
| 36 |
return result
|
| 37 |
|
| 38 |
|
|
|
|
| 53 |
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
| 54 |
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
| 55 |
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
| 56 |
+
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
|
| 57 |
+
def main(
|
| 58 |
+
out_file: str,
|
| 59 |
+
dataset: str,
|
| 60 |
+
max_rows: int,
|
| 61 |
+
max_workers: int,
|
| 62 |
+
use_llm: bool,
|
| 63 |
+
table_rec_batch_size: int | None,
|
| 64 |
+
use_gemini: bool = False
|
| 65 |
+
):
|
| 66 |
models = create_model_dict()
|
| 67 |
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
|
| 68 |
start = time.time()
|
|
|
|
| 97 |
marker_json = converter(temp_pdf_file.name).children
|
| 98 |
tqdm.disable = False
|
| 99 |
|
| 100 |
+
doc = pdfium.PdfDocument(temp_pdf_file.name)
|
| 101 |
+
page_image = doc[0].render(scale=92/72).to_pil()
|
| 102 |
+
|
| 103 |
if len(marker_json) == 0 or len(gt_tables) == 0:
|
| 104 |
print(f'No tables detected, skipping...')
|
| 105 |
total_unaligned += len(gt_tables)
|
|
|
|
| 108 |
marker_tables = extract_tables(marker_json)
|
| 109 |
marker_table_boxes = [table.bbox for table in marker_tables]
|
| 110 |
page_bbox = marker_json[0].bbox
|
| 111 |
+
w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]
|
| 112 |
+
table_images = [page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox in marker_table_boxes]
|
| 113 |
|
| 114 |
# Normalize the bboxes
|
| 115 |
for bbox in marker_table_boxes:
|
|
|
|
| 152 |
unaligned_tables.add(table_idx)
|
| 153 |
continue
|
| 154 |
|
| 155 |
+
gemini_html = ""
|
| 156 |
+
if use_gemini:
|
| 157 |
+
gemini_html = gemini_table_rec(table_images[aligned_idx])
|
| 158 |
+
|
| 159 |
aligned_tables.append(
|
| 160 |
+
(marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
|
| 161 |
)
|
| 162 |
used_tables.add(aligned_idx)
|
| 163 |
|
| 164 |
total_unaligned += len(unaligned_tables)
|
| 165 |
|
| 166 |
+
for marker_table, gt_table, gemini_table in aligned_tables:
|
| 167 |
gt_table_html = gt_table['html']
|
| 168 |
|
| 169 |
#marker wraps the table in <tbody> which fintabnet data doesn't
|
|
|
|
| 174 |
th_tag.name = 'td'
|
| 175 |
marker_table_html = str(marker_table_soup)
|
| 176 |
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 177 |
+
gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines
|
| 178 |
|
| 179 |
results.append({
|
| 180 |
"marker_table": marker_table_html,
|
| 181 |
+
"gt_table": gt_table_html,
|
| 182 |
+
"gemini_table": gemini_table_html
|
| 183 |
})
|
| 184 |
except PdfiumError:
|
| 185 |
print('Broken PDF, Skipping...')
|
|
|
|
| 189 |
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
| 190 |
|
| 191 |
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 192 |
+
marker_results = list(
|
| 193 |
tqdm(
|
| 194 |
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
|
| 195 |
)
|
| 196 |
)
|
|
|
|
| 197 |
|
| 198 |
+
avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
|
| 199 |
headers = ["Avg score", "Total tables"]
|
| 200 |
+
data = [f"{avg_score:.3f}", len(marker_results)]
|
| 201 |
+
gemini_results = None
|
| 202 |
+
if use_gemini:
|
| 203 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 204 |
+
gemini_results = list(
|
| 205 |
+
tqdm(
|
| 206 |
+
executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
|
| 207 |
+
total=len(results)
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
|
| 211 |
+
headers.append("Avg Gemini score")
|
| 212 |
+
data.append(f"{avg_gemini_score:.3f}")
|
| 213 |
+
|
| 214 |
table = tabulate([data], headers=headers, tablefmt="github")
|
| 215 |
print(table)
|
| 216 |
print("Avg score computed by comparing marker predicted HTML with original HTML")
|
| 217 |
|
| 218 |
+
results = {
|
| 219 |
+
"marker": marker_results,
|
| 220 |
+
"gemini": gemini_results
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
with open(out_file, "w+") as f:
|
| 224 |
json.dump(results, f, indent=2)
|
| 225 |
|