Sathvik-kota commited on
Commit
80ab573
·
verified ·
1 Parent(s): 5ec4a93

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +78 -32
app.py CHANGED
@@ -1,7 +1,7 @@
1
- # app_bill_extractor_final.py
2
  # Humanized, high-accuracy bill extraction API.
3
- # Combines robust OCR preprocessing, TSV-based layout parsing, numeric-column inference,
4
- # and ALWAYS attempts Gemini refinement (if GEMINI_API_KEY set). Made compact & readable.
5
 
6
  import os
7
  import re
@@ -19,7 +19,7 @@ from pytesseract import Output
19
  import numpy as np
20
  import cv2
21
 
22
- # Optional: Google Gemini SDK (if you use it). Code will gracefully work without it.
23
  try:
24
  import google.generativeai as genai
25
  except Exception:
@@ -40,20 +40,29 @@ app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, humanized)")
40
  class BillRequest(BaseModel):
41
  document: str
42
 
43
- # ---------------- Regex, small utils ----------------
44
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
45
  TOTAL_KEYWORDS = re.compile(
46
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
47
  re.I,
48
  )
49
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
50
- HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
51
 
52
- # sanitize OCR text before ever sending to an LLM or using it for heuristics
 
 
 
 
 
 
 
 
 
 
 
53
  def sanitize_ocr_text(s: str) -> str:
54
  if not s:
55
  return ""
56
- # unify dashes and remove odd control characters
57
  s = s.replace("\u2014", "-").replace("\u2013", "-")
58
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
59
  s = s.replace("\r\n", "\n").replace("\r", "\n")
@@ -102,7 +111,6 @@ def pil_to_cv2(img: Image.Image) -> Any:
102
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
103
 
104
  def preprocess_image(pil_img: Image.Image) -> Any:
105
- # quick, robust steps: upscale small images, grayscale, denoise, adaptive threshold
106
  pil_img = pil_img.convert("RGB")
107
  w, h = pil_img.size
108
  target_w = 1500
@@ -120,7 +128,7 @@ def preprocess_image(pil_img: Image.Image) -> Any:
120
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
121
  return bw
122
 
123
- # ---------------- OCR TSV helpers ----------------
124
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
125
  try:
126
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
@@ -148,7 +156,7 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
148
  cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
149
  return cells
150
 
151
- # ---------------- grouping & merging ----------------
152
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
153
  if not cells:
154
  return []
@@ -199,7 +207,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
199
  i += 1
200
  return merged
201
 
202
- # ---------------- numeric column detection ----------------
203
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
204
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
205
  if not xs:
@@ -207,19 +215,21 @@ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) ->
207
  xs = sorted(xs)
208
  if len(xs) == 1:
209
  return [xs[0]]
 
 
 
210
  gaps = [xs[i+1] - xs[i] for i in range(len(xs) - 1)]
211
- mean_gap = float(np.mean(gaps))
212
- std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
213
- gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
214
  clusters = []
215
  curr = [xs[0]]
216
  for i, g in enumerate(gaps):
217
- if g > gap_thresh and len(clusters) < (max_columns - 1):
218
  clusters.append(curr)
219
  curr = [xs[i+1]]
220
  else:
221
  curr.append(xs[i+1])
222
  clusters.append(curr)
 
223
  centers = [float(np.median(c)) for c in clusters]
224
  if len(centers) > max_columns:
225
  centers = centers[-max_columns:]
@@ -231,7 +241,7 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
231
  distances = [abs(token_x - cx) for cx in column_centers]
232
  return int(np.argmin(distances))
233
 
234
- # ---------------- parse rows into items ----------------
235
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
236
  parsed_items = []
237
  rows = merge_multiline_names(rows)
@@ -369,29 +379,33 @@ def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[flo
369
  if final is None: final = float(round(v, 2))
370
  return {"subtotal": subtotal, "final_total": final}
371
 
372
- # ---------------- Gemini refinement (always attempted) ----------------
373
  def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
374
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
375
  if not GEMINI_API_KEY or genai is None:
376
  return page_items, zero_usage
377
  try:
378
  safe_text = sanitize_ocr_text(page_text)
379
- system = (
380
- "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no text) of objects with keys "
381
- "item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
382
- "Do NOT return totals or subtotals as items. Do not invent items. Fix broken names and numeric mismatches."
383
  )
384
- # small few-shot example to anchor the model
385
- few_shot = (
386
- "# EXAMPLE\nitems = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0}]\n"
387
- "=> [{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n"
 
 
 
 
 
388
  )
389
- prompt = f"page_text='''{safe_text}'''\nitems = {json.dumps(page_items, ensure_ascii=False)}\n\n{few_shot}\nReturn only a JSON array."
390
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
391
  response = model.generate_content(
392
  [
393
- {"role": "system", "parts": [system]},
394
- {"role": "user", "parts": [prompt]},
395
  ],
396
  temperature=0.0,
397
  max_output_tokens=1000,
@@ -413,6 +427,7 @@ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") ->
413
  })
414
  except Exception:
415
  continue
 
416
  return cleaned, zero_usage
417
  return page_items, zero_usage
418
  except Exception:
@@ -423,6 +438,9 @@ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
423
  if not txt:
424
  return False
425
  t = re.sub(r"\s+", " ", txt.strip().lower())
 
 
 
426
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
427
  if hits >= 2:
428
  return True
@@ -438,12 +456,12 @@ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
438
  return True
439
  return False
440
 
441
-
442
  def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = []) -> bool:
443
  name = (item.get("item_name") or "").strip()
444
  if not name:
445
  return False
446
  ln = name.lower()
 
447
  for h in known_page_headers:
448
  if h and h.strip() and h.strip().lower() in ln:
449
  return False
@@ -455,6 +473,10 @@ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [])
455
  return False
456
  if re.fullmatch(r"(charge|charges|services|laboratory|lab|consultation)", ln.strip(), re.I):
457
  return False
 
 
 
 
458
  if float(item.get("item_amount", 0)) <= 0.0:
459
  return False
460
  rate = float(item.get("item_rate", 0) or 0)
@@ -499,25 +521,48 @@ async def extract_bill_data(payload: BillRequest):
499
  proc = preprocess_image(page_img)
500
  cells = image_to_tsv_cells(proc)
501
  rows = group_cells_into_rows(cells, y_tolerance=12)
502
- rows_texts = [" ".join([c["text"] for c in r]) for r in rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  top_headers = []
504
  for i, rt in enumerate(rows_texts[:6]):
505
  if looks_like_header_text(rt, top_of_page=(i < 4)):
506
  top_headers.append(rt.strip().lower())
 
507
  parsed_items = parse_rows_with_columns(rows, cells)
508
- page_text = sanitize_ocr_text(" ".join(rows_texts))
 
509
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
510
  for k in cumulative_token_usage:
511
  cumulative_token_usage[k] += token_u.get(k, 0)
 
 
512
  cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers)]
513
  cleaned = dedupe_items(cleaned)
514
  cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
 
515
  page_type = "Bill Detail"
516
  page_txt = page_text.lower()
517
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
518
  page_type = "Pharmacy"
519
  if "final bill" in page_txt or "grand total" in page_txt:
520
  page_type = "Final Bill"
 
521
  pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
522
  except Exception:
523
  pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
@@ -526,6 +571,7 @@ async def extract_bill_data(payload: BillRequest):
526
  total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
527
  if not GEMINI_API_KEY or genai is None:
528
  cumulative_token_usage["warning_no_gemini"] = 1
 
529
  return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
530
 
531
  # ---------------- debug TSV ----------------
 
1
+ # app_bill_extractor_final_v2.py
2
  # Humanized, high-accuracy bill extraction API.
3
+ # Robust OCR preprocessing, TSV layout parsing, numeric-column inference,
4
+ # header prefiltering, deterministic Gemini refinement (if configured).
5
 
6
  import os
7
  import re
 
19
  import numpy as np
20
  import cv2
21
 
22
+ # Optional: Google Gemini SDK (if available)
23
  try:
24
  import google.generativeai as genai
25
  except Exception:
 
40
  class BillRequest(BaseModel):
41
  document: str
42
 
43
+ # ---------------- Regex and keywords ----------------
44
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
45
  TOTAL_KEYWORDS = re.compile(
46
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
47
  re.I,
48
  )
49
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
 
50
 
51
+ # generalized header-related tokens & exact header phrase blacklist (common variants)
52
+ HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
53
+ HEADER_PHRASES = [
54
+ "description qty / hrs consultation rate discount net amt",
55
+ "description qty / hrs rate discount net amt",
56
+ "description qty / hrs rate net amt",
57
+ "description qty hrs rate discount net amt",
58
+ "description qty / hrs rate discount net amt",
59
+ ]
60
+ HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
61
+
62
+ # ---------------- small utilities ----------------
63
  def sanitize_ocr_text(s: str) -> str:
64
  if not s:
65
  return ""
 
66
  s = s.replace("\u2014", "-").replace("\u2013", "-")
67
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
68
  s = s.replace("\r\n", "\n").replace("\r", "\n")
 
111
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
112
 
113
  def preprocess_image(pil_img: Image.Image) -> Any:
 
114
  pil_img = pil_img.convert("RGB")
115
  w, h = pil_img.size
116
  target_w = 1500
 
128
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
129
  return bw
130
 
131
+ # ---------------- OCR TSV ----------------
132
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
133
  try:
134
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
 
156
  cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
157
  return cells
158
 
159
+ # ---------------- grouping & merge helpers ----------------
160
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
161
  if not cells:
162
  return []
 
207
  i += 1
208
  return merged
209
 
210
+ # ---------------- numeric column detection (conservative) ----------------
211
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
212
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
213
  if not xs:
 
215
  xs = sorted(xs)
216
  if len(xs) == 1:
217
  return [xs[0]]
218
+
219
+ # Conservative min gap to avoid merging separate numeric columns
220
+ min_gap_px = 50.0
221
  gaps = [xs[i+1] - xs[i] for i in range(len(xs) - 1)]
222
+
 
 
223
  clusters = []
224
  curr = [xs[0]]
225
  for i, g in enumerate(gaps):
226
+ if g >= min_gap_px:
227
  clusters.append(curr)
228
  curr = [xs[i+1]]
229
  else:
230
  curr.append(xs[i+1])
231
  clusters.append(curr)
232
+
233
  centers = [float(np.median(c)) for c in clusters]
234
  if len(centers) > max_columns:
235
  centers = centers[-max_columns:]
 
241
  distances = [abs(token_x - cx) for cx in column_centers]
242
  return int(np.argmin(distances))
243
 
244
+ # ---------------- parsing rows into items ----------------
245
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
246
  parsed_items = []
247
  rows = merge_multiline_names(rows)
 
379
  if final is None: final = float(round(v, 2))
380
  return {"subtotal": subtotal, "final_total": final}
381
 
382
+ # ---------------- Gemini refinement (deterministic) ----------------
383
  def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
384
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
385
  if not GEMINI_API_KEY or genai is None:
386
  return page_items, zero_usage
387
  try:
388
  safe_text = sanitize_ocr_text(page_text)
389
+ system_prompt = (
390
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
391
+ "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
392
+ "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
393
  )
394
+ user_prompt = (
395
+ f"page_text='''{safe_text}'''\n"
396
+ f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
397
+ "Example:\n"
398
+ "items = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0},\n"
399
+ " {'item_name':'Description Qty / Hrs Consultation Rate Discount Net Amt','item_amount':1950.0,'item_rate':1950.0,'item_quantity':1.0}]\n"
400
+ "=>\n"
401
+ "[{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n\n"
402
+ "Return only the cleaned JSON array of items."
403
  )
 
404
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
405
  response = model.generate_content(
406
  [
407
+ {"role": "system", "parts": [system_prompt]},
408
+ {"role": "user", "parts": [user_prompt]},
409
  ],
410
  temperature=0.0,
411
  max_output_tokens=1000,
 
427
  })
428
  except Exception:
429
  continue
430
+ # token usage info not reliably extracted here — return zeros
431
  return cleaned, zero_usage
432
  return page_items, zero_usage
433
  except Exception:
 
438
  if not txt:
439
  return False
440
  t = re.sub(r"\s+", " ", txt.strip().lower())
441
+ # exact phrase blacklist
442
+ if any(h == t for h in HEADER_PHRASES):
443
+ return True
444
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
445
  if hits >= 2:
446
  return True
 
456
  return True
457
  return False
458
 
 
459
  def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = []) -> bool:
460
  name = (item.get("item_name") or "").strip()
461
  if not name:
462
  return False
463
  ln = name.lower()
464
+ # exact match against detected headers
465
  for h in known_page_headers:
466
  if h and h.strip() and h.strip().lower() in ln:
467
  return False
 
473
  return False
474
  if re.fullmatch(r"(charge|charges|services|laboratory|lab|consultation)", ln.strip(), re.I):
475
  return False
476
+ # drop obvious section/subtotal labels (but allow items like 'ANAES. CHARGE' which contain a dot)
477
+ if len(name.split()) <= 4 and re.search(r"\b(charges|services|room|radiology|laborat|surgery|procedure)\b", ln):
478
+ if "." not in name and not re.search(r"\b[A-Z]{2,}\b", name):
479
+ return False
480
  if float(item.get("item_amount", 0)) <= 0.0:
481
  return False
482
  rate = float(item.get("item_rate", 0) or 0)
 
521
  proc = preprocess_image(page_img)
522
  cells = image_to_tsv_cells(proc)
523
  rows = group_cells_into_rows(cells, y_tolerance=12)
524
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
525
+
526
+ # === HEADER PREFILTER: remove header-like rows anywhere on page ===
527
+ rows_filtered = []
528
+ for i, (r, rt) in enumerate(zip(rows, rows_texts)):
529
+ top_flag = (i < 6)
530
+ rt_norm = sanitize_ocr_text(rt).lower()
531
+ if looks_like_header_text(rt_norm, top_of_page=top_flag):
532
+ continue
533
+ if any(h in rt_norm for h in HEADER_PHRASES):
534
+ continue
535
+ rows_filtered.append(r)
536
+ # recompute row texts and a simple page_text
537
+ rows = rows_filtered
538
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
539
+ page_text = sanitize_ocr_text(" ".join(rows_texts))
540
+
541
+ # detect page-level top headers (for final filtering)
542
  top_headers = []
543
  for i, rt in enumerate(rows_texts[:6]):
544
  if looks_like_header_text(rt, top_of_page=(i < 4)):
545
  top_headers.append(rt.strip().lower())
546
+
547
  parsed_items = parse_rows_with_columns(rows, cells)
548
+
549
+ # ALWAYS attempt Gemini refinement if available (deterministic settings)
550
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
551
  for k in cumulative_token_usage:
552
  cumulative_token_usage[k] += token_u.get(k, 0)
553
+
554
+ # final cleaning & dedupe
555
  cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers)]
556
  cleaned = dedupe_items(cleaned)
557
  cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
558
+
559
  page_type = "Bill Detail"
560
  page_txt = page_text.lower()
561
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
562
  page_type = "Pharmacy"
563
  if "final bill" in page_txt or "grand total" in page_txt:
564
  page_type = "Final Bill"
565
+
566
  pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
567
  except Exception:
568
  pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
 
571
  total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
572
  if not GEMINI_API_KEY or genai is None:
573
  cumulative_token_usage["warning_no_gemini"] = 1
574
+
575
  return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
576
 
577
  # ---------------- debug TSV ----------------