Vik Paruchuri commited on
Commit
a78718d
·
1 Parent(s): f712a74

Add gemini to table bench

Browse files
README.md CHANGED
@@ -400,7 +400,7 @@ Marker can extract tables from PDFs using `marker.converters.table.TableConverte
400
 
401
  | Avg score | Total tables | use_llm |
402
  |-----------|--------------|---------|
403
- | 0.82 | 54 | False |
404
  | 0.887 | 54 | True |
405
 
406
  The `--use_llm` flag can significantly improve table recognition performance, as you can see.
 
400
 
401
  | Avg score | Total tables | use_llm |
402
  |-----------|--------------|---------|
403
+ | 0.822 | 54 | False |
404
  | 0.887 | 54 | True |
405
 
406
  The `--use_llm` flag can significantly improve table recognition performance, as you can see.
benchmarks/table/gemini.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from PIL import Image
3
+ import google.generativeai as genai
4
+ from google.ai.generativelanguage_v1beta.types import content
5
+ from marker.settings import settings
6
+
7
+ prompt = """
8
+ You're an expert document analyst who is good at turning tables in documents into HTML. Analyze the provided image, and convert it to a faithful HTML representation.
9
+
10
+ Guidelines:
11
+ - Keep the HTML simple and concise.
12
+ - Only include the <table> tag and contents.
13
+ - Only use <table>, <tr>, and <td> tags. Only use the colspan and rowspan attributes if necessary. Do not use <tbody>, <thead>, or <th> tags.
14
+ - Make sure the table is as faithful to the image as possible with the given tags.
15
+
16
+ **Instructions**
17
+ 1. Analyze the image, and determine the table structure.
18
+ 2. Convert the table image to HTML, following the guidelines above.
19
+ 3. Output only the HTML for the table, starting with the <table> tag and ending with the </table> tag.
20
+ """.strip()
21
+
22
+ genai.configure(api_key=settings.GOOGLE_API_KEY)
23
+
24
+ def gemini_table_rec(image: Image.Image):
25
+ schema = content.Schema(
26
+ type=content.Type.OBJECT,
27
+ required=["table_html"],
28
+ properties={
29
+ "table_html": content.Schema(
30
+ type=content.Type.STRING,
31
+ )
32
+ }
33
+ )
34
+
35
+ model = genai.GenerativeModel("gemini-1.5-flash")
36
+
37
+ responses = model.generate_content(
38
+ [image, prompt], # According to gemini docs, it performs better if the image is the first element
39
+ stream=False,
40
+ generation_config={
41
+ "temperature": 0,
42
+ "response_schema": schema,
43
+ "response_mime_type": "application/json",
44
+ },
45
+ request_options={'timeout': 60}
46
+ )
47
+
48
+ output = responses.candidates[0].content.parts[0].text
49
+ return json.loads(output)["table_html"]
benchmarks/table/table.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- from typing import List
3
-
4
- import numpy as np
5
-
6
- from marker.renderers.json import JSONOutput, JSONBlockOutput
7
 
8
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
9
 
 
 
10
  import base64
11
  import time
12
  import datasets
@@ -16,21 +15,24 @@ import click
16
  from tabulate import tabulate
17
  import json
18
  from bs4 import BeautifulSoup
19
- from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
20
  from pypdfium2._helpers.misc import PdfiumError
 
21
  from marker.util import matrix_intersection_area
 
22
 
23
  from marker.config.parser import ConfigParser
24
  from marker.converters.table import TableConverter
25
  from marker.models import create_model_dict
26
 
27
  from scoring import wrap_table_html, similarity_eval_html
 
28
 
29
- def update_teds_score(result):
30
- prediction, ground_truth = result['marker_table'], result['gt_table']
31
  prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
32
  score = similarity_eval_html(prediction, ground_truth)
33
- result.update({'score':score})
34
  return result
35
 
36
 
@@ -51,7 +53,16 @@ def extract_tables(children: List[JSONBlockOutput]):
51
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
52
  @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
53
  @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
54
- def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm: bool, table_rec_batch_size: int | None):
 
 
 
 
 
 
 
 
 
55
  models = create_model_dict()
56
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
57
  start = time.time()
@@ -86,6 +97,9 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
86
  marker_json = converter(temp_pdf_file.name).children
87
  tqdm.disable = False
88
 
 
 
 
89
  if len(marker_json) == 0 or len(gt_tables) == 0:
90
  print(f'No tables detected, skipping...')
91
  total_unaligned += len(gt_tables)
@@ -94,6 +108,8 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
94
  marker_tables = extract_tables(marker_json)
95
  marker_table_boxes = [table.bbox for table in marker_tables]
96
  page_bbox = marker_json[0].bbox
 
 
97
 
98
  # Normalize the bboxes
99
  for bbox in marker_table_boxes:
@@ -136,14 +152,18 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
136
  unaligned_tables.add(table_idx)
137
  continue
138
 
 
 
 
 
139
  aligned_tables.append(
140
- (marker_tables[aligned_idx], gt_tables[table_idx])
141
  )
142
  used_tables.add(aligned_idx)
143
 
144
  total_unaligned += len(unaligned_tables)
145
 
146
- for marker_table, gt_table in aligned_tables:
147
  gt_table_html = gt_table['html']
148
 
149
  #marker wraps the table in <tbody> which fintabnet data doesn't
@@ -154,10 +174,12 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
154
  th_tag.name = 'td'
155
  marker_table_html = str(marker_table_soup)
156
  marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
 
157
 
158
  results.append({
159
  "marker_table": marker_table_html,
160
- "gt_table": gt_table_html
 
161
  })
162
  except PdfiumError:
163
  print('Broken PDF, Skipping...')
@@ -167,19 +189,37 @@ def main(out_file: str, dataset: str, max_rows: int, max_workers: int, use_llm:
167
  print(f"Could not align {total_unaligned} tables from fintabnet.")
168
 
169
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
170
- results = list(
171
  tqdm(
172
  executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
173
  )
174
  )
175
- avg_score = sum([r["score"] for r in results]) / len(results)
176
 
 
177
  headers = ["Avg score", "Total tables"]
178
- data = [f"{avg_score:.3f}", len(results)]
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  table = tabulate([data], headers=headers, tablefmt="github")
180
  print(table)
181
  print("Avg score computed by comparing marker predicted HTML with original HTML")
182
 
 
 
 
 
 
183
  with open(out_file, "w+") as f:
184
  json.dump(results, f, indent=2)
185
 
 
1
  import os
2
+ from itertools import repeat
3
+ from tkinter import Image
 
 
 
4
 
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
6
 
7
+ from typing import List
8
+ import numpy as np
9
  import base64
10
  import time
11
  import datasets
 
15
  from tabulate import tabulate
16
  import json
17
  from bs4 import BeautifulSoup
18
+ from concurrent.futures import ProcessPoolExecutor
19
  from pypdfium2._helpers.misc import PdfiumError
20
+ import pypdfium2 as pdfium
21
  from marker.util import matrix_intersection_area
22
+ from marker.renderers.json import JSONOutput, JSONBlockOutput
23
 
24
  from marker.config.parser import ConfigParser
25
  from marker.converters.table import TableConverter
26
  from marker.models import create_model_dict
27
 
28
  from scoring import wrap_table_html, similarity_eval_html
29
+ from gemini import gemini_table_rec
30
 
31
+ def update_teds_score(result, prefix: str = "marker"):
32
+ prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
33
  prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
34
  score = similarity_eval_html(prediction, ground_truth)
35
+ result.update({f'{prefix}_score':score})
36
  return result
37
 
38
 
 
53
  @click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
54
  @click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
55
  @click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
56
+ @click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
57
+ def main(
58
+ out_file: str,
59
+ dataset: str,
60
+ max_rows: int,
61
+ max_workers: int,
62
+ use_llm: bool,
63
+ table_rec_batch_size: int | None,
64
+ use_gemini: bool = False
65
+ ):
66
  models = create_model_dict()
67
  config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size})
68
  start = time.time()
 
97
  marker_json = converter(temp_pdf_file.name).children
98
  tqdm.disable = False
99
 
100
+ doc = pdfium.PdfDocument(temp_pdf_file.name)
101
+ page_image = doc[0].render(scale=92/72).to_pil()
102
+
103
  if len(marker_json) == 0 or len(gt_tables) == 0:
104
  print(f'No tables detected, skipping...')
105
  total_unaligned += len(gt_tables)
 
108
  marker_tables = extract_tables(marker_json)
109
  marker_table_boxes = [table.bbox for table in marker_tables]
110
  page_bbox = marker_json[0].bbox
111
+ w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]
112
+ table_images = [page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox in marker_table_boxes]
113
 
114
  # Normalize the bboxes
115
  for bbox in marker_table_boxes:
 
152
  unaligned_tables.add(table_idx)
153
  continue
154
 
155
+ gemini_html = ""
156
+ if use_gemini:
157
+ gemini_html = gemini_table_rec(table_images[aligned_idx])
158
+
159
  aligned_tables.append(
160
+ (marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
161
  )
162
  used_tables.add(aligned_idx)
163
 
164
  total_unaligned += len(unaligned_tables)
165
 
166
+ for marker_table, gt_table, gemini_table in aligned_tables:
167
  gt_table_html = gt_table['html']
168
 
169
  #marker wraps the table in <tbody> which fintabnet data doesn't
 
174
  th_tag.name = 'td'
175
  marker_table_html = str(marker_table_soup)
176
  marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
177
+ gemini_table_html = gemini_table.replace("\n", " ") # Fintabnet uses spaces instead of newlines
178
 
179
  results.append({
180
  "marker_table": marker_table_html,
181
+ "gt_table": gt_table_html,
182
+ "gemini_table": gemini_table_html
183
  })
184
  except PdfiumError:
185
  print('Broken PDF, Skipping...')
 
189
  print(f"Could not align {total_unaligned} tables from fintabnet.")
190
 
191
  with ProcessPoolExecutor(max_workers=max_workers) as executor:
192
+ marker_results = list(
193
  tqdm(
194
  executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
195
  )
196
  )
 
197
 
198
+ avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
199
  headers = ["Avg score", "Total tables"]
200
+ data = [f"{avg_score:.3f}", len(marker_results)]
201
+ gemini_results = None
202
+ if use_gemini:
203
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
204
+ gemini_results = list(
205
+ tqdm(
206
+ executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
207
+ total=len(results)
208
+ )
209
+ )
210
+ avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
211
+ headers.append("Avg Gemini score")
212
+ data.append(f"{avg_gemini_score:.3f}")
213
+
214
  table = tabulate([data], headers=headers, tablefmt="github")
215
  print(table)
216
  print("Avg score computed by comparing marker predicted HTML with original HTML")
217
 
218
+ results = {
219
+ "marker": marker_results,
220
+ "gemini": gemini_results
221
+ }
222
+
223
  with open(out_file, "w+") as f:
224
  json.dump(results, f, indent=2)
225