Sathvik-kota commited on
Commit
1d0b76b
·
verified ·
1 Parent(s): 310da4b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +463 -546
app.py CHANGED
@@ -1,7 +1,7 @@
1
- ###############################################
2
- # Bajaj Datathon - FINAL PATCHED BILL EXTRACTOR
3
- # High Accuracy | Robust OCR | Gemini Refinement
4
- ###############################################
5
 
6
  import os
7
  import re
@@ -19,42 +19,37 @@ from pytesseract import Output
19
  import numpy as np
20
  import cv2
21
 
22
- # Optional Gemini SDK
23
  try:
24
  import google.generativeai as genai
25
- except:
26
  genai = None
27
 
28
  # ---------------- LLM CONFIG ----------------
29
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
30
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
31
-
32
  if GEMINI_API_KEY and genai is not None:
33
  try:
34
  genai.configure(api_key=GEMINI_API_KEY)
35
- except:
36
  pass
37
 
38
-
39
- # ---------------- FASTAPI APP ----------------
40
- app = FastAPI(title="Bajaj Datathon - Bill Extractor (patched v3)")
41
 
42
  class BillRequest(BaseModel):
43
  document: str
44
 
45
-
46
- ###############################################
47
- # COMMON REGEX AND UTILITY FUNCTIONS
48
- ###############################################
49
-
50
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
 
 
 
 
 
51
 
52
- HEADER_KEYWORDS = [
53
- "description", "qty", "hrs", "rate",
54
- "discount", "net", "amt", "amount",
55
- "qty/hrs", "qty / hrs"
56
- ]
57
-
58
  HEADER_PHRASES = [
59
  "description qty / hrs consultation rate discount net amt",
60
  "description qty / hrs rate discount net amt",
@@ -64,14 +59,7 @@ HEADER_PHRASES = [
64
  ]
65
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
66
 
67
- TOTAL_KEYWORDS = re.compile(
68
- r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
69
- re.I,
70
- )
71
-
72
- FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
73
-
74
-
75
  def sanitize_ocr_text(s: str) -> str:
76
  if not s:
77
  return ""
@@ -80,718 +68,647 @@ def sanitize_ocr_text(s: str) -> str:
80
  s = s.replace("\r\n", "\n").replace("\r", "\n")
81
  s = re.sub(r"[ \t]+", " ", s)
82
  s = s.strip()
83
- return s[:5000]
84
-
85
 
86
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
87
  if s is None:
88
  return None
89
  s = str(s).strip()
90
- s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
91
  if s == "":
92
  return None
 
93
  negative = False
94
  if s.startswith("(") and s.endswith(")"):
95
  negative = True
96
  s = s[1:-1]
97
  s = s.replace(",", "")
98
- try:
99
- v = float(s)
100
- return -v if negative else v
101
- except:
102
  return None
103
-
 
 
 
 
 
 
104
 
105
  def is_numeric_token(t: Optional[str]) -> bool:
106
  return bool(t and NUM_RE.search(str(t)))
107
 
108
-
109
  def clean_name_text(s: str) -> str:
110
  s = s.replace("—", "-")
111
  s = re.sub(r"\s+", " ", s)
112
  s = s.strip(" -:,.")
113
- # Fix doctor prefix only if followed by name
114
- s = re.sub(r"\bOR (?=[A-Z][a-z])", "DR ", s)
 
 
115
  return s.strip()
116
 
117
-
118
- ###############################################
119
- # IMAGE PREPROCESSING
120
- ###############################################
121
-
122
- def pil_to_cv2(img: Image.Image):
123
  arr = np.array(img)
124
  if arr.ndim == 2:
125
  return arr
126
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
127
 
128
-
129
- def preprocess_image(pil_img: Image.Image):
130
  pil_img = pil_img.convert("RGB")
131
  w, h = pil_img.size
132
-
133
- if w < 1500:
134
- scale = 1500 / float(w)
135
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
136
-
137
- img = pil_to_cv2(pil_img)
138
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
139
-
140
  gray = cv2.fastNlMeansDenoising(gray, h=10)
141
-
142
  try:
143
- bw = cv2.adaptiveThreshold(gray, 255,
144
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
145
- cv2.THRESH_BINARY, 41, 15)
146
- except:
147
- _, bw = cv2.threshold(gray, 127, 255,
148
- cv2.THRESH_BINARY + cv2.THRESH_OTSU)
149
-
150
- bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, np.ones((1, 1), np.uint8))
151
  return bw
152
 
153
-
154
- ###############################################
155
- # OCR TSV EXTRACTION
156
- ###############################################
157
-
158
- def image_to_tsv_cells(cv_img):
159
  try:
160
- ocr = pytesseract.image_to_data(
161
- cv_img,
162
- output_type=Output.DICT,
163
- config="--psm 6"
164
- )
165
- except:
166
- ocr = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
167
-
168
  cells = []
169
- n = len(ocr.get("text", []))
170
-
171
  for i in range(n):
172
- t = (ocr["text"][i] or "").strip()
173
- if not t:
 
 
 
174
  continue
175
  try:
176
- conf = float(ocr["conf"][i])
177
- except:
178
  conf = -1.0
179
-
180
- left = int(ocr.get("left", [0])[i])
181
- top = int(ocr.get("top", [0])[i])
182
- width = int(ocr.get("width", [0])[i])
183
- height = int(ocr.get("height", [0])[i])
184
-
185
- cells.append({
186
- "text": t,
187
- "conf": conf,
188
- "left": left,
189
- "top": top,
190
- "width": width,
191
- "height": height,
192
- "center_x": left + width / 2,
193
- "center_y": top + height / 2,
194
- })
195
  return cells
196
 
197
-
198
- ###############################################
199
- # GROUPING INTO TEXT LINES
200
- ###############################################
201
-
202
- def group_cells_into_rows(cells, y_tol=12):
203
  if not cells:
204
  return []
205
- cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
206
-
207
  rows = []
208
- current = [cells[0]]
209
- last_y = cells[0]["center_y"]
210
-
211
- for c in cells[1:]:
212
- if abs(c["center_y"] - last_y) <= y_tol:
213
  current.append(c)
214
  last_y = (last_y * (len(current) - 1) + c["center_y"]) / len(current)
215
  else:
216
  rows.append(sorted(current, key=lambda cc: cc["left"]))
217
  current = [c]
218
  last_y = c["center_y"]
219
-
220
  if current:
221
  rows.append(sorted(current, key=lambda cc: cc["left"]))
222
-
223
  return rows
224
 
225
-
226
- ###############################################
227
- # DOCTOR-NAME MERGING (PATCH)
228
- ###############################################
229
-
230
- def merge_multiline_names(rows):
231
  if not rows:
232
  return rows
233
-
234
  merged = []
235
  i = 0
236
  while i < len(rows):
237
  row = rows[i]
238
  tokens = [c["text"] for c in row]
239
- joined = " ".join(tokens)
240
-
241
  has_num = any(is_numeric_token(t) for t in tokens)
242
-
243
- # --- Doctor Name Merge Fix ---
244
- if (not has_num and
245
- re.search(r"\bdr\b", joined.lower()) and
246
- i + 1 < len(rows)):
247
-
248
- next_tokens = " ".join([c["text"] for c in rows[i + 1]])
249
- if not any(is_numeric_token(x) for x in next_tokens.split()):
250
- merged_row = row + rows[i + 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
252
  i += 2
253
  continue
254
-
255
  merged.append(row)
256
  i += 1
257
-
258
  return merged
259
 
260
-
261
- ###############################################
262
- # DETECT NUMERIC COLUMNS
263
- ###############################################
264
-
265
- def detect_numeric_columns(cells, max_cols=4):
266
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
267
  if not xs:
268
  return []
269
-
270
  xs = sorted(xs)
271
  if len(xs) == 1:
272
  return [xs[0]]
273
-
274
- gaps = [xs[i + 1] - xs[i] for i in range(len(xs) - 1)]
275
  mean_gap = float(np.mean(gaps))
276
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
277
- thresh = max(30.0, mean_gap + 0.6 * std_gap)
278
-
279
  clusters = []
280
  curr = [xs[0]]
281
-
282
  for i, g in enumerate(gaps):
283
- if g > thresh and len(clusters) < (max_cols - 1):
284
  clusters.append(curr)
285
- curr = [xs[i + 1]]
286
  else:
287
- curr.append(xs[i + 1])
288
-
289
  clusters.append(curr)
290
-
291
  centers = [float(np.median(c)) for c in clusters]
292
- centers = centers[-max_cols:]
 
293
  return sorted(centers)
294
 
295
-
296
- def assign_token_to_column(x, centers):
297
- if not centers:
298
  return None
299
- dist = [abs(x - c) for c in centers]
300
- return int(np.argmin(dist))
301
-
302
-
303
- ###############################################
304
- # STRONG HEADER DETECTION (PATCHED)
305
- ###############################################
306
-
307
- def looks_like_header_text(txt: str, top_of_page=False):
308
- if not txt:
309
- return False
310
-
311
- t = re.sub(r"\s+", " ", txt.strip().lower())
312
-
313
- patterns = [
314
- r"description.*qty",
315
- r"qty.*rate",
316
- r"rate.*amount",
317
- r"net\s*amt",
318
- r"discount",
319
- r"hrs\s*/\s*qty",
320
- r"qty\s*/\s*hrs",
321
- ]
322
- for p in patterns:
323
- if re.search(p, t):
324
- return True
325
-
326
- if any(h == t for h in HEADER_PHRASES):
327
- return True
328
 
329
- hits = sum(1 for k in HEADER_KEYWORDS if k in t)
330
- if hits >= 3:
331
- return True
332
-
333
- tokens = re.split(r"[ \|,/]+", t)
334
- num = sum(1 for tok in tokens if NUM_RE.search(tok))
335
- if num >= 3:
336
- return True
337
-
338
- if top_of_page and hits >= 2:
339
- return True
340
-
341
- return False
342
-
343
-
344
- ###############################################
345
- # PARSE ROWS INTO ITEMS
346
- ###############################################
347
-
348
- def parse_rows_with_columns(rows, cells):
349
  rows = merge_multiline_names(rows)
350
- col_centers = detect_numeric_columns(cells)
351
-
352
- parsed = []
353
 
354
  for row in rows:
355
- texts = [c["text"] for c in row]
356
- joined = " ".join(texts).lower()
357
-
358
- if FOOTER_KEYWORDS.search(joined) and not any(is_numeric_token(t) for t in texts):
 
359
  continue
360
- if all(not is_numeric_token(t) for t in texts):
361
  continue
362
 
 
363
  numeric_values = []
364
- for t in texts:
365
  if is_numeric_token(t):
366
  v = normalize_num_str(t)
367
  if v is not None:
368
  numeric_values.append(float(v))
 
 
369
 
370
- # De-duplicate & sort largest first
371
- numeric_values = sorted(list({float(v) for v in numeric_values}), reverse=True)
372
-
373
- # Drop tiny noise
374
- numeric_values = [v for v in numeric_values if v >= 5 or (v < 5 and len(numeric_values) == 1)]
375
-
376
- if col_centers:
377
- left_text = []
378
- bucket = {i: [] for i in range(len(col_centers))}
379
-
380
  for c in row:
381
  t = c["text"]
382
- x = c["center_x"]
383
  if is_numeric_token(t):
384
- idx = assign_token_to_column(x, col_centers)
385
- if idx is not None:
386
- bucket[idx].append(t)
 
 
387
  else:
388
- left_text.append(t)
389
-
390
- name_raw = " ".join(left_text).strip()
391
- name = clean_name_text(name_raw)
392
-
393
- N = len(col_centers)
394
 
395
- def pick(k):
396
- vals = bucket.get(k, [])
 
397
  return vals[-1] if vals else None
398
 
399
- amount = normalize_num_str(pick(N - 1)) if N >= 1 else None
400
- rate = normalize_num_str(pick(N - 2)) if N >= 2 else None
401
- qty = normalize_num_str(pick(N - 3)) if N >= 3 else None
402
 
403
- # fallback amount
404
  if amount is None:
405
- for t in reversed(texts):
406
  if is_numeric_token(t):
407
  amount = normalize_num_str(t)
408
  if amount is not None:
409
  break
410
 
411
- # strong qty/rate inference
412
- if amount is not None and rate is not None:
413
- ratio = amount / rate if rate else None
414
- if ratio and 1 <= round(ratio) <= 10:
415
- qty = float(round(ratio))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
 
 
 
 
 
 
 
 
 
 
 
417
  if qty is None:
418
  qty = 1.0
419
 
420
- if amount == 0 and rate and qty:
421
- amount = rate * qty
422
-
423
- try: amount = float(round(amount, 2))
424
- except: continue
425
-
426
- try: rate = float(round(rate or 0.0, 2))
427
- except: rate = 0.0
428
-
429
- try: qty = float(qty)
430
- except: qty = 1.0
 
 
431
 
432
- parsed.append({
433
  "item_name": name if name else "UNKNOWN",
434
  "item_amount": amount,
435
- "item_rate": rate,
436
- "item_quantity": qty
437
  })
438
 
439
  else:
440
- idxs = [i for i, t in enumerate(texts) if is_numeric_token(t)]
441
- if not idxs:
442
  continue
443
-
444
- amt = normalize_num_str(texts[idxs[-1]])
445
  if amt is None:
446
  continue
447
-
448
- name = " ".join(texts[: idxs[-1]]).strip()
449
  if not name:
450
  continue
 
451
 
452
- rate = 0.0
453
- qty = 1.0
454
-
455
- possible = []
456
- for i in idxs:
457
- v = normalize_num_str(texts[i])
458
  if v is not None:
459
- possible.append(float(v))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- possible = sorted(list({v for v in possible}), reverse=True)
 
 
 
462
 
463
- for p in possible:
464
- if p <= 1 or p >= amt:
465
- continue
466
- ratio = amt / p
467
- r = round(ratio)
468
- if 1 <= r <= 10:
469
- rate = p
470
- qty = r
471
- break
472
-
473
- parsed.append({
474
  "item_name": clean_name_text(name),
475
  "item_amount": float(round(amt, 2)),
476
  "item_rate": float(round(rate, 2)),
477
- "item_quantity": float(qty)
478
  })
479
 
480
- return parsed
481
-
482
-
483
- ###############################################
484
- # FINAL ITEM FILTER
485
- ###############################################
486
-
487
- def final_item_filter(item, headers, all_names):
488
- name = item["item_name"].strip()
489
- ln = name.lower()
490
-
491
- if not name:
492
- return False
493
-
494
- for h in headers:
495
- if h in ln:
496
- return False
497
-
498
- if FOOTER_KEYWORDS.search(ln):
499
- return False
500
-
501
- if item["item_amount"] <= 0:
502
- return False
503
-
504
- words = ln.split()
505
- short = len(words) <= 3
506
-
507
- if any(k in ln for k in ["charges", "services", "room", "radiology", "surgery"]) and short:
508
- lower_other = " ".join(all_names).lower()
509
- if any(z in lower_other for z in [
510
- "rent","ward","nursing","surgeon","anaes","ot","procedure"
511
- ]):
512
- return False
513
-
514
- rate = item["item_rate"]
515
- amt = item["item_amount"]
516
- if rate and rate > amt * 10 and amt < 10000:
517
- return False
518
-
519
- return True
520
-
521
 
522
- ###############################################
523
- # POST VALIDATION (PATCH)
524
- ###############################################
525
-
526
- def post_validate_items(items):
527
  out = []
528
  for it in items:
529
- amt = it["item_amount"]
530
- rate = it["item_rate"]
531
- qty = it["item_quantity"]
532
-
533
- if amt == 0 and rate > 0:
534
- amt = rate * qty
535
-
536
- if rate > 0:
537
- ideal = rate * qty
538
- if abs(ideal - amt) > max(2, 0.15 * ideal):
539
- q = amt / rate
540
- if 1 <= round(q) <= 10:
541
- qty = round(q)
542
-
543
- it["item_amount"] = round(amt, 2)
544
- it["item_rate"] = round(rate, 2)
545
- it["item_quantity"] = float(qty)
546
-
547
  out.append(it)
548
  return out
549
 
550
-
551
- ###############################################
552
- # SUBTOTAL / FINAL TOTAL DETECTION
553
- ###############################################
554
-
555
- def detect_subtotals_and_totals(rows):
556
- sub = None
557
- final = None
558
-
559
- for rt in rows[::-1]:
560
- if not rt.strip():
561
  continue
562
-
563
  if TOTAL_KEYWORDS.search(rt):
564
  m = NUM_RE.search(rt)
565
  if m:
566
  v = normalize_num_str(m.group(0))
567
  if v is None:
568
  continue
569
-
570
- if "sub" in rt.lower():
571
- if sub is None:
572
- sub = round(v, 2)
573
  else:
574
- if final is None:
575
- final = round(v, 2)
576
-
577
- return {"subtotal": sub, "final_total": final}
578
-
579
 
580
- ###############################################
581
- # GEMINI REFINER (PATCHED PROMPT)
582
- ###############################################
583
-
584
- def refine_with_gemini(items, page_text=""):
585
  if not GEMINI_API_KEY or genai is None:
586
- return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
587
-
588
- safe = sanitize_ocr_text(page_text)
589
-
590
- system_prompt = (
591
- "You are a strict hospital bill item cleaner.\n"
592
- "Return ONLY a JSON array of cleaned line items.\n"
593
- "Do NOT include section headers, totals, subtotals, page numbers.\n"
594
- "Do NOT invent items.\n"
595
- )
596
-
597
- user_prompt = f"""
598
- Extract ONLY valid line items from the bill.
599
-
600
- RULES YOU MUST FOLLOW:
601
- - Do NOT create new items.
602
- - Do NOT output section headers (Room Charges, Lab Services, Radiology).
603
- - Merge broken names (doctor names on multiple lines).
604
- - Use exact item names from OCR text.
605
- - Recompute rate/qty if amount = rate×qty is clear.
606
- - Ignore totals or summary lines.
607
- - Ignore page numbers.
608
- - Always output: item_name, item_amount, item_rate, item_quantity.
609
-
610
- OCR TEXT:
611
- {safe}
612
-
613
- INITIAL ITEMS:
614
- {json.dumps(items, ensure_ascii=False)}
615
-
616
- Return ONLY a JSON array:
617
- [
618
- {{"item_name":"...","item_amount":float,"item_rate":float,"item_quantity":float}}
619
- ]
620
- """
621
-
622
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
624
- resp = model.generate_content(
625
  [
626
  {"role": "system", "parts": [system_prompt]},
627
  {"role": "user", "parts": [user_prompt]},
628
  ],
629
  temperature=0.0,
630
- max_output_tokens=1200,
631
  )
632
-
633
- raw = resp.text.strip()
634
- raw = raw.replace("```json", "").replace("```", "").strip()
 
635
  parsed = json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
- cleaned = []
638
- for obj in parsed:
639
- cleaned.append({
640
- "item_name": str(obj.get("item_name", "")).strip(),
641
- "item_amount": float(obj.get("item_amount", 0.0)),
642
- "item_rate": float(obj.get("item_rate", 0.0)),
643
- "item_quantity": float(obj.get("item_quantity", 1.0)),
644
- })
645
-
646
- return cleaned, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
647
-
648
- except:
649
- return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
650
-
 
 
 
 
 
 
 
 
651
 
652
- ###############################################
653
- # MAIN EXTRACTION ENDPOINT
654
- ###############################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
 
656
  @app.post("/extract-bill-data")
657
  async def extract_bill_data(payload: BillRequest):
658
-
659
- url = payload.document
660
-
661
- # download
662
  try:
663
- r = requests.get(url, headers={"User-Agent": "Mozilla"}, timeout=30)
664
- if r.status_code != 200:
665
- raise RuntimeError("Download failed")
666
- data = r.content
667
- except:
668
- return {
669
- "is_success": False,
670
- "token_usage": {},
671
- "data": {"pagewise_line_items": [], "total_item_count": 0}
672
- }
673
-
674
- # load image(s)
675
  try:
676
- if url.lower().split("?")[0].endswith(".pdf"):
677
- imgs = convert_from_bytes(data)
 
 
678
  else:
679
- imgs = [Image.open(BytesIO(data))]
680
- except:
681
- imgs = []
 
 
 
682
 
683
  pagewise = []
684
- total_tokens = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
685
-
686
- for idx, img in enumerate(imgs, 1):
687
 
 
688
  try:
689
- proc = preprocess_image(img)
690
  cells = image_to_tsv_cells(proc)
691
- rows = group_cells_into_rows(cells)
692
-
693
- row_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
694
-
695
- # remove headers
696
- filtered = []
697
- for i, (r, t) in enumerate(zip(rows, row_texts)):
698
- if looks_like_header_text(t, top_of_page=(i < 5)):
 
699
  continue
700
- filtered.append(r)
701
-
702
- rows = filtered
703
- row_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
704
- page_text = " ".join(row_texts)
 
 
705
 
 
706
  top_headers = []
707
- for t in row_texts[:5]:
708
- if looks_like_header_text(t, top_of_page=True):
709
- top_headers.append(t.lower())
710
 
711
  parsed_items = parse_rows_with_columns(rows, cells)
712
 
713
- refined, usage = refine_with_gemini(parsed_items, page_text)
714
-
715
- for k in total_tokens:
716
- total_tokens[k] += usage.get(k, 0)
717
-
718
- all_names = [x["item_name"] for x in refined]
719
 
720
- cleaned = [
721
- x for x in refined
722
- if final_item_filter(x, top_headers, all_names)
723
- ]
724
 
725
- cleaned = post_validate_items(cleaned)
726
-
727
- totals = detect_subtotals_and_totals(row_texts)
 
728
 
729
  page_type = "Bill Detail"
730
- low = page_text.lower()
731
- if "pharmacy" in low:
732
  page_type = "Pharmacy"
733
- if "final bill" in low or "grand total" in low:
734
  page_type = "Final Bill"
735
 
736
- pagewise.append({
737
- "page_no": str(idx),
738
- "page_type": page_type,
739
- "bill_items": cleaned,
740
- "subtotal": totals["subtotal"],
741
- "final_page_total": totals["final_total"]
742
- })
743
-
744
- except:
745
- pagewise.append({
746
- "page_no": str(idx),
747
- "page_type": "Bill Detail",
748
- "bill_items": [],
749
- "subtotal": None,
750
- "final_page_total": None
751
- })
752
-
753
- # global final total = sum of all item amounts
754
- final_sum = 0.0
755
- for p in pagewise:
756
- for it in p["bill_items"]:
757
- final_sum += it["item_amount"]
758
-
759
- total_item_count = sum(len(p["bill_items"]) for p in pagewise)
760
-
761
- return {
762
- "is_success": True,
763
- "token_usage": total_tokens,
764
- "data": {
765
- "pagewise_line_items": pagewise,
766
- "total_item_count": total_item_count,
767
- "final_total": round(final_sum, 2)
768
- }
769
- }
770
 
 
 
 
771
 
772
- ###############################################
773
- # DEBUG ENDPOINT
774
- ###############################################
775
 
 
776
  @app.post("/debug-tsv")
777
  async def debug_tsv(payload: BillRequest):
 
778
  try:
779
- r = requests.get(payload.document, timeout=20)
780
- img = Image.open(BytesIO(r.content))
781
- proc = preprocess_image(img)
782
- cells = image_to_tsv_cells(proc)
783
- return {"cells": cells}
784
- except:
785
- return {"error": "debug failed"}
786
-
787
-
788
- ###############################################
789
- # HEALTH CHECK
790
- ###############################################
 
 
 
791
 
792
  @app.get("/")
793
- def ping():
794
- msg = "Bill extractor live."
795
- if not GEMINI_API_KEY:
796
- msg += " (Gemini missing)"
797
- return {"status": "ok", "message": msg}
 
1
+ # app_bill_extractor_final_v2.py
2
+ # Humanized, high-accuracy bill extraction API.
3
+ # Robust OCR preprocessing, TSV layout parsing, numeric-column inference,
4
+ # header prefiltering, deterministic Gemini refinement (if configured).
5
 
6
  import os
7
  import re
 
19
  import numpy as np
20
  import cv2
21
 
22
+ # Optional: Google Gemini SDK (if available)
23
  try:
24
  import google.generativeai as genai
25
+ except Exception:
26
  genai = None
27
 
28
  # ---------------- LLM CONFIG ----------------
29
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
30
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
 
31
  if GEMINI_API_KEY and genai is not None:
32
  try:
33
  genai.configure(api_key=GEMINI_API_KEY)
34
+ except Exception:
35
  pass
36
 
37
+ # ---------------- FastAPI app ----------------
38
+ app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, humanized)")
 
39
 
40
  class BillRequest(BaseModel):
41
  document: str
42
 
43
+ # ---------------- Regex and keywords ----------------
 
 
 
 
44
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
45
+ TOTAL_KEYWORDS = re.compile(
46
+ r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
47
+ re.I,
48
+ )
49
+ FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
50
 
51
+ # generalized header-related tokens & exact header phrase blacklist (common variants)
52
+ HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
 
 
 
 
53
  HEADER_PHRASES = [
54
  "description qty / hrs consultation rate discount net amt",
55
  "description qty / hrs rate discount net amt",
 
59
  ]
60
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
61
 
62
+ # ---------------- small utilities ----------------
 
 
 
 
 
 
 
63
  def sanitize_ocr_text(s: str) -> str:
64
  if not s:
65
  return ""
 
68
  s = s.replace("\r\n", "\n").replace("\r", "\n")
69
  s = re.sub(r"[ \t]+", " ", s)
70
  s = s.strip()
71
+ return s[:4000]
 
72
 
73
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
74
  if s is None:
75
  return None
76
  s = str(s).strip()
 
77
  if s == "":
78
  return None
79
+ s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
80
  negative = False
81
  if s.startswith("(") and s.endswith(")"):
82
  negative = True
83
  s = s[1:-1]
84
  s = s.replace(",", "")
85
+ if s in ("", "-", "+"):
 
 
 
86
  return None
87
+ try:
88
+ return -float(s) if negative else float(s)
89
+ except Exception:
90
+ try:
91
+ return float(s.replace(" ", ""))
92
+ except Exception:
93
+ return None
94
 
95
  def is_numeric_token(t: Optional[str]) -> bool:
96
  return bool(t and NUM_RE.search(str(t)))
97
 
 
98
  def clean_name_text(s: str) -> str:
99
  s = s.replace("—", "-")
100
  s = re.sub(r"\s+", " ", s)
101
  s = s.strip(" -:,.")
102
+ s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
103
+ s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
104
+ # fix common OCR mistakes for doctor prefixes
105
+ s = re.sub(r"\bOR\b", "DR", s) # sometimes OCR turns 'DR' -> 'OR'
106
  return s.strip()
107
 
108
+ # ---------------- image preprocessing ----------------
109
+ def pil_to_cv2(img: Image.Image) -> Any:
 
 
 
 
110
  arr = np.array(img)
111
  if arr.ndim == 2:
112
  return arr
113
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
114
 
115
+ def preprocess_image(pil_img: Image.Image) -> Any:
 
116
  pil_img = pil_img.convert("RGB")
117
  w, h = pil_img.size
118
+ target_w = 1500
119
+ if w < target_w:
120
+ scale = target_w / float(w)
121
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
122
+ cv_img = pil_to_cv2(pil_img)
123
+ gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
 
 
124
  gray = cv2.fastNlMeansDenoising(gray, h=10)
 
125
  try:
126
+ bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 41, 15)
127
+ except Exception:
128
+ _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
129
+ kernel = np.ones((1,1), np.uint8)
130
+ bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
 
 
 
131
  return bw
132
 
133
+ # ---------------- OCR TSV ----------------
134
+ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
 
 
 
 
135
  try:
136
+ o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
137
+ except Exception:
138
+ o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
 
 
 
 
 
139
  cells = []
140
+ n = len(o.get("text", []))
 
141
  for i in range(n):
142
+ raw = o["text"][i]
143
+ if raw is None:
144
+ continue
145
+ txt = str(raw).strip()
146
+ if not txt:
147
  continue
148
  try:
149
+ conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
150
+ except Exception:
151
  conf = -1.0
152
+ left = int(o.get("left", [0])[i])
153
+ top = int(o.get("top", [0])[i])
154
+ width = int(o.get("width", [0])[i])
155
+ height = int(o.get("height", [0])[i])
156
+ center_y = top + height / 2.0
157
+ center_x = left + width / 2.0
158
+ cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
 
 
 
 
 
 
 
 
159
  return cells
160
 
161
+ # ---------------- grouping & merge helpers ----------------
162
+ def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
 
 
 
 
163
  if not cells:
164
  return []
165
+ sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
 
166
  rows = []
167
+ current = [sorted_cells[0]]
168
+ last_y = sorted_cells[0]["center_y"]
169
+ for c in sorted_cells[1:]:
170
+ if abs(c["center_y"] - last_y) <= y_tolerance:
 
171
  current.append(c)
172
  last_y = (last_y * (len(current) - 1) + c["center_y"]) / len(current)
173
  else:
174
  rows.append(sorted(current, key=lambda cc: cc["left"]))
175
  current = [c]
176
  last_y = c["center_y"]
 
177
  if current:
178
  rows.append(sorted(current, key=lambda cc: cc["left"]))
 
179
  return rows
180
 
181
+ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
 
 
 
 
 
182
  if not rows:
183
  return rows
 
184
  merged = []
185
  i = 0
186
  while i < len(rows):
187
  row = rows[i]
188
  tokens = [c["text"] for c in row]
 
 
189
  has_num = any(is_numeric_token(t) for t in tokens)
190
+ # if row looks pure text and next row contains numbers but short left text tokens, merge
191
+ if not has_num and i + 1 < len(rows):
192
+ next_row = rows[i+1]
193
+ next_tokens = [c["text"] for c in next_row]
194
+ next_has_num = any(is_numeric_token(t) for t in next_tokens)
195
+ if next_has_num and len(tokens) >= 2 and len([t for t in next_tokens if not is_numeric_token(t)]) <= 3:
196
+ merged_row = []
197
+ min_left = min((c["left"] for c in next_row), default=0)
198
+ offset = 10
199
+ for c in row:
200
+ newc = c.copy()
201
+ newc["left"] = min_left - offset
202
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
203
+ merged_row.append(newc)
204
+ offset += 10
205
+ merged_row.extend(next_row)
206
+ merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
207
+ i += 2
208
+ continue
209
+ # Additional merge: If a row ends with a trailing token like a doctor's name line with single token and next row also text, merge (helps names split across 2+ lines)
210
+ if not has_num and i + 1 < len(rows):
211
+ next_row = rows[i+1]
212
+ next_tokens = [c["text"] for c in next_row]
213
+ next_has_num = any(is_numeric_token(t) for t in next_tokens)
214
+ if not next_has_num and len(tokens) <= 3 and len(next_tokens) <= 3:
215
+ # merge both textual lines into one (keeps relative left ordering by shifting)
216
+ merged_row = []
217
+ min_left = min((c["left"] for c in next_row + row), default=0)
218
+ offset = 10
219
+ for c in row + next_row:
220
+ newc = c.copy()
221
+ if newc["left"] > min_left:
222
+ newc["left"] = newc["left"]
223
+ else:
224
+ newc["left"] = min_left - offset
225
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
226
+ merged_row.append(newc)
227
+ offset += 5
228
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
229
  i += 2
230
  continue
 
231
  merged.append(row)
232
  i += 1
 
233
  return merged
234
 
235
+ # ---------------- numeric column detection ----------------
236
+ # >>> CHANGE: adaptive clustering (restored to conservative adaptive threshold)
237
+ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
 
 
 
238
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
239
  if not xs:
240
  return []
 
241
  xs = sorted(xs)
242
  if len(xs) == 1:
243
  return [xs[0]]
244
+ gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
 
245
  mean_gap = float(np.mean(gaps))
246
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
247
+ gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
 
248
  clusters = []
249
  curr = [xs[0]]
 
250
  for i, g in enumerate(gaps):
251
+ if g > gap_thresh and len(clusters) < (max_columns - 1):
252
  clusters.append(curr)
253
+ curr = [xs[i+1]]
254
  else:
255
+ curr.append(xs[i+1])
 
256
  clusters.append(curr)
 
257
  centers = [float(np.median(c)) for c in clusters]
258
+ if len(centers) > max_columns:
259
+ centers = centers[-max_columns:]
260
  return sorted(centers)
261
 
262
+ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]:
263
+ if not column_centers:
 
264
  return None
265
+ distances = [abs(token_x - cx) for cx in column_centers]
266
+ return int(np.argmin(distances))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # ---------------- parsing rows into items ----------------
269
+ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
270
+ parsed_items = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  rows = merge_multiline_names(rows)
272
+ column_centers = detect_numeric_columns(page_cells, max_columns=4)
 
 
273
 
274
  for row in rows:
275
+ tokens = [c["text"] for c in row]
276
+ if not tokens:
277
+ continue
278
+ joined_lower = " ".join(tokens).lower()
279
+ if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
280
  continue
281
+ if all(not is_numeric_token(t) for t in tokens):
282
  continue
283
 
284
+ # gather numeric candidates (unique, filtered)
285
  numeric_values = []
286
+ for t in tokens:
287
  if is_numeric_token(t):
288
  v = normalize_num_str(t)
289
  if v is not None:
290
  numeric_values.append(float(v))
291
+ # de-duplicate and sort descending (larger candidates first)
292
+ numeric_values = sorted(list({int(x) if float(x).is_integer() else x for x in numeric_values}), reverse=True)
293
 
294
+ if column_centers:
295
+ left_text_parts = []
296
+ numeric_bucket_map = {i: [] for i in range(len(column_centers))}
 
 
 
 
 
 
 
297
  for c in row:
298
  t = c["text"]
299
+ cx = c["center_x"]
300
  if is_numeric_token(t):
301
+ col_idx = assign_token_to_column(cx, column_centers)
302
+ if col_idx is None:
303
+ numeric_bucket_map[len(column_centers) - 1].append(t)
304
+ else:
305
+ numeric_bucket_map[col_idx].append(t)
306
  else:
307
+ left_text_parts.append(t)
308
+ raw_name = " ".join(left_text_parts).strip()
309
+ name = clean_name_text(raw_name) if raw_name else ""
 
 
 
310
 
311
+ num_cols = len(column_centers)
312
+ def get_bucket(idx):
313
+ vals = numeric_bucket_map.get(idx, [])
314
  return vals[-1] if vals else None
315
 
316
+ amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
317
+ rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
318
+ qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
319
 
 
320
  if amount is None:
321
+ for t in reversed(tokens):
322
  if is_numeric_token(t):
323
  amount = normalize_num_str(t)
324
  if amount is not None:
325
  break
326
 
327
+ # >>> CHANGE: safer inference — skip tiny candidates like 1, enforce qty bounds, require close ratio
328
+ if amount is not None and numeric_values:
329
+ # Only accept candidate as rate if candidate >= 2 (or amount is tiny) and candidate < amount
330
+ for cand in numeric_values:
331
+ try:
332
+ cand_float = float(cand)
333
+ except:
334
+ continue
335
+ if cand_float <= 1.0:
336
+ continue
337
+ if amount <= 5 and cand_float < 1.0:
338
+ continue
339
+ if cand_float >= amount:
340
+ continue
341
+ ratio = amount / cand_float if cand_float else None
342
+ if ratio is None:
343
+ continue
344
+ r = round(ratio)
345
+ if r < 1 or r > 200:
346
+ continue
347
+ # require relative closeness threshold (adaptive)
348
+ if abs(ratio - r) <= max(0.03 * r, 0.15):
349
+ # Accept only if qty reasonable (<=100)
350
+ if r <= 100:
351
+ rate = cand_float
352
+ qty = float(r)
353
+ break
354
 
355
+ # fallback compute rate if qty found but rate missing
356
+ if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
357
+ try:
358
+ candidate_rate = amount / qty
359
+ # require candidate_rate > 1 (avoid tiny rates) and reasonable
360
+ if candidate_rate >= 2:
361
+ rate = candidate_rate
362
+ except Exception:
363
+ pass
364
+
365
+ # final defaults
366
  if qty is None:
367
  qty = 1.0
368
 
369
+ # normalize and sanity-check
370
+ try:
371
+ amount = float(round(amount, 2))
372
+ except Exception:
373
+ continue
374
+ try:
375
+ rate = float(round(rate, 2)) if rate is not None else 0.0
376
+ except Exception:
377
+ rate = 0.0
378
+ try:
379
+ qty = float(qty)
380
+ except Exception:
381
+ qty = 1.0
382
 
383
+ parsed_items.append({
384
  "item_name": name if name else "UNKNOWN",
385
  "item_amount": amount,
386
+ "item_rate": rate if rate is not None else 0.0,
387
+ "item_quantity": qty if qty is not None else 1.0,
388
  })
389
 
390
  else:
391
+ numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
392
+ if not numeric_idxs:
393
  continue
394
+ last = numeric_idxs[-1]
395
+ amt = normalize_num_str(tokens[last])
396
  if amt is None:
397
  continue
398
+ name = " ".join(tokens[:last]).strip()
 
399
  if not name:
400
  continue
401
+ rate = None; qty = None
402
 
403
+ # try to pick rate/qty from previous numeric tokens (right-to-left)
404
+ # and use the safer inference logic (ignore candidate == 1)
405
+ right_nums = []
406
+ for i in numeric_idxs:
407
+ v = normalize_num_str(tokens[i])
 
408
  if v is not None:
409
+ right_nums.append(float(v))
410
+ right_nums = sorted(list({int(x) if float(x).is_integer() else x for x in right_nums}), reverse=True)
411
+
412
+ # attempt direct mapping: last numeric = amount, previous maybe rate / qty
413
+ if len(right_nums) >= 2:
414
+ cand = right_nums[1]
415
+ if float(cand) > 1 and float(cand) < float(amt):
416
+ # check ratio
417
+ ratio = float(amt) / float(cand) if cand else None
418
+ if ratio:
419
+ r = round(ratio)
420
+ if 1 <= r <= 200 and abs(ratio - r) <= max(0.03 * r, 0.15) and r <= 100:
421
+ rate = float(cand)
422
+ qty = float(r)
423
+ # fallback: conservative search like above
424
+ if rate is None and right_nums:
425
+ for cand in right_nums:
426
+ if cand <= 1.0 or cand >= float(amt):
427
+ continue
428
+ ratio = float(amt) / float(cand)
429
+ r = round(ratio)
430
+ if 1 <= r <= 100 and abs(ratio - r) <= max(0.03 * r, 0.15):
431
+ rate = float(cand)
432
+ qty = float(r)
433
+ break
434
 
435
+ if qty is None:
436
+ qty = 1.0
437
+ if rate is None:
438
+ rate = 0.0
439
 
440
+ parsed_items.append({
 
 
 
 
 
 
 
 
 
 
441
  "item_name": clean_name_text(name),
442
  "item_amount": float(round(amt, 2)),
443
  "item_rate": float(round(rate, 2)),
444
+ "item_quantity": float(qty),
445
  })
446
 
447
+ return parsed_items
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
+ # ---------------- dedupe & totals ----------------
450
+ def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
451
+ seen = set()
 
 
452
  out = []
453
  for it in items:
454
+ nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
455
+ key = (nm[:120], round(float(it["item_amount"]), 2))
456
+ if key in seen:
457
+ continue
458
+ seen.add(key)
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  out.append(it)
460
  return out
461
 
462
+ def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
463
+ subtotal = None; final = None
464
+ for rt in rows_texts[::-1]:
465
+ if not rt or rt.strip() == "":
 
 
 
 
 
 
 
466
  continue
 
467
  if TOTAL_KEYWORDS.search(rt):
468
  m = NUM_RE.search(rt)
469
  if m:
470
  v = normalize_num_str(m.group(0))
471
  if v is None:
472
  continue
473
+ if re.search(r"sub", rt, re.I):
474
+ if subtotal is None: subtotal = float(round(v, 2))
 
 
475
  else:
476
+ if final is None: final = float(round(v, 2))
477
+ return {"subtotal": subtotal, "final_total": final}
 
 
 
478
 
479
+ # ---------------- Gemini refinement (deterministic) ----------------
480
+ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
481
+ zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
 
 
482
  if not GEMINI_API_KEY or genai is None:
483
+ return page_items, zero_usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  try:
485
+ safe_text = sanitize_ocr_text(page_text)
486
+ system_prompt = (
487
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
488
+ "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
489
+ "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
490
+ )
491
+ user_prompt = (
492
+ f"page_text='''{safe_text}'''\n"
493
+ f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
494
+ "Example:\n"
495
+ "items = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0},\n"
496
+ " {'item_name':'Description Qty / Hrs Consultation Rate Discount Net Amt','item_amount':1950.0,'item_rate':1950.0,'item_quantity':1.0}]\n"
497
+ "=>\n"
498
+ "[{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n\n"
499
+ "Return only the cleaned JSON array of items."
500
+ )
501
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
502
+ response = model.generate_content(
503
  [
504
  {"role": "system", "parts": [system_prompt]},
505
  {"role": "user", "parts": [user_prompt]},
506
  ],
507
  temperature=0.0,
508
+ max_output_tokens=1000,
509
  )
510
+ raw = response.text.strip()
511
+ if raw.startswith("```"):
512
+ raw = re.sub(r"^```[a-zA-Z]*", "", raw)
513
+ raw = re.sub(r"```$", "", raw).strip()
514
  parsed = json.loads(raw)
515
+ if isinstance(parsed, list):
516
+ cleaned = []
517
+ for obj in parsed:
518
+ try:
519
+ cleaned.append({
520
+ "item_name": str(obj.get("item_name", "")).strip(),
521
+ "item_amount": float(obj.get("item_amount", 0.0)),
522
+ "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
523
+ "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
524
+ })
525
+ except Exception:
526
+ continue
527
+ return cleaned, zero_usage
528
+ return page_items, zero_usage
529
+ except Exception:
530
+ return page_items, zero_usage
531
 
532
+ # ---------------- header heuristics & final filter ----------------
533
+ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
534
+ if not txt:
535
+ return False
536
+ t = re.sub(r"\s+", " ", txt.strip().lower())
537
+ # exact phrase blacklist
538
+ if any(h == t for h in HEADER_PHRASES):
539
+ return True
540
+ hits = sum(1 for k in HEADER_KEYWORDS if k in t)
541
+ if hits >= 2:
542
+ return True
543
+ tokens = re.split(r"[\s\|,/:]+", t)
544
+ key_hit_count = sum(1 for tok in tokens if tok in HEADER_KEYWORDS)
545
+ if key_hit_count >= 3:
546
+ return True
547
+ if top_of_page and len(tokens) <= 10 and key_hit_count >= 2:
548
+ return True
549
+ if ("rate" in t or "net" in t) and "amt" in t and not any(ch.isdigit() for ch in t):
550
+ return True
551
+ if t.startswith("description") or t.startswith("qty") or t.startswith("qty /"):
552
+ return True
553
+ return False
554
 
555
+ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [], other_item_names: List[str] = []) -> bool:
556
+ name = (item.get("item_name") or "").strip()
557
+ if not name:
558
+ return False
559
+ ln = name.lower()
560
+ # header exact detection
561
+ for h in known_page_headers:
562
+ if h and h.strip() and h.strip().lower() in ln:
563
+ return False
564
+ if FOOTER_KEYWORDS.search(ln):
565
+ return False
566
+ if item.get("item_amount", 0) > 1_000_000:
567
+ return False
568
+ if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
569
+ return False
570
+ # avoid pure section headers (short & header words)
571
+ words = ln.split()
572
+ header_word_hits = sum(1 for k in HEADER_KEYWORDS if k in ln)
573
+ if header_word_hits >= 1 and len(words) <= 3:
574
+ # if page contains more detailed items with 'room'/'rent'/'nursing' etc, remove this generic header
575
+ lower_other = " ".join(other_item_names).lower()
576
+ if any(k in lower_other for k in ["room", "rent", "nursing", "ward", "surgeon", "anaes", "ot", "charges", "procedure", "radiology"]):
577
+ return False
578
+ # also if name is exactly one of the short header words, drop
579
+ if ln in ("charge", "charges", "services", "consultation", "room", "radiology", "surgery"):
580
+ return False
581
+ # drop non-informative labels even if they have amount (summary rows)
582
+ if len(words) <= 4 and re.search(r"\b(charges|services|room|radiolog|laborat|surgery|procedure|rent|nursing)\b", ln):
583
+ # try to detect if it's a summary (presence of other more specific items)
584
+ lower_other = " ".join(other_item_names).lower()
585
+ if any(tok in lower_other for tok in ["rent", "room", "ward", "nursing", "surgeon", "anaes", "ot"]):
586
+ return False
587
+ if float(item.get("item_amount", 0)) <= 0.0:
588
+ return False
589
+ # sanity check rate vs amount
590
+ rate = float(item.get("item_rate", 0) or 0)
591
+ amt = float(item.get("item_amount", 0) or 0)
592
+ if rate and rate > amt * 10 and amt < 10000:
593
+ return False
594
+ return True
595
 
596
+ # ---------------- main endpoint ----------------
597
  @app.post("/extract-bill-data")
598
  async def extract_bill_data(payload: BillRequest):
599
+ doc_url = payload.document
 
 
 
600
  try:
601
+ headers = {"User-Agent": "Mozilla/5.0"}
602
+ resp = requests.get(doc_url, headers=headers, timeout=30)
603
+ if resp.status_code != 200:
604
+ raise RuntimeError(f"download failed status={resp.status_code}")
605
+ file_bytes = resp.content
606
+ except Exception:
607
+ return {"is_success": False, "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, "data": {"pagewise_line_items": [], "total_item_count": 0}}
608
+
609
+ images = []
610
+ clean_url = doc_url.split("?", 1)[0].lower()
 
 
611
  try:
612
+ if clean_url.endswith(".pdf"):
613
+ images = convert_from_bytes(file_bytes)
614
+ elif any(clean_url.endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]):
615
+ images = [Image.open(BytesIO(file_bytes))]
616
  else:
617
+ try:
618
+ images = convert_from_bytes(file_bytes)
619
+ except Exception:
620
+ images = []
621
+ except Exception:
622
+ images = []
623
 
624
  pagewise = []
625
+ cumulative_token_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
 
 
626
 
627
+ for idx, page_img in enumerate(images, start=1):
628
  try:
629
+ proc = preprocess_image(page_img)
630
  cells = image_to_tsv_cells(proc)
631
+ rows = group_cells_into_rows(cells, y_tolerance=12)
632
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
633
+
634
+ # === HEADER PREFILTER: remove header-like rows anywhere on page ===
635
+ rows_filtered = []
636
+ for i, (r, rt) in enumerate(zip(rows, rows_texts)):
637
+ top_flag = (i < 6)
638
+ rt_norm = sanitize_ocr_text(rt).lower()
639
+ if looks_like_header_text(rt_norm, top_of_page=top_flag):
640
  continue
641
+ if any(h in rt_norm for h in HEADER_PHRASES):
642
+ continue
643
+ rows_filtered.append(r)
644
+ # recompute row texts and a simple page_text
645
+ rows = rows_filtered
646
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
647
+ page_text = sanitize_ocr_text(" ".join(rows_texts))
648
 
649
+ # detect page-level top headers (for final filtering)
650
  top_headers = []
651
+ for i, rt in enumerate(rows_texts[:6]):
652
+ if looks_like_header_text(rt, top_of_page=(i < 4)):
653
+ top_headers.append(rt.strip().lower())
654
 
655
  parsed_items = parse_rows_with_columns(rows, cells)
656
 
657
+ # ALWAYS attempt Gemini refinement if available (deterministic settings)
658
+ refined_items, token_u = refine_with_gemini(parsed_items, page_text)
659
+ for k in cumulative_token_usage:
660
+ cumulative_token_usage[k] += token_u.get(k, 0)
 
 
661
 
662
+ # Prepare other_item_names for contextual filtering (helps remove generic section headers)
663
+ other_item_names = [it.get("item_name","") for it in refined_items]
 
 
664
 
665
+ # final cleaning & dedupe
666
+ cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers, other_item_names=other_item_names)]
667
+ cleaned = dedupe_items(cleaned)
668
+ cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
669
 
670
  page_type = "Bill Detail"
671
+ page_txt = page_text.lower()
672
+ if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
673
  page_type = "Pharmacy"
674
+ if "final bill" in page_txt or "grand total" in page_txt:
675
  page_type = "Final Bill"
676
 
677
+ pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
678
+ except Exception:
679
+ pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
680
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
 
682
+ total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
683
+ if not GEMINI_API_KEY or genai is None:
684
+ cumulative_token_usage["warning_no_gemini"] = 1
685
 
686
+ return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
 
 
687
 
688
+ # ---------------- debug TSV ----------------
689
  @app.post("/debug-tsv")
690
  async def debug_tsv(payload: BillRequest):
691
+ doc_url = payload.document
692
  try:
693
+ resp = requests.get(doc_url, timeout=20)
694
+ if resp.status_code != 200:
695
+ return {"error": "Download failed"}
696
+ file_bytes = resp.content
697
+ except Exception:
698
+ return {"error": "Download failed"}
699
+ clean_url = doc_url.split("?", 1)[0].lower()
700
+ if clean_url.endswith(".pdf"):
701
+ imgs = convert_from_bytes(file_bytes)
702
+ img = imgs[0]
703
+ else:
704
+ img = Image.open(BytesIO(file_bytes))
705
+ proc = preprocess_image(img)
706
+ cells = image_to_tsv_cells(proc)
707
+ return {"cells": cells}
708
 
709
  @app.get("/")
710
+ def health_check():
711
+ msg = "Bill extraction API (final) live."
712
+ if not GEMINI_API_KEY or genai is None:
713
+ msg += " (No GEMINI_API_KEY/configured SDK — LLM refinement skipped.)"
714
+ return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}