hkai20000 commited on
Commit
7dc006e
·
verified ·
1 Parent(s): 3f371b2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +226 -4
main.py CHANGED
@@ -194,6 +194,211 @@ def basic_cleanup(text: str) -> str:
194
  text = " ".join(text.split())
195
  return text
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def extract_text_structured(result) -> str:
198
  """
199
  Extract text from docTR result preserving logical structure.
@@ -451,11 +656,22 @@ async def process_image(
451
  img_height, img_width = preprocessed_img.shape[:2]
452
 
453
  # Extract text and word bounding boxes
 
 
 
 
454
  structured_text = extract_text_structured(result)
455
  cleaned_text = basic_cleanup(structured_text)
456
  words_with_boxes = extract_words_with_boxes(result)
457
 
458
- print(f"OCR Structured Text:\n{structured_text[:500]}...")
 
 
 
 
 
 
 
459
  print(f"Extracted {len(words_with_boxes)} words with bounding boxes")
460
 
461
  # Perform NER on cleaned text
@@ -485,14 +701,20 @@ async def process_image(
485
  print(f"Found {len(interactions)} drug interactions")
486
 
487
  return {
488
- "structured_text": structured_text,
489
  "cleaned_text": cleaned_text,
490
  "medical_entities": entities_with_boxes,
491
- "interactions": interactions, # NEW: Drug interaction warnings
492
  "model_id": NER_MODELS[ner_model_id]["name"],
493
  "ocr_model": f"{det_arch} + {reco_arch}",
494
  "image_width": img_width,
495
- "image_height": img_height
 
 
 
 
 
 
496
  }
497
 
498
  except Exception as e:
 
194
  text = " ".join(text.split())
195
  return text
196
 
197
+
198
+ # --- TABLE DETECTION AND EXTRACTION ---
199
+
200
+ def detect_columns(words_data: list, min_gap_ratio: float = 0.03) -> list:
201
+ """
202
+ Detect column boundaries by analyzing gaps between words.
203
+ Returns list of column boundaries [(x_start, x_end), ...].
204
+ """
205
+ if not words_data:
206
+ return []
207
+
208
+ all_x_starts = sorted(set(w['x'] for w in words_data))
209
+
210
+ if len(all_x_starts) < 2:
211
+ return [(0, 1)]
212
+
213
+ x_clusters = []
214
+ current_cluster = [all_x_starts[0]]
215
+
216
+ for i in range(1, len(all_x_starts)):
217
+ gap = all_x_starts[i] - all_x_starts[i-1]
218
+ if gap > min_gap_ratio:
219
+ x_clusters.append(current_cluster)
220
+ current_cluster = [all_x_starts[i]]
221
+ else:
222
+ current_cluster.append(all_x_starts[i])
223
+ x_clusters.append(current_cluster)
224
+
225
+ if len(x_clusters) >= 2:
226
+ columns = []
227
+ for i, cluster in enumerate(x_clusters):
228
+ x_start = min(cluster) - 0.01
229
+ if i < len(x_clusters) - 1:
230
+ x_end = min(x_clusters[i + 1]) - 0.005
231
+ else:
232
+ x_end = 1.0
233
+ columns.append((max(0, x_start), min(1, x_end)))
234
+ return columns
235
+
236
+ return [(0, 1)]
237
+
238
+
239
+ def detect_rows(words_data: list, y_tolerance: float = 0.015) -> list:
240
+ """Detect row boundaries by analyzing y-positions."""
241
+ if not words_data:
242
+ return []
243
+
244
+ y_positions = sorted(set(w['y'] for w in words_data))
245
+
246
+ if not y_positions:
247
+ return []
248
+
249
+ rows = []
250
+ current_row_ys = [y_positions[0]]
251
+
252
+ for i in range(1, len(y_positions)):
253
+ if y_positions[i] - y_positions[i-1] <= y_tolerance:
254
+ current_row_ys.append(y_positions[i])
255
+ else:
256
+ rows.append(sum(current_row_ys) / len(current_row_ys))
257
+ current_row_ys = [y_positions[i]]
258
+
259
+ rows.append(sum(current_row_ys) / len(current_row_ys))
260
+ return rows
261
+
262
+
263
+ def extract_table_structure(words_data: list) -> dict:
264
+ """
265
+ Extract table structure from words, returning rows and columns.
266
+ """
267
+ if not words_data:
268
+ return {'is_table': False, 'columns': [], 'rows': [], 'cells': []}
269
+
270
+ columns = detect_columns(words_data)
271
+ rows = detect_rows(words_data)
272
+
273
+ is_table = len(columns) >= 2 and len(rows) >= 2
274
+
275
+ if not is_table:
276
+ return {'is_table': False, 'columns': columns, 'rows': rows, 'cells': []}
277
+
278
+ y_tolerance = 0.02
279
+ cells = []
280
+
281
+ for row_y in rows:
282
+ row_cells = [''] * len(columns)
283
+ row_words = [w for w in words_data if abs(w['y'] - row_y) <= y_tolerance]
284
+
285
+ for word in row_words:
286
+ for col_idx, (col_start, col_end) in enumerate(columns):
287
+ if col_start <= word['x'] < col_end:
288
+ if row_cells[col_idx]:
289
+ row_cells[col_idx] += ' ' + word['text']
290
+ else:
291
+ row_cells[col_idx] = word['text']
292
+ break
293
+
294
+ cells.append(row_cells)
295
+
296
+ return {
297
+ 'is_table': True,
298
+ 'columns': columns,
299
+ 'rows': rows,
300
+ 'cells': cells,
301
+ 'num_columns': len(columns),
302
+ 'num_rows': len(rows)
303
+ }
304
+
305
+
306
+ def format_table_as_markdown(table_data: dict) -> str:
307
+ """Format extracted table data as a markdown table."""
308
+ if not table_data.get('is_table') or not table_data.get('cells'):
309
+ return ''
310
+
311
+ cells = table_data['cells']
312
+ if not cells:
313
+ return ''
314
+
315
+ num_cols = len(cells[0]) if cells else 0
316
+ if num_cols == 0:
317
+ return ''
318
+
319
+ lines = []
320
+ col_widths = [3] * num_cols
321
+ for row in cells:
322
+ for i, cell in enumerate(row):
323
+ if i < num_cols:
324
+ col_widths[i] = max(col_widths[i], len(cell))
325
+
326
+ for row_idx, row in enumerate(cells):
327
+ formatted_cells = []
328
+ for i, cell in enumerate(row):
329
+ if i < num_cols:
330
+ formatted_cells.append(cell.ljust(col_widths[i]))
331
+
332
+ line = '| ' + ' | '.join(formatted_cells) + ' |'
333
+ lines.append(line)
334
+
335
+ if row_idx == 0:
336
+ separator = '|' + '|'.join(['-' * (w + 2) for w in col_widths]) + '|'
337
+ lines.append(separator)
338
+
339
+ return '\n'.join(lines)
340
+
341
+
342
+ def extract_text_with_table_detection(result) -> tuple:
343
+ """
344
+ Extract text from docTR result, detecting and preserving table structure.
345
+ Returns (structured_text, table_data).
346
+ """
347
+ all_words = []
348
+
349
+ for page in result.pages:
350
+ for block in page.blocks:
351
+ for line in block.lines:
352
+ for word in line.words:
353
+ x_min = word.geometry[0][0]
354
+ y_min = word.geometry[0][1]
355
+ x_max = word.geometry[1][0]
356
+ y_max = word.geometry[1][1]
357
+
358
+ all_words.append({
359
+ 'text': word.value,
360
+ 'x': x_min,
361
+ 'x_end': x_max,
362
+ 'y': y_min,
363
+ 'y_end': y_max,
364
+ 'width': x_max - x_min,
365
+ 'height': y_max - y_min
366
+ })
367
+
368
+ if not all_words:
369
+ return '', {'is_table': False}
370
+
371
+ table_data = extract_table_structure(all_words)
372
+
373
+ if table_data['is_table']:
374
+ markdown_table = format_table_as_markdown(table_data)
375
+ return markdown_table, table_data
376
+ else:
377
+ all_words.sort(key=lambda w: (round(w['y'] * 50) / 50, w['x']))
378
+
379
+ lines = []
380
+ current_line = []
381
+ prev_y = -1
382
+ y_tolerance = 0.02
383
+
384
+ for word in all_words:
385
+ current_y = round(word['y'] * 50) / 50
386
+
387
+ if prev_y != -1 and abs(word['y'] - prev_y) > y_tolerance:
388
+ if current_line:
389
+ lines.append(' '.join(w['text'] for w in current_line))
390
+ current_line = [word]
391
+ else:
392
+ current_line.append(word)
393
+
394
+ prev_y = word['y']
395
+
396
+ if current_line:
397
+ lines.append(' '.join(w['text'] for w in current_line))
398
+
399
+ return '\n'.join(lines), {'is_table': False}
400
+
401
+
402
  def extract_text_structured(result) -> str:
403
  """
404
  Extract text from docTR result preserving logical structure.
 
656
  img_height, img_width = preprocessed_img.shape[:2]
657
 
658
  # Extract text and word bounding boxes
659
+ # Try table detection first
660
+ table_formatted_text, table_data = extract_text_with_table_detection(result)
661
+
662
+ # Also get the regular structured text for NER processing
663
  structured_text = extract_text_structured(result)
664
  cleaned_text = basic_cleanup(structured_text)
665
  words_with_boxes = extract_words_with_boxes(result)
666
 
667
+ # Use table-formatted text if table was detected
668
+ if table_data.get('is_table'):
669
+ display_text = table_formatted_text
670
+ print(f"Table detected with {table_data.get('num_columns', 0)} columns and {table_data.get('num_rows', 0)} rows")
671
+ else:
672
+ display_text = structured_text
673
+
674
+ print(f"OCR Structured Text:\n{display_text[:500]}...")
675
  print(f"Extracted {len(words_with_boxes)} words with bounding boxes")
676
 
677
  # Perform NER on cleaned text
 
701
  print(f"Found {len(interactions)} drug interactions")
702
 
703
  return {
704
+ "structured_text": display_text, # Table-formatted if detected, otherwise regular
705
  "cleaned_text": cleaned_text,
706
  "medical_entities": entities_with_boxes,
707
+ "interactions": interactions, # Drug interaction warnings
708
  "model_id": NER_MODELS[ner_model_id]["name"],
709
  "ocr_model": f"{det_arch} + {reco_arch}",
710
  "image_width": img_width,
711
+ "image_height": img_height,
712
+ "table_detected": table_data.get('is_table', False),
713
+ "table_data": {
714
+ "num_columns": table_data.get('num_columns', 0),
715
+ "num_rows": table_data.get('num_rows', 0),
716
+ "cells": table_data.get('cells', [])
717
+ } if table_data.get('is_table') else None
718
  }
719
 
720
  except Exception as e: