Sathvik-kota commited on
Commit
2ad459f
·
verified ·
1 Parent(s): 811fc30

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +128 -186
app.py CHANGED
@@ -1,17 +1,10 @@
1
- # app_bill_extractor_final_v2.py
2
- # Humanized, high-accuracy bill extraction API.
3
- # Robust OCR preprocessing, TSV layout parsing, numeric-column inference,
4
- # header prefiltering, deterministic Gemini refinement (if configured).
5
-
6
  import os
7
  import re
8
  import json
9
- import logging
10
  from io import BytesIO
11
  from typing import List, Dict, Any, Optional, Tuple
12
 
13
- import uvicorn
14
- from fastapi import FastAPI, BackgroundTasks
15
  from pydantic import BaseModel
16
  import requests
17
  from PIL import Image
@@ -27,40 +20,29 @@ try:
27
  except Exception:
28
  genai = None
29
 
30
- # ---------------- logging ----------------
31
- logging.basicConfig(level=logging.INFO)
32
- logger = logging.getLogger("bill-extractor")
33
-
34
- # ---------------- FastAPI app ----------------
35
- app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, humanized)")
36
-
37
- # ---------------- request model ----------------
38
- class BillRequest(BaseModel):
39
- document: str
40
-
41
  # ---------------- LLM CONFIG ----------------
42
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
43
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
44
  if GEMINI_API_KEY and genai is not None:
45
  try:
46
  genai.configure(api_key=GEMINI_API_KEY)
47
- logger.info("Gemini SDK configured.")
48
- except Exception as e:
49
- logger.warning("Failed to configure Gemini SDK: %s", e)
 
 
50
 
51
- # ---------------- Regex and keywords (updated) ----------------
 
 
 
52
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
53
  TOTAL_KEYWORDS = re.compile(
54
- r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|total)",
55
  re.I,
56
  )
57
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
58
-
59
- HEADER_KEYWORDS = [
60
- "description", "qty", "hrs", "rate", "discount", "net", "amt", "amount",
61
- "consultation", "address", "sex", "age", "mobile", "patient", "category",
62
- "doctor", "dr", "invoice", "bill", "subtotal", "total", "charges", "service"
63
- ]
64
  HEADER_PHRASES = [
65
  "description qty / hrs consultation rate discount net amt",
66
  "description qty / hrs rate discount net amt",
@@ -79,6 +61,9 @@ def sanitize_ocr_text(s: str) -> str:
79
  s = s.replace("\r\n", "\n").replace("\r", "\n")
80
  s = re.sub(r"[ \t]+", " ", s)
81
  s = s.strip()
 
 
 
82
  return s[:4000]
83
 
84
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
@@ -106,28 +91,14 @@ def normalize_num_str(s: Optional[str]) -> Optional[float]:
106
  def is_numeric_token(t: Optional[str]) -> bool:
107
  return bool(t and NUM_RE.search(str(t)))
108
 
109
- def looks_like_date_num(s: str) -> bool:
110
- s_digits = re.sub(r"[^\d]", "", s or "")
111
- if len(s_digits) >= 7:
112
- if s_digits.endswith(("2025","2024","2023","2022","2026")):
113
- return True
114
- try:
115
- if float(s_digits) > 1e6:
116
- return True
117
- except:
118
- pass
119
- return False
120
-
121
  def clean_name_text(s: str) -> str:
122
  s = s.replace("—", "-")
123
  s = re.sub(r"\s+", " ", s)
124
- s = s.strip(" -:,.=")
125
- s = re.sub(r"\s+x$", "", s, flags=re.I)
126
- s = re.sub(r"[\)\}\]]+$", "", s)
127
- s = re.sub(r"\bOR\b", "DR", s)
128
  s = s.strip(" -:,.")
129
- s = s.strip()
130
- return s
 
 
131
 
132
  # ---------------- image preprocessing ----------------
133
  def pil_to_cv2(img: Image.Image) -> Any:
@@ -137,7 +108,6 @@ def pil_to_cv2(img: Image.Image) -> Any:
137
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
138
 
139
  def preprocess_image(pil_img: Image.Image) -> Any:
140
- # convert and upscale if small
141
  pil_img = pil_img.convert("RGB")
142
  w, h = pil_img.size
143
  target_w = 1500
@@ -145,11 +115,7 @@ def preprocess_image(pil_img: Image.Image) -> Any:
145
  scale = target_w / float(w)
146
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
147
  cv_img = pil_to_cv2(pil_img)
148
- # grayscale and denoise
149
- if cv_img.ndim == 3:
150
- gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
151
- else:
152
- gray = cv_img
153
  gray = cv2.fastNlMeansDenoising(gray, h=10)
154
  try:
155
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
@@ -162,7 +128,6 @@ def preprocess_image(pil_img: Image.Image) -> Any:
162
 
163
  # ---------------- OCR TSV ----------------
164
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
165
- # pytesseract expects either a PIL image or numpy array
166
  try:
167
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
168
  except Exception:
@@ -187,10 +152,11 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
187
  center_y = top + height / 2.0
188
  center_x = left + width / 2.0
189
  cells.append({"text": txt, "conf": conf, "left": left, "top": top,
190
- "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
191
  return cells
192
 
193
- # ---------------- grouping & merging helpers ----------------
194
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
195
  if not cells:
196
  return []
@@ -219,6 +185,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
219
  row = rows[i]
220
  tokens = [c["text"] for c in row]
221
  has_num = any(is_numeric_token(t) for t in tokens)
 
222
  if not has_num and i + 1 < len(rows):
223
  next_row = rows[i+1]
224
  next_tokens = [c["text"] for c in next_row]
@@ -237,6 +204,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
237
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
238
  i += 2
239
  continue
 
240
  if not has_num and i + 1 < len(rows):
241
  next_row = rows[i+1]
242
  next_tokens = [c["text"] for c in next_row]
@@ -247,10 +215,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
247
  offset = 10
248
  for c in row + next_row:
249
  newc = c.copy()
250
- if newc["left"] > min_left:
251
- newc["left"] = newc["left"]
252
- else:
253
- newc["left"] = min_left - offset
254
  newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
255
  merged_row.append(newc)
256
  offset += 5
@@ -262,7 +227,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
262
  return merged
263
 
264
  # ---------------- numeric column detection ----------------
265
- def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
266
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
267
  if not xs:
268
  return []
@@ -293,60 +258,11 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
293
  distances = [abs(token_x - cx) for cx in column_centers]
294
  return int(np.argmin(distances))
295
 
296
- # ---------------- Gemini refinement (deterministic) ----------------
297
- def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
298
- zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
299
- if not GEMINI_API_KEY or genai is None:
300
- return page_items, zero_usage
301
- try:
302
- safe_text = sanitize_ocr_text(page_text)
303
- system_prompt = (
304
- "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
305
- "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
306
- "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
307
- )
308
- user_prompt = (
309
- f"page_text='''{safe_text}'''\n"
310
- f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
311
- "Return only the cleaned JSON array of items."
312
- )
313
- model = genai.GenerativeModel(GEMINI_MODEL_NAME)
314
- response = model.generate_content(
315
- [
316
- {"role": "system", "parts": [system_prompt]},
317
- {"role": "user", "parts": [user_prompt]},
318
- ],
319
- temperature=0.0,
320
- max_output_tokens=1000,
321
- )
322
- raw = response.text.strip()
323
- if raw.startswith("```"):
324
- raw = re.sub(r"^```[a-zA-Z]*", "", raw)
325
- raw = re.sub(r"```$", "", raw).strip()
326
- parsed = json.loads(raw)
327
- if isinstance(parsed, list):
328
- cleaned = []
329
- for obj in parsed:
330
- try:
331
- cleaned.append({
332
- "item_name": str(obj.get("item_name", "")).strip(),
333
- "item_amount": float(obj.get("item_amount", 0.0)),
334
- "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
335
- "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
336
- })
337
- except Exception:
338
- continue
339
- return cleaned, zero_usage
340
- return page_items, zero_usage
341
- except Exception as e:
342
- logger.warning("Gemini refinement failed: %s", e)
343
- return page_items, zero_usage
344
-
345
- # ---------------- parsing rows into items (modified) ----------------
346
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
347
  parsed_items = []
348
  rows = merge_multiline_names(rows)
349
- column_centers = detect_numeric_columns(page_cells, max_columns=4)
350
 
351
  for row in rows:
352
  tokens = [c["text"] for c in row]
@@ -358,23 +274,23 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
358
  if all(not is_numeric_token(t) for t in tokens):
359
  continue
360
 
 
361
  numeric_values = []
362
  for t in tokens:
363
  if is_numeric_token(t):
364
- if looks_like_date_num(t):
365
- continue
366
  v = normalize_num_str(t)
367
  if v is not None:
368
  numeric_values.append(float(v))
369
- numeric_values = sorted({int(x) if float(x).is_integer() else x for x in numeric_values}, reverse=True)
370
 
371
  if column_centers:
372
  left_text_parts = []
373
  numeric_bucket_map = {i: [] for i in range(len(column_centers))}
374
  for c in row:
375
  t = c["text"]
376
- if is_numeric_token(t) and not looks_like_date_num(t):
377
- col_idx = assign_token_to_column(c["center_x"], column_centers)
 
378
  if col_idx is None:
379
  numeric_bucket_map[len(column_centers) - 1].append(t)
380
  else:
@@ -383,23 +299,24 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
383
  left_text_parts.append(t)
384
  raw_name = " ".join(left_text_parts).strip()
385
  name = clean_name_text(raw_name) if raw_name else ""
386
-
387
  num_cols = len(column_centers)
 
388
  def get_bucket(idx):
389
  vals = numeric_bucket_map.get(idx, [])
390
  return vals[-1] if vals else None
391
 
392
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
393
- rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
394
- qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
395
 
396
  if amount is None:
397
  for t in reversed(tokens):
398
- if is_numeric_token(t) and not looks_like_date_num(t):
399
  amount = normalize_num_str(t)
400
  if amount is not None:
401
  break
402
 
 
403
  if amount is not None and numeric_values:
404
  for cand in numeric_values:
405
  try:
@@ -424,6 +341,7 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
424
  qty = float(r)
425
  break
426
 
 
427
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
428
  try:
429
  candidate_rate = amount / qty
@@ -435,17 +353,18 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
435
  if qty is None:
436
  qty = 1.0
437
 
 
438
  try:
439
  amount = float(round(amount, 2))
440
- except:
441
  continue
442
  try:
443
  rate = float(round(rate, 2)) if rate is not None else 0.0
444
- except:
445
  rate = 0.0
446
  try:
447
  qty = float(qty)
448
- except:
449
  qty = 1.0
450
 
451
  parsed_items.append({
@@ -456,7 +375,7 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
456
  })
457
 
458
  else:
459
- numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t) and not looks_like_date_num(t)]
460
  if not numeric_idxs:
461
  continue
462
  last = numeric_idxs[-1]
@@ -473,11 +392,11 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
473
  v = normalize_num_str(tokens[i])
474
  if v is not None:
475
  right_nums.append(float(v))
476
- right_nums = sorted({int(x) if float(x).is_integer() else x for x in right_nums}, reverse=True)
477
 
478
  if len(right_nums) >= 2:
479
  cand = right_nums[1]
480
- if 1 < cand < float(amt):
481
  ratio = float(amt) / float(cand) if cand else None
482
  if ratio:
483
  r = round(ratio)
@@ -524,7 +443,7 @@ def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
524
 
525
  def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
526
  subtotal = None; final = None
527
- for rt in rows_texts[::-1]:
528
  if not rt or rt.strip() == "":
529
  continue
530
  if TOTAL_KEYWORDS.search(rt):
@@ -534,16 +453,72 @@ def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[flo
534
  if v is None:
535
  continue
536
  if re.search(r"sub", rt, re.I):
537
- if subtotal is None: subtotal = float(round(v, 2))
 
538
  else:
539
- if final is None: final = float(round(v, 2))
 
540
  return {"subtotal": subtotal, "final_total": final}
541
 
542
- # ---------------- header heuristics & final filter (updated) ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
544
  if not txt:
545
  return False
546
  t = re.sub(r"\s+", " ", txt.strip().lower())
 
547
  if any(h == t for h in HEADER_PHRASES):
548
  return True
549
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
@@ -559,8 +534,6 @@ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
559
  return True
560
  if t.startswith("description") or t.startswith("qty") or t.startswith("qty /"):
561
  return True
562
- if "sponsor" in t or "admission" in t or "age" in t or "sex" in t or "mobile" in t or "address" in t:
563
- return True
564
  return False
565
 
566
  def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [], other_item_names: List[str] = []) -> bool:
@@ -568,41 +541,25 @@ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [],
568
  if not name:
569
  return False
570
  ln = name.lower()
571
- if name.upper() == "UNKNOWN" or ln == "unknown":
572
- return False
573
- if ln == "x":
574
- return False
575
  for h in known_page_headers:
576
  if h and h.strip() and h.strip().lower() in ln:
577
  return False
578
- if re.search(r"\b(total|subtotal|grand total)\b", ln):
579
- return False
580
  if FOOTER_KEYWORDS.search(ln):
581
  return False
582
  if item.get("item_amount", 0) > 1_000_000:
583
  return False
584
  if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
585
  return False
586
- words = ln.split()
587
- header_word_hits = sum(1 for k in HEADER_KEYWORDS if k in ln)
588
- if header_word_hits >= 1 and len(words) <= 3:
589
- lower_other = " ".join(other_item_names).lower()
590
- if any(k in lower_other for k in ["room", "rent", "nursing", "ward", "surgeon", "anaes", "ot", "charges", "procedure", "radiology"]):
591
- return False
592
- if ln in ("charge", "charges", "services", "consultation", "room", "radiology", "surgery"):
593
- return False
594
- if len(words) <= 4 and re.search(r"\b(charges|services|room|radiolog|laborat|surgery|procedure|rent|nursing)\b", ln):
595
- lower_other = " ".join(other_item_names).lower()
596
- if any(tok in lower_other for tok in ["rent", "room", "ward", "nursing", "surgeon", "anaes", "ot"]):
597
- return False
598
- amt = float(item.get("item_amount", 0) or 0)
599
- rate = float(item.get("item_rate", 0) or 0)
600
- qty = float(item.get("item_quantity", 0) or 0)
601
- if qty <= 0:
602
- return False
603
- if rate and rate > amt:
604
  return False
605
- if amt <= 0.0:
 
 
 
606
  return False
607
  return True
608
 
@@ -612,14 +569,13 @@ async def extract_bill_data(payload: BillRequest):
612
  doc_url = payload.document
613
  file_bytes = None
614
 
615
- # 1. local file support
616
  if doc_url.startswith("file://"):
617
  local_path = doc_url.replace("file://", "")
618
  try:
619
  with open(local_path, "rb") as f:
620
  file_bytes = f.read()
621
  except Exception as e:
622
- logger.error("Local file read error: %s", e)
623
  return {
624
  "is_success": False,
625
  "error": f"Local file read error: {e}",
@@ -634,7 +590,6 @@ async def extract_bill_data(payload: BillRequest):
634
  raise RuntimeError(f"Download failed status={resp.status_code}")
635
  file_bytes = resp.content
636
  except Exception as e:
637
- logger.error("HTTP download error: %s", e)
638
  return {
639
  "is_success": False,
640
  "error": f"HTTP error: {e}",
@@ -662,8 +617,7 @@ async def extract_bill_data(payload: BillRequest):
662
  images = convert_from_bytes(file_bytes)
663
  except Exception:
664
  images = []
665
- except Exception as e:
666
- logger.warning("Image conversion failed: %s", e)
667
  images = []
668
 
669
  pagewise = []
@@ -676,7 +630,7 @@ async def extract_bill_data(payload: BillRequest):
676
  rows = group_cells_into_rows(cells, y_tolerance=12)
677
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
678
 
679
- # header prefilter
680
  rows_filtered = []
681
  for i, (r, rt) in enumerate(zip(rows, rows_texts)):
682
  top_flag = (i < 6)
@@ -691,6 +645,7 @@ async def extract_bill_data(payload: BillRequest):
691
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
692
  page_text = sanitize_ocr_text(" ".join(rows_texts))
693
 
 
694
  top_headers = []
695
  for i, rt in enumerate(rows_texts[:6]):
696
  if looks_like_header_text(rt, top_of_page=(i < 4)):
@@ -698,26 +653,24 @@ async def extract_bill_data(payload: BillRequest):
698
 
699
  parsed_items = parse_rows_with_columns(rows, cells)
700
 
 
701
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
702
  for k in cumulative_token_usage:
703
  cumulative_token_usage[k] += token_u.get(k, 0)
704
 
705
- other_item_names = [it.get("item_name","") for it in refined_items]
706
-
707
  cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers, other_item_names=other_item_names)]
708
  cleaned = dedupe_items(cleaned)
709
- cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
710
 
711
  page_type = "Bill Detail"
712
  page_txt = page_text.lower()
713
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
714
  page_type = "Pharmacy"
715
- if "final bill" in page_txt or "grand total" in page_txt or "grandtotal" in page_txt:
716
  page_type = "Final Bill"
717
 
718
  pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
719
- except Exception as e:
720
- logger.exception("Failed to parse page %s: %s", idx, e)
721
  pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
722
  continue
723
 
@@ -725,7 +678,8 @@ async def extract_bill_data(payload: BillRequest):
725
  if not GEMINI_API_KEY or genai is None:
726
  cumulative_token_usage["warning_no_gemini"] = 1
727
 
728
- return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
 
729
 
730
  # ---------------- debug TSV ----------------
731
  @app.post("/debug-tsv")
@@ -750,19 +704,7 @@ async def debug_tsv(payload: BillRequest):
750
 
751
  @app.get("/")
752
  def health_check():
753
- msg = "Bill extraction API (final) live."
754
  if not GEMINI_API_KEY or genai is None:
755
- msg += " (No GEMINI_API_KEY/configured SDK LLM refinement skipped.)"
756
  return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}
757
-
758
- @app.get("/run-all-samples")
759
- async def run_all_samples():
760
- try:
761
- import run_all_samples
762
- run_all_samples.main()
763
- return {"status": "done", "results_ready": True}
764
- except Exception as e:
765
- logger.exception("run_all_samples failed: %s", e)
766
- return {"status": "error", "error": str(e)}
767
-
768
-
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
 
4
  from io import BytesIO
5
  from typing import List, Dict, Any, Optional, Tuple
6
 
7
+ from fastapi import FastAPI
 
8
  from pydantic import BaseModel
9
  import requests
10
  from PIL import Image
 
20
  except Exception:
21
  genai = None
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # ---------------- LLM CONFIG ----------------
24
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
25
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
26
  if GEMINI_API_KEY and genai is not None:
27
  try:
28
  genai.configure(api_key=GEMINI_API_KEY)
29
+ except Exception:
30
+ pass
31
+
32
+ # ---------------- FastAPI app ----------------
33
+ app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, improved)")
34
 
35
+ class BillRequest(BaseModel):
36
+ document: str
37
+
38
+ # ---------------- Regex and keywords ----------------
39
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
40
  TOTAL_KEYWORDS = re.compile(
41
+ r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
42
  re.I,
43
  )
44
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
45
+ HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
 
 
 
 
 
46
  HEADER_PHRASES = [
47
  "description qty / hrs consultation rate discount net amt",
48
  "description qty / hrs rate discount net amt",
 
61
  s = s.replace("\r\n", "\n").replace("\r", "\n")
62
  s = re.sub(r"[ \t]+", " ", s)
63
  s = s.strip()
64
+ # Correct common OCR mis-recognitions for headers
65
+ s = re.sub(r"\bqiy\b", "qty", s, flags=re.IGNORECASE)
66
+ s = re.sub(r"\bdeseription\b", "description", s, flags=re.IGNORECASE)
67
  return s[:4000]
68
 
69
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
 
91
  def is_numeric_token(t: Optional[str]) -> bool:
92
  return bool(t and NUM_RE.search(str(t)))
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def clean_name_text(s: str) -> str:
95
  s = s.replace("—", "-")
96
  s = re.sub(r"\s+", " ", s)
 
 
 
 
97
  s = s.strip(" -:,.")
98
+ s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
99
+ s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
100
+ s = re.sub(r"\bOR\b", "DR", s) # correct OCR 'OR' -> 'DR'
101
+ return s.strip()
102
 
103
  # ---------------- image preprocessing ----------------
104
  def pil_to_cv2(img: Image.Image) -> Any:
 
108
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
109
 
110
  def preprocess_image(pil_img: Image.Image) -> Any:
 
111
  pil_img = pil_img.convert("RGB")
112
  w, h = pil_img.size
113
  target_w = 1500
 
115
  scale = target_w / float(w)
116
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
117
  cv_img = pil_to_cv2(pil_img)
118
+ gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
 
 
 
 
119
  gray = cv2.fastNlMeansDenoising(gray, h=10)
120
  try:
121
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
 
128
 
129
  # ---------------- OCR TSV ----------------
130
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
 
131
  try:
132
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
133
  except Exception:
 
152
  center_y = top + height / 2.0
153
  center_x = left + width / 2.0
154
  cells.append({"text": txt, "conf": conf, "left": left, "top": top,
155
+ "width": width, "height": height,
156
+ "center_y": center_y, "center_x": center_x})
157
  return cells
158
 
159
+ # ---------------- grouping & merge helpers ----------------
160
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
161
  if not cells:
162
  return []
 
185
  row = rows[i]
186
  tokens = [c["text"] for c in row]
187
  has_num = any(is_numeric_token(t) for t in tokens)
188
+ # If row has no numbers but next row does, merge them into one line
189
  if not has_num and i + 1 < len(rows):
190
  next_row = rows[i+1]
191
  next_tokens = [c["text"] for c in next_row]
 
204
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
205
  i += 2
206
  continue
207
+ # Merge short text rows without numbers (split descriptions)
208
  if not has_num and i + 1 < len(rows):
209
  next_row = rows[i+1]
210
  next_tokens = [c["text"] for c in next_row]
 
215
  offset = 10
216
  for c in row + next_row:
217
  newc = c.copy()
218
+ newc["left"] = newc["left"] if newc["left"] > min_left else (min_left - offset)
 
 
 
219
  newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
220
  merged_row.append(newc)
221
  offset += 5
 
227
  return merged
228
 
229
  # ---------------- numeric column detection ----------------
230
+ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]:
231
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
232
  if not xs:
233
  return []
 
258
  distances = [abs(token_x - cx) for cx in column_centers]
259
  return int(np.argmin(distances))
260
 
261
+ # ---------------- parsing rows into items ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
263
  parsed_items = []
264
  rows = merge_multiline_names(rows)
265
+ column_centers = detect_numeric_columns(page_cells, max_columns=6)
266
 
267
  for row in rows:
268
  tokens = [c["text"] for c in row]
 
274
  if all(not is_numeric_token(t) for t in tokens):
275
  continue
276
 
277
+ # Collect numeric candidates in this row
278
  numeric_values = []
279
  for t in tokens:
280
  if is_numeric_token(t):
 
 
281
  v = normalize_num_str(t)
282
  if v is not None:
283
  numeric_values.append(float(v))
284
+ numeric_values = sorted(list({int(x) if float(x).is_integer() else x for x in numeric_values}), reverse=True)
285
 
286
  if column_centers:
287
  left_text_parts = []
288
  numeric_bucket_map = {i: [] for i in range(len(column_centers))}
289
  for c in row:
290
  t = c["text"]
291
+ cx = c["center_x"]
292
+ if is_numeric_token(t):
293
+ col_idx = assign_token_to_column(cx, column_centers)
294
  if col_idx is None:
295
  numeric_bucket_map[len(column_centers) - 1].append(t)
296
  else:
 
299
  left_text_parts.append(t)
300
  raw_name = " ".join(left_text_parts).strip()
301
  name = clean_name_text(raw_name) if raw_name else ""
 
302
  num_cols = len(column_centers)
303
+
304
  def get_bucket(idx):
305
  vals = numeric_bucket_map.get(idx, [])
306
  return vals[-1] if vals else None
307
 
308
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
309
+ rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
310
+ qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
311
 
312
  if amount is None:
313
  for t in reversed(tokens):
314
+ if is_numeric_token(t):
315
  amount = normalize_num_str(t)
316
  if amount is not None:
317
  break
318
 
319
+ # Infer rate and qty if needed
320
  if amount is not None and numeric_values:
321
  for cand in numeric_values:
322
  try:
 
341
  qty = float(r)
342
  break
343
 
344
+ # Fallback compute rate if needed
345
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
346
  try:
347
  candidate_rate = amount / qty
 
353
  if qty is None:
354
  qty = 1.0
355
 
356
+ # Normalize values
357
  try:
358
  amount = float(round(amount, 2))
359
+ except Exception:
360
  continue
361
  try:
362
  rate = float(round(rate, 2)) if rate is not None else 0.0
363
+ except Exception:
364
  rate = 0.0
365
  try:
366
  qty = float(qty)
367
+ except Exception:
368
  qty = 1.0
369
 
370
  parsed_items.append({
 
375
  })
376
 
377
  else:
378
+ numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
379
  if not numeric_idxs:
380
  continue
381
  last = numeric_idxs[-1]
 
392
  v = normalize_num_str(tokens[i])
393
  if v is not None:
394
  right_nums.append(float(v))
395
+ right_nums = sorted(list({int(x) if float(x).is_integer() else x for x in right_nums}), reverse=True)
396
 
397
  if len(right_nums) >= 2:
398
  cand = right_nums[1]
399
+ if float(cand) > 1 and float(cand) < float(amt):
400
  ratio = float(amt) / float(cand) if cand else None
401
  if ratio:
402
  r = round(ratio)
 
443
 
444
  def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
445
  subtotal = None; final = None
446
+ for rt in reversed(rows_texts):
447
  if not rt or rt.strip() == "":
448
  continue
449
  if TOTAL_KEYWORDS.search(rt):
 
453
  if v is None:
454
  continue
455
  if re.search(r"sub", rt, re.I):
456
+ if subtotal is None:
457
+ subtotal = float(round(v, 2))
458
  else:
459
+ if final is None:
460
+ final = float(round(v, 2))
461
  return {"subtotal": subtotal, "final_total": final}
462
 
463
+ # ---------------- Gemini refinement (deterministic) ----------------
464
+ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
465
+ zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
466
+ if not GEMINI_API_KEY or genai is None:
467
+ return page_items, zero_usage
468
+ try:
469
+ safe_text = sanitize_ocr_text(page_text)
470
+ system_prompt = (
471
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
472
+ "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
473
+ "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
474
+ )
475
+ user_prompt = (
476
+ f"page_text='''{safe_text}'''\n"
477
+ f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
478
+ "Example:\n"
479
+ "items = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0},\n"
480
+ " {'item_name':'Description Qty / Hrs Consultation Rate Discount Net Amt','item_amount':1950.0,'item_rate':1950.0,'item_quantity':1.0}]\n"
481
+ "=>\n"
482
+ "[{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n\n"
483
+ "Return only the cleaned JSON array of items."
484
+ )
485
+ model = genai.GenerativeModel(GEMINI_MODEL_NAME)
486
+ response = model.generate_content(
487
+ [
488
+ {"role": "system", "parts": [system_prompt]},
489
+ {"role": "user", "parts": [user_prompt]},
490
+ ],
491
+ temperature=0.0,
492
+ max_output_tokens=1000,
493
+ )
494
+ raw = response.text.strip()
495
+ if raw.startswith("```"):
496
+ raw = re.sub(r"^```[a-zA-Z]*", "", raw)
497
+ raw = re.sub(r"```$", "", raw).strip()
498
+ parsed = json.loads(raw)
499
+ if isinstance(parsed, list):
500
+ cleaned = []
501
+ for obj in parsed:
502
+ try:
503
+ cleaned.append({
504
+ "item_name": str(obj.get("item_name", "")).strip(),
505
+ "item_amount": float(obj.get("item_amount", 0.0)),
506
+ "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
507
+ "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
508
+ })
509
+ except Exception:
510
+ continue
511
+ return cleaned, zero_usage
512
+ return page_items, zero_usage
513
+ except Exception:
514
+ return page_items, zero_usage
515
+
516
+ # ---------------- header heuristics & final filter ----------------
517
  def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
518
  if not txt:
519
  return False
520
  t = re.sub(r"\s+", " ", txt.strip().lower())
521
+ # exact phrase blacklist
522
  if any(h == t for h in HEADER_PHRASES):
523
  return True
524
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
 
534
  return True
535
  if t.startswith("description") or t.startswith("qty") or t.startswith("qty /"):
536
  return True
 
 
537
  return False
538
 
539
  def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [], other_item_names: List[str] = []) -> bool:
 
541
  if not name:
542
  return False
543
  ln = name.lower()
544
+ # Remove if this item matches any known header text
 
 
 
545
  for h in known_page_headers:
546
  if h and h.strip() and h.strip().lower() in ln:
547
  return False
 
 
548
  if FOOTER_KEYWORDS.search(ln):
549
  return False
550
  if item.get("item_amount", 0) > 1_000_000:
551
  return False
552
  if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
553
  return False
554
+ # (Removed overly restrictive filters for generic terms to retain valid items)
555
+
556
+ # Drop items with non-positive amounts
557
+ if float(item.get("item_amount", 0)) <= 0.0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  return False
559
+ # Sanity check: discard if rate is absurdly higher than amount
560
+ rate = float(item.get("item_rate", 0) or 0)
561
+ amt = float(item.get("item_amount", 0) or 0)
562
+ if rate and rate > amt * 10 and amt < 10000:
563
  return False
564
  return True
565
 
 
569
  doc_url = payload.document
570
  file_bytes = None
571
 
572
+ # --------------------------- Local or remote file ---------------------------
573
  if doc_url.startswith("file://"):
574
  local_path = doc_url.replace("file://", "")
575
  try:
576
  with open(local_path, "rb") as f:
577
  file_bytes = f.read()
578
  except Exception as e:
 
579
  return {
580
  "is_success": False,
581
  "error": f"Local file read error: {e}",
 
590
  raise RuntimeError(f"Download failed status={resp.status_code}")
591
  file_bytes = resp.content
592
  except Exception as e:
 
593
  return {
594
  "is_success": False,
595
  "error": f"HTTP error: {e}",
 
617
  images = convert_from_bytes(file_bytes)
618
  except Exception:
619
  images = []
620
+ except Exception:
 
621
  images = []
622
 
623
  pagewise = []
 
630
  rows = group_cells_into_rows(cells, y_tolerance=12)
631
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
632
 
633
+ # === Header prefilter: remove header-like rows ===
634
  rows_filtered = []
635
  for i, (r, rt) in enumerate(zip(rows, rows_texts)):
636
  top_flag = (i < 6)
 
645
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
646
  page_text = sanitize_ocr_text(" ".join(rows_texts))
647
 
648
+ # Collect detected top headers for final filtering
649
  top_headers = []
650
  for i, rt in enumerate(rows_texts[:6]):
651
  if looks_like_header_text(rt, top_of_page=(i < 4)):
 
653
 
654
  parsed_items = parse_rows_with_columns(rows, cells)
655
 
656
+ # Gemini refinement (if enabled)
657
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
658
  for k in cumulative_token_usage:
659
  cumulative_token_usage[k] += token_u.get(k, 0)
660
 
661
+ other_item_names = [it.get("item_name", "") for it in refined_items]
 
662
  cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers, other_item_names=other_item_names)]
663
  cleaned = dedupe_items(cleaned)
 
664
 
665
  page_type = "Bill Detail"
666
  page_txt = page_text.lower()
667
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
668
  page_type = "Pharmacy"
669
+ if "final bill" in page_txt or "grand total" in page_txt:
670
  page_type = "Final Bill"
671
 
672
  pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
673
+ except Exception:
 
674
  pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
675
  continue
676
 
 
678
  if not GEMINI_API_KEY or genai is None:
679
  cumulative_token_usage["warning_no_gemini"] = 1
680
 
681
+ return {"is_success": True, "token_usage": cumulative_token_usage,
682
+ "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
683
 
684
  # ---------------- debug TSV ----------------
685
  @app.post("/debug-tsv")
 
704
 
705
  @app.get("/")
706
  def health_check():
707
+ msg = "Bill extraction API (updated version) live."
708
  if not GEMINI_API_KEY or genai is None:
709
+ msg += " (No GEMINI - LLM refinement skipped.)"
710
  return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}