hkai20000 commited on
Commit
af2ef1c
·
verified ·
1 Parent(s): 7dc006e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +115 -160
main.py CHANGED
@@ -4,12 +4,15 @@ from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
5
  from doctr.io import DocumentFile
6
  from doctr.models import ocr_predictor
 
 
7
  import cv2
8
  import numpy as np
9
  from PIL import Image
10
  import io
11
  import json
12
  import os
 
13
  from typing import Dict, Any, Optional, List
14
 
15
  app = FastAPI(title="ScanAssured OCR & NER API")
@@ -195,112 +198,100 @@ def basic_cleanup(text: str) -> str:
195
  return text
196
 
197
 
198
- # --- TABLE DETECTION AND EXTRACTION ---
199
 
200
- def detect_columns(words_data: list, min_gap_ratio: float = 0.03) -> list:
201
- """
202
- Detect column boundaries by analyzing gaps between words.
203
- Returns list of column boundaries [(x_start, x_end), ...].
204
- """
205
- if not words_data:
206
- return []
207
-
208
- all_x_starts = sorted(set(w['x'] for w in words_data))
209
-
210
- if len(all_x_starts) < 2:
211
- return [(0, 1)]
212
-
213
- x_clusters = []
214
- current_cluster = [all_x_starts[0]]
215
-
216
- for i in range(1, len(all_x_starts)):
217
- gap = all_x_starts[i] - all_x_starts[i-1]
218
- if gap > min_gap_ratio:
219
- x_clusters.append(current_cluster)
220
- current_cluster = [all_x_starts[i]]
221
- else:
222
- current_cluster.append(all_x_starts[i])
223
- x_clusters.append(current_cluster)
224
-
225
- if len(x_clusters) >= 2:
226
- columns = []
227
- for i, cluster in enumerate(x_clusters):
228
- x_start = min(cluster) - 0.01
229
- if i < len(x_clusters) - 1:
230
- x_end = min(x_clusters[i + 1]) - 0.005
231
- else:
232
- x_end = 1.0
233
- columns.append((max(0, x_start), min(1, x_end)))
234
- return columns
235
 
236
- return [(0, 1)]
237
-
238
-
239
- def detect_rows(words_data: list, y_tolerance: float = 0.015) -> list:
240
- """Detect row boundaries by analyzing y-positions."""
241
- if not words_data:
242
- return []
243
-
244
- y_positions = sorted(set(w['y'] for w in words_data))
245
-
246
- if not y_positions:
247
- return []
248
-
249
- rows = []
250
- current_row_ys = [y_positions[0]]
251
-
252
- for i in range(1, len(y_positions)):
253
- if y_positions[i] - y_positions[i-1] <= y_tolerance:
254
- current_row_ys.append(y_positions[i])
255
- else:
256
- rows.append(sum(current_row_ys) / len(current_row_ys))
257
- current_row_ys = [y_positions[i]]
258
 
259
- rows.append(sum(current_row_ys) / len(current_row_ys))
260
- return rows
261
 
262
-
263
- def extract_table_structure(words_data: list) -> dict:
264
  """
265
- Extract table structure from words, returning rows and columns.
 
266
  """
267
- if not words_data:
268
- return {'is_table': False, 'columns': [], 'rows': [], 'cells': []}
269
-
270
- columns = detect_columns(words_data)
271
- rows = detect_rows(words_data)
272
-
273
- is_table = len(columns) >= 2 and len(rows) >= 2
274
-
275
- if not is_table:
276
- return {'is_table': False, 'columns': columns, 'rows': rows, 'cells': []}
 
 
 
 
 
 
 
 
 
 
277
 
278
- y_tolerance = 0.02
279
- cells = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- for row_y in rows:
282
- row_cells = [''] * len(columns)
283
- row_words = [w for w in words_data if abs(w['y'] - row_y) <= y_tolerance]
284
 
285
- for word in row_words:
286
- for col_idx, (col_start, col_end) in enumerate(columns):
287
- if col_start <= word['x'] < col_end:
288
- if row_cells[col_idx]:
289
- row_cells[col_idx] += ' ' + word['text']
290
- else:
291
- row_cells[col_idx] = word['text']
292
- break
293
 
294
- cells.append(row_cells)
 
 
 
 
 
 
 
295
 
296
- return {
297
- 'is_table': True,
298
- 'columns': columns,
299
- 'rows': rows,
300
- 'cells': cells,
301
- 'num_columns': len(columns),
302
- 'num_rows': len(rows)
303
- }
304
 
305
 
306
  def format_table_as_markdown(table_data: dict) -> str:
@@ -312,22 +303,27 @@ def format_table_as_markdown(table_data: dict) -> str:
312
  if not cells:
313
  return ''
314
 
315
- num_cols = len(cells[0]) if cells else 0
316
  if num_cols == 0:
317
  return ''
318
 
319
  lines = []
320
  col_widths = [3] * num_cols
 
 
 
321
  for row in cells:
322
- for i, cell in enumerate(row):
 
 
323
  if i < num_cols:
324
- col_widths[i] = max(col_widths[i], len(cell))
325
 
326
- for row_idx, row in enumerate(cells):
327
  formatted_cells = []
328
  for i, cell in enumerate(row):
329
  if i < num_cols:
330
- formatted_cells.append(cell.ljust(col_widths[i]))
331
 
332
  line = '| ' + ' | '.join(formatted_cells) + ' |'
333
  lines.append(line)
@@ -339,64 +335,18 @@ def format_table_as_markdown(table_data: dict) -> str:
339
  return '\n'.join(lines)
340
 
341
 
342
- def extract_text_with_table_detection(result) -> tuple:
343
  """
344
- Extract text from docTR result, detecting and preserving table structure.
345
- Returns (structured_text, table_data).
346
  """
347
- all_words = []
348
-
349
- for page in result.pages:
350
- for block in page.blocks:
351
- for line in block.lines:
352
- for word in line.words:
353
- x_min = word.geometry[0][0]
354
- y_min = word.geometry[0][1]
355
- x_max = word.geometry[1][0]
356
- y_max = word.geometry[1][1]
357
-
358
- all_words.append({
359
- 'text': word.value,
360
- 'x': x_min,
361
- 'x_end': x_max,
362
- 'y': y_min,
363
- 'y_end': y_max,
364
- 'width': x_max - x_min,
365
- 'height': y_max - y_min
366
- })
367
-
368
- if not all_words:
369
- return '', {'is_table': False}
370
-
371
- table_data = extract_table_structure(all_words)
372
 
373
- if table_data['is_table']:
374
  markdown_table = format_table_as_markdown(table_data)
375
  return markdown_table, table_data
376
  else:
377
- all_words.sort(key=lambda w: (round(w['y'] * 50) / 50, w['x']))
378
-
379
- lines = []
380
- current_line = []
381
- prev_y = -1
382
- y_tolerance = 0.02
383
-
384
- for word in all_words:
385
- current_y = round(word['y'] * 50) / 50
386
-
387
- if prev_y != -1 and abs(word['y'] - prev_y) > y_tolerance:
388
- if current_line:
389
- lines.append(' '.join(w['text'] for w in current_line))
390
- current_line = [word]
391
- else:
392
- current_line.append(word)
393
-
394
- prev_y = word['y']
395
-
396
- if current_line:
397
- lines.append(' '.join(w['text'] for w in current_line))
398
-
399
- return '\n'.join(lines), {'is_table': False}
400
 
401
 
402
  def extract_text_structured(result) -> str:
@@ -655,24 +605,29 @@ async def process_image(
655
  # Get image dimensions for frontend highlighting
656
  img_height, img_width = preprocessed_img.shape[:2]
657
 
658
- # Extract text and word bounding boxes
659
- # Try table detection first
660
- table_formatted_text, table_data = extract_text_with_table_detection(result)
661
-
662
- # Also get the regular structured text for NER processing
663
  structured_text = extract_text_structured(result)
664
  cleaned_text = basic_cleanup(structured_text)
665
  words_with_boxes = extract_words_with_boxes(result)
666
 
 
 
 
 
 
 
 
 
 
667
  # Use table-formatted text if table was detected
668
  if table_data.get('is_table'):
669
  display_text = table_formatted_text
670
  print(f"Table detected with {table_data.get('num_columns', 0)} columns and {table_data.get('num_rows', 0)} rows")
 
 
671
  else:
672
  display_text = structured_text
673
-
674
- print(f"OCR Structured Text:\n{display_text[:500]}...")
675
- print(f"Extracted {len(words_with_boxes)} words with bounding boxes")
676
 
677
  # Perform NER on cleaned text
678
  print("Running NER...")
 
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
5
  from doctr.io import DocumentFile
6
  from doctr.models import ocr_predictor
7
+ from img2table.document import Image as Img2TableImage
8
+ from img2table.ocr import DocTR
9
  import cv2
10
  import numpy as np
11
  from PIL import Image
12
  import io
13
  import json
14
  import os
15
+ import tempfile
16
  from typing import Dict, Any, Optional, List
17
 
18
  app = FastAPI(title="ScanAssured OCR & NER API")
 
198
  return text
199
 
200
 
201
+ # --- TABLE DETECTION WITH IMG2TABLE ---
202
 
203
+ # Cache for img2table OCR instance
204
+ img2table_ocr_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ def get_img2table_ocr():
207
+ """Get or create img2table DocTR OCR instance."""
208
+ if 'doctr' not in img2table_ocr_cache:
209
+ img2table_ocr_cache['doctr'] = DocTR()
210
+ return img2table_ocr_cache['doctr']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
 
 
212
 
213
+ def extract_tables_with_img2table(image_bytes: bytes, img_width: int, img_height: int) -> dict:
 
214
  """
215
+ Use img2table to detect and extract table structure from image.
216
+ Returns table data with properly structured cells.
217
  """
218
+ try:
219
+ # Save image to temp file (img2table needs file path)
220
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
221
+ tmp_file.write(image_bytes)
222
+ tmp_path = tmp_file.name
223
+
224
+ # Create img2table Image object
225
+ img2table_img = Img2TableImage(src=tmp_path)
226
+
227
+ # Get OCR instance
228
+ ocr = get_img2table_ocr()
229
+
230
+ # Extract tables with OCR
231
+ tables = img2table_img.extract_tables(
232
+ ocr=ocr,
233
+ implicit_rows=True, # Detect rows even without horizontal lines
234
+ implicit_columns=True, # Detect columns even without vertical lines
235
+ borderless_tables=True, # Detect tables without borders
236
+ min_confidence=50 # Minimum OCR confidence
237
+ )
238
 
239
+ # Clean up temp file
240
+ try:
241
+ os.unlink(tmp_path)
242
+ except:
243
+ pass
244
+
245
+ if not tables:
246
+ return {'is_table': False, 'tables': []}
247
+
248
+ # Process all detected tables
249
+ all_tables = []
250
+ for table in tables:
251
+ # Get table content as list of lists
252
+ if hasattr(table, 'content'):
253
+ cells = []
254
+ for row in table.content:
255
+ row_cells = []
256
+ for cell in row:
257
+ # Cell can be string or have value attribute
258
+ if cell is None:
259
+ row_cells.append('')
260
+ elif isinstance(cell, str):
261
+ row_cells.append(cell.strip())
262
+ elif hasattr(cell, 'value'):
263
+ row_cells.append(str(cell.value).strip() if cell.value else '')
264
+ else:
265
+ row_cells.append(str(cell).strip())
266
+ cells.append(row_cells)
267
+
268
+ if cells and any(any(c for c in row) for row in cells):
269
+ all_tables.append({
270
+ 'cells': cells,
271
+ 'num_rows': len(cells),
272
+ 'num_columns': len(cells[0]) if cells else 0
273
+ })
274
 
275
+ if not all_tables:
276
+ return {'is_table': False, 'tables': []}
 
277
 
278
+ # Return the largest table (most cells) as primary
279
+ primary_table = max(all_tables, key=lambda t: t['num_rows'] * t['num_columns'])
 
 
 
 
 
 
280
 
281
+ return {
282
+ 'is_table': True,
283
+ 'cells': primary_table['cells'],
284
+ 'num_rows': primary_table['num_rows'],
285
+ 'num_columns': primary_table['num_columns'],
286
+ 'tables': all_tables,
287
+ 'total_tables': len(all_tables)
288
+ }
289
 
290
+ except Exception as e:
291
+ print(f"img2table extraction error: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ return {'is_table': False, 'error': str(e)}
 
 
 
295
 
296
 
297
  def format_table_as_markdown(table_data: dict) -> str:
 
303
  if not cells:
304
  return ''
305
 
306
+ num_cols = max(len(row) for row in cells) if cells else 0
307
  if num_cols == 0:
308
  return ''
309
 
310
  lines = []
311
  col_widths = [3] * num_cols
312
+
313
+ # Normalize rows to have same number of columns
314
+ normalized_cells = []
315
  for row in cells:
316
+ normalized_row = list(row) + [''] * (num_cols - len(row))
317
+ normalized_cells.append(normalized_row)
318
+ for i, cell in enumerate(normalized_row):
319
  if i < num_cols:
320
+ col_widths[i] = max(col_widths[i], len(str(cell)))
321
 
322
+ for row_idx, row in enumerate(normalized_cells):
323
  formatted_cells = []
324
  for i, cell in enumerate(row):
325
  if i < num_cols:
326
+ formatted_cells.append(str(cell).ljust(col_widths[i]))
327
 
328
  line = '| ' + ' | '.join(formatted_cells) + ' |'
329
  lines.append(line)
 
335
  return '\n'.join(lines)
336
 
337
 
338
+ def extract_text_with_table_detection(image_bytes: bytes, img_width: int, img_height: int) -> tuple:
339
  """
340
+ Extract tables from image using img2table.
341
+ Returns (markdown_text, table_data).
342
  """
343
+ table_data = extract_tables_with_img2table(image_bytes, img_width, img_height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ if table_data.get('is_table'):
346
  markdown_table = format_table_as_markdown(table_data)
347
  return markdown_table, table_data
348
  else:
349
+ return '', {'is_table': False}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
 
352
  def extract_text_structured(result) -> str:
 
605
  # Get image dimensions for frontend highlighting
606
  img_height, img_width = preprocessed_img.shape[:2]
607
 
608
+ # Extract text and word bounding boxes using docTR
 
 
 
 
609
  structured_text = extract_text_structured(result)
610
  cleaned_text = basic_cleanup(structured_text)
611
  words_with_boxes = extract_words_with_boxes(result)
612
 
613
+ print(f"OCR Structured Text:\n{structured_text[:500]}...")
614
+ print(f"Extracted {len(words_with_boxes)} words with bounding boxes")
615
+
616
+ # Try table detection with img2table
617
+ print("Running img2table for table detection...")
618
+ table_formatted_text, table_data = extract_text_with_table_detection(
619
+ img_bytes, img_width, img_height
620
+ )
621
+
622
  # Use table-formatted text if table was detected
623
  if table_data.get('is_table'):
624
  display_text = table_formatted_text
625
  print(f"Table detected with {table_data.get('num_columns', 0)} columns and {table_data.get('num_rows', 0)} rows")
626
+ if table_data.get('total_tables', 0) > 1:
627
+ print(f"Total tables found: {table_data.get('total_tables')}")
628
  else:
629
  display_text = structured_text
630
+ print("No table detected, using regular OCR text")
 
 
631
 
632
  # Perform NER on cleaned text
633
  print("Running NER...")