Sathvik-kota commited on
Commit
56ab53e
·
verified ·
1 Parent(s): 2ad459f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +435 -244
app.py CHANGED
@@ -1,6 +1,15 @@
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
 
4
  from io import BytesIO
5
  from typing import List, Dict, Any, Optional, Tuple
6
 
@@ -9,10 +18,22 @@ from pydantic import BaseModel
9
  import requests
10
  from PIL import Image
11
  from pdf2image import convert_from_bytes
12
- import pytesseract
13
- from pytesseract import Output
14
  import numpy as np
15
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Optional: Google Gemini SDK (if available)
18
  try:
@@ -20,51 +41,85 @@ try:
20
  except Exception:
21
  genai = None
22
 
23
- # ---------------- LLM CONFIG ----------------
 
 
 
24
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
25
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
 
 
 
 
 
 
26
  if GEMINI_API_KEY and genai is not None:
27
  try:
28
  genai.configure(api_key=GEMINI_API_KEY)
29
- except Exception:
30
- pass
31
-
32
- # ---------------- FastAPI app ----------------
33
- app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, improved)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class BillRequest(BaseModel):
36
- document: str
37
 
38
- # ---------------- Regex and keywords ----------------
 
 
39
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
40
  TOTAL_KEYWORDS = re.compile(
41
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
42
  re.I,
43
  )
44
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
45
- HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
 
 
 
 
 
46
  HEADER_PHRASES = [
47
  "description qty / hrs consultation rate discount net amt",
48
  "description qty / hrs rate discount net amt",
49
- "description qty / hrs rate net amt",
50
- "description qty hrs rate discount net amt",
51
- "description qty / hrs rate discount net amt",
52
  ]
53
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
54
 
55
- # ---------------- small utilities ----------------
56
- def sanitize_ocr_text(s: str) -> str:
57
  if not s:
58
  return ""
59
  s = s.replace("\u2014", "-").replace("\u2013", "-")
60
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
61
  s = s.replace("\r\n", "\n").replace("\r", "\n")
62
  s = re.sub(r"[ \t]+", " ", s)
63
- s = s.strip()
64
- # Correct common OCR mis-recognitions for headers
65
- s = re.sub(r"\bqiy\b", "qty", s, flags=re.IGNORECASE)
66
- s = re.sub(r"\bdeseription\b", "description", s, flags=re.IGNORECASE)
67
- return s[:4000]
68
 
69
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
70
  if s is None:
@@ -81,7 +136,8 @@ def normalize_num_str(s: Optional[str]) -> Optional[float]:
81
  if s in ("", "-", "+"):
82
  return None
83
  try:
84
- return -float(s) if negative else float(s)
 
85
  except Exception:
86
  try:
87
  return float(s.replace(" ", ""))
@@ -94,28 +150,32 @@ def is_numeric_token(t: Optional[str]) -> bool:
94
  def clean_name_text(s: str) -> str:
95
  s = s.replace("—", "-")
96
  s = re.sub(r"\s+", " ", s)
97
- s = s.strip(" -:,.")
98
- s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
99
- s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
100
- s = re.sub(r"\bOR\b", "DR", s) # correct OCR 'OR' -> 'DR'
101
  return s.strip()
102
 
103
- # ---------------- image preprocessing ----------------
 
 
104
  def pil_to_cv2(img: Image.Image) -> Any:
105
  arr = np.array(img)
106
  if arr.ndim == 2:
107
  return arr
108
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
109
 
110
- def preprocess_image(pil_img: Image.Image) -> Any:
111
  pil_img = pil_img.convert("RGB")
112
  w, h = pil_img.size
113
- target_w = 1500
114
  if w < target_w:
115
  scale = target_w / float(w)
116
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
117
  cv_img = pil_to_cv2(pil_img)
118
- gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
 
 
 
119
  gray = cv2.fastNlMeansDenoising(gray, h=10)
120
  try:
121
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
@@ -126,10 +186,10 @@ def preprocess_image(pil_img: Image.Image) -> Any:
126
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
127
  return bw
128
 
129
- # ---------------- OCR TSV ----------------
130
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
 
131
  try:
132
- o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
133
  except Exception:
134
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
135
  cells = []
@@ -142,7 +202,8 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
142
  if not txt:
143
  continue
144
  try:
145
- conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
 
146
  except Exception:
147
  conf = -1.0
148
  left = int(o.get("left", [0])[i])
@@ -152,11 +213,9 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
152
  center_y = top + height / 2.0
153
  center_x = left + width / 2.0
154
  cells.append({"text": txt, "conf": conf, "left": left, "top": top,
155
- "width": width, "height": height,
156
- "center_y": center_y, "center_x": center_x})
157
  return cells
158
 
159
- # ---------------- grouping & merge helpers ----------------
160
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
161
  if not cells:
162
  return []
@@ -185,7 +244,6 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
185
  row = rows[i]
186
  tokens = [c["text"] for c in row]
187
  has_num = any(is_numeric_token(t) for t in tokens)
188
- # If row has no numbers but next row does, merge them into one line
189
  if not has_num and i + 1 < len(rows):
190
  next_row = rows[i+1]
191
  next_tokens = [c["text"] for c in next_row]
@@ -204,7 +262,6 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
204
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
205
  i += 2
206
  continue
207
- # Merge short text rows without numbers (split descriptions)
208
  if not has_num and i + 1 < len(rows):
209
  next_row = rows[i+1]
210
  next_tokens = [c["text"] for c in next_row]
@@ -215,7 +272,10 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
215
  offset = 10
216
  for c in row + next_row:
217
  newc = c.copy()
218
- newc["left"] = newc["left"] if newc["left"] > min_left else (min_left - offset)
 
 
 
219
  newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
220
  merged_row.append(newc)
221
  offset += 5
@@ -226,7 +286,6 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
226
  i += 1
227
  return merged
228
 
229
- # ---------------- numeric column detection ----------------
230
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]:
231
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
232
  if not xs:
@@ -237,7 +296,7 @@ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) ->
237
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
238
  mean_gap = float(np.mean(gaps))
239
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
240
- gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
241
  clusters = []
242
  curr = [xs[0]]
243
  for i, g in enumerate(gaps):
@@ -258,7 +317,9 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
258
  distances = [abs(token_x - cx) for cx in column_centers]
259
  return int(np.argmin(distances))
260
 
261
- # ---------------- parsing rows into items ----------------
 
 
262
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
263
  parsed_items = []
264
  rows = merge_multiline_names(rows)
@@ -271,10 +332,10 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
271
  joined_lower = " ".join(tokens).lower()
272
  if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
273
  continue
 
274
  if all(not is_numeric_token(t) for t in tokens):
275
  continue
276
 
277
- # Collect numeric candidates in this row
278
  numeric_values = []
279
  for t in tokens:
280
  if is_numeric_token(t):
@@ -300,11 +361,9 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
300
  raw_name = " ".join(left_text_parts).strip()
301
  name = clean_name_text(raw_name) if raw_name else ""
302
  num_cols = len(column_centers)
303
-
304
  def get_bucket(idx):
305
  vals = numeric_bucket_map.get(idx, [])
306
  return vals[-1] if vals else None
307
-
308
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
309
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
310
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
@@ -316,7 +375,7 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
316
  if amount is not None:
317
  break
318
 
319
- # Infer rate and qty if needed
320
  if amount is not None and numeric_values:
321
  for cand in numeric_values:
322
  try:
@@ -341,11 +400,10 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
341
  qty = float(r)
342
  break
343
 
344
- # Fallback compute rate if needed
345
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
346
  try:
347
  candidate_rate = amount / qty
348
- if candidate_rate >= 2:
349
  rate = candidate_rate
350
  except Exception:
351
  pass
@@ -353,7 +411,6 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
353
  if qty is None:
354
  qty = 1.0
355
 
356
- # Normalize values
357
  try:
358
  amount = float(round(amount, 2))
359
  except Exception:
@@ -373,7 +430,6 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
373
  "item_rate": rate if rate is not None else 0.0,
374
  "item_quantity": qty if qty is not None else 1.0,
375
  })
376
-
377
  else:
378
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
379
  if not numeric_idxs:
@@ -425,100 +481,24 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
425
  "item_rate": float(round(rate, 2)),
426
  "item_quantity": float(qty),
427
  })
428
-
429
  return parsed_items
430
 
431
- # ---------------- dedupe & totals ----------------
432
  def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
433
  seen = set()
434
  out = []
435
  for it in items:
436
- nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
437
- key = (nm[:120], round(float(it["item_amount"]), 2))
438
  if key in seen:
439
  continue
440
  seen.add(key)
441
  out.append(it)
442
  return out
443
 
444
- def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
445
- subtotal = None; final = None
446
- for rt in reversed(rows_texts):
447
- if not rt or rt.strip() == "":
448
- continue
449
- if TOTAL_KEYWORDS.search(rt):
450
- m = NUM_RE.search(rt)
451
- if m:
452
- v = normalize_num_str(m.group(0))
453
- if v is None:
454
- continue
455
- if re.search(r"sub", rt, re.I):
456
- if subtotal is None:
457
- subtotal = float(round(v, 2))
458
- else:
459
- if final is None:
460
- final = float(round(v, 2))
461
- return {"subtotal": subtotal, "final_total": final}
462
-
463
- # ---------------- Gemini refinement (deterministic) ----------------
464
- def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
465
- zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
466
- if not GEMINI_API_KEY or genai is None:
467
- return page_items, zero_usage
468
- try:
469
- safe_text = sanitize_ocr_text(page_text)
470
- system_prompt = (
471
- "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
472
- "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
473
- "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
474
- )
475
- user_prompt = (
476
- f"page_text='''{safe_text}'''\n"
477
- f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
478
- "Example:\n"
479
- "items = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0},\n"
480
- " {'item_name':'Description Qty / Hrs Consultation Rate Discount Net Amt','item_amount':1950.0,'item_rate':1950.0,'item_quantity':1.0}]\n"
481
- "=>\n"
482
- "[{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n\n"
483
- "Return only the cleaned JSON array of items."
484
- )
485
- model = genai.GenerativeModel(GEMINI_MODEL_NAME)
486
- response = model.generate_content(
487
- [
488
- {"role": "system", "parts": [system_prompt]},
489
- {"role": "user", "parts": [user_prompt]},
490
- ],
491
- temperature=0.0,
492
- max_output_tokens=1000,
493
- )
494
- raw = response.text.strip()
495
- if raw.startswith("```"):
496
- raw = re.sub(r"^```[a-zA-Z]*", "", raw)
497
- raw = re.sub(r"```$", "", raw).strip()
498
- parsed = json.loads(raw)
499
- if isinstance(parsed, list):
500
- cleaned = []
501
- for obj in parsed:
502
- try:
503
- cleaned.append({
504
- "item_name": str(obj.get("item_name", "")).strip(),
505
- "item_amount": float(obj.get("item_amount", 0.0)),
506
- "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
507
- "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
508
- })
509
- except Exception:
510
- continue
511
- return cleaned, zero_usage
512
- return page_items, zero_usage
513
- except Exception:
514
- return page_items, zero_usage
515
-
516
- # ---------------- header heuristics & final filter ----------------
517
  def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
518
  if not txt:
519
  return False
520
  t = re.sub(r"\s+", " ", txt.strip().lower())
521
- # exact phrase blacklist
522
  if any(h == t for h in HEADER_PHRASES):
523
  return True
524
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
@@ -536,101 +516,271 @@ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
536
  return True
537
  return False
538
 
539
- def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [], other_item_names: List[str] = []) -> bool:
540
  name = (item.get("item_name") or "").strip()
541
  if not name:
542
  return False
543
  ln = name.lower()
544
- # Remove if this item matches any known header text
545
  for h in known_page_headers:
546
  if h and h.strip() and h.strip().lower() in ln:
547
  return False
548
  if FOOTER_KEYWORDS.search(ln):
549
  return False
550
- if item.get("item_amount", 0) > 1_000_000:
551
- return False
552
- if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
553
  return False
554
- # (Removed overly restrictive filters for generic terms to retain valid items)
555
-
556
- # Drop items with non-positive amounts
557
- if float(item.get("item_amount", 0)) <= 0.0:
558
  return False
559
- # Sanity check: discard if rate is absurdly higher than amount
560
  rate = float(item.get("item_rate", 0) or 0)
561
- amt = float(item.get("item_amount", 0) or 0)
562
- if rate and rate > amt * 10 and amt < 10000:
563
  return False
564
  return True
565
 
566
- # ---------------- main endpoint ----------------
567
- @app.post("/extract-bill-data")
568
- async def extract_bill_data(payload: BillRequest):
569
- doc_url = payload.document
570
- file_bytes = None
571
-
572
- # --------------------------- Local or remote file ---------------------------
573
- if doc_url.startswith("file://"):
574
- local_path = doc_url.replace("file://", "")
575
- try:
576
- with open(local_path, "rb") as f:
577
- file_bytes = f.read()
578
- except Exception as e:
579
- return {
580
- "is_success": False,
581
- "error": f"Local file read error: {e}",
582
- "data": {"pagewise_line_items": [], "total_item_count": 0},
583
- "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
584
- }
585
- else:
586
- try:
587
- headers = {"User-Agent": "Mozilla/5.0"}
588
- resp = requests.get(doc_url, headers=headers, timeout=30)
589
- if resp.status_code != 200:
590
- raise RuntimeError(f"Download failed status={resp.status_code}")
591
- file_bytes = resp.content
592
- except Exception as e:
593
- return {
594
- "is_success": False,
595
- "error": f"HTTP error: {e}",
596
- "data": {"pagewise_line_items": [], "total_item_count": 0},
597
- "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
598
- }
599
-
600
- if not file_bytes:
601
- return {
602
- "is_success": False,
603
- "error": "No file bytes found.",
604
- "data": {"pagewise_line_items": [], "total_item_count": 0},
605
- "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
606
- }
607
-
608
- images = []
609
- clean_url = doc_url.split("?", 1)[0].lower()
610
  try:
611
- if clean_url.endswith(".pdf"):
612
- images = convert_from_bytes(file_bytes)
613
- elif any(clean_url.endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]):
614
- images = [Image.open(BytesIO(file_bytes))]
615
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  try:
617
- images = convert_from_bytes(file_bytes)
 
 
 
 
 
618
  except Exception:
619
- images = []
620
- except Exception:
621
- images = []
 
 
622
 
623
- pagewise = []
624
- cumulative_token_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
- for idx, page_img in enumerate(images, start=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  try:
628
- proc = preprocess_image(page_img)
629
  cells = image_to_tsv_cells(proc)
630
  rows = group_cells_into_rows(cells, y_tolerance=12)
631
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
632
-
633
- # === Header prefilter: remove header-like rows ===
634
  rows_filtered = []
635
  for i, (r, rt) in enumerate(zip(rows, rows_texts)):
636
  top_flag = (i < 6)
@@ -640,71 +790,112 @@ async def extract_bill_data(payload: BillRequest):
640
  if any(h in rt_norm for h in HEADER_PHRASES):
641
  continue
642
  rows_filtered.append(r)
643
-
644
  rows = rows_filtered
645
- rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
646
- page_text = sanitize_ocr_text(" ".join(rows_texts))
647
-
648
- # Collect detected top headers for final filtering
649
- top_headers = []
650
- for i, rt in enumerate(rows_texts[:6]):
651
- if looks_like_header_text(rt, top_of_page=(i < 4)):
652
- top_headers.append(rt.strip().lower())
653
-
654
  parsed_items = parse_rows_with_columns(rows, cells)
655
-
656
- # Gemini refinement (if enabled)
657
- refined_items, token_u = refine_with_gemini(parsed_items, page_text)
658
- for k in cumulative_token_usage:
659
- cumulative_token_usage[k] += token_u.get(k, 0)
660
-
661
- other_item_names = [it.get("item_name", "") for it in refined_items]
662
- cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers, other_item_names=other_item_names)]
663
  cleaned = dedupe_items(cleaned)
664
-
665
  page_type = "Bill Detail"
666
- page_txt = page_text.lower()
667
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
668
  page_type = "Pharmacy"
669
- if "final bill" in page_txt or "grand total" in page_txt:
670
- page_type = "Final Bill"
 
 
 
671
 
672
- pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
673
- except Exception:
674
- pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
675
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
- total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
678
  if not GEMINI_API_KEY or genai is None:
679
- cumulative_token_usage["warning_no_gemini"] = 1
680
 
681
- return {"is_success": True, "token_usage": cumulative_token_usage,
682
- "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
683
 
684
- # ---------------- debug TSV ----------------
 
 
685
  @app.post("/debug-tsv")
686
  async def debug_tsv(payload: BillRequest):
687
  doc_url = payload.document
688
  try:
689
- resp = requests.get(doc_url, timeout=20)
690
- if resp.status_code != 200:
691
- return {"error": "Download failed"}
692
- file_bytes = resp.content
693
- except Exception:
694
- return {"error": "Download failed"}
695
- clean_url = doc_url.split("?", 1)[0].lower()
696
- if clean_url.endswith(".pdf"):
 
 
 
697
  imgs = convert_from_bytes(file_bytes)
698
  img = imgs[0]
699
- else:
700
- img = Image.open(BytesIO(file_bytes))
701
- proc = preprocess_image(img)
 
 
 
702
  cells = image_to_tsv_cells(proc)
703
  return {"cells": cells}
704
 
705
  @app.get("/")
706
  def health_check():
707
- msg = "Bill extraction API (updated version) live."
708
  if not GEMINI_API_KEY or genai is None:
709
- msg += " (No GEMINI - LLM refinement skipped.)"
710
- return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}
 
1
+ # app.py
2
+ # High-accuracy bill extraction API with optional Amazon Textract / Google Vision + robust Tesseract fallback.
3
+ # Usage:
4
+ # export OCR_ENGINE=textract # or "vision" or "tesseract"
5
+ # export AWS_REGION=us-east-1 # required for Textract
6
+ # export GEMINI_API_KEY=... # optional
7
+ # uvicorn app:app --host 0.0.0.0 --port 8080
8
+
9
  import os
10
  import re
11
  import json
12
+ import logging
13
  from io import BytesIO
14
  from typing import List, Dict, Any, Optional, Tuple
15
 
 
18
  import requests
19
  from PIL import Image
20
  from pdf2image import convert_from_bytes
 
 
21
  import numpy as np
22
  import cv2
23
+ import pytesseract
24
+ from pytesseract import Output
25
+
26
+ # Optional libs (import lazily)
27
+ try:
28
+ import boto3
29
+ from botocore.exceptions import BotoCoreError, ClientError
30
+ except Exception:
31
+ boto3 = None
32
+
33
+ try:
34
+ from google.cloud import vision
35
+ except Exception:
36
+ vision = None
37
 
38
  # Optional: Google Gemini SDK (if available)
39
  try:
 
41
  except Exception:
42
  genai = None
43
 
44
+ # -------------------------------------------------------------------------
45
+ # Configuration and logging
46
+ # -------------------------------------------------------------------------
47
+ OCR_ENGINE = os.getenv("OCR_ENGINE", "textract").lower() # 'textract' | 'vision' | 'tesseract'
48
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
49
  GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
50
+ AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
51
+ TESSERACT_PSM = os.getenv("TESSERACT_PSM", "6") # page segmentation mode default
52
+
53
+ logging.basicConfig(level=logging.INFO)
54
+ logger = logging.getLogger("bill-extractor")
55
+
56
  if GEMINI_API_KEY and genai is not None:
57
  try:
58
  genai.configure(api_key=GEMINI_API_KEY)
59
+ logger.info("Gemini configured")
60
+ except Exception as e:
61
+ logger.warning("Gemini config failed: %s", e)
62
+
63
+ # Boto3 textract client (lazy init)
64
+ _textract_client = None
65
+ def textract_client():
66
+ global _textract_client
67
+ if _textract_client is None:
68
+ if boto3 is None:
69
+ raise RuntimeError("boto3 not installed but OCR_ENGINE=textract requested")
70
+ _textract_client = boto3.client("textract", region_name=AWS_REGION)
71
+ return _textract_client
72
+
73
+ # Google Vision client (lazy)
74
+ _vision_client = None
75
+ def vision_client():
76
+ global _vision_client
77
+ if _vision_client is None:
78
+ if vision is None:
79
+ raise RuntimeError("google-cloud-vision not installed but OCR_ENGINE=vision requested")
80
+ _vision_client = vision.ImageAnnotatorClient()
81
+ return _vision_client
82
+
83
+ # -------------------------------------------------------------------------
84
+ # Request model
85
+ # -------------------------------------------------------------------------
86
+ app = FastAPI(title="Bajaj Datathon - Bill Extractor (high-accuracy)")
87
 
88
  class BillRequest(BaseModel):
89
+ document: str # file://local_path or http(s) url
90
 
91
+ # -------------------------------------------------------------------------
92
+ # Helpers (numbers, cleaning, OCR preprocessing)
93
+ # -------------------------------------------------------------------------
94
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
95
  TOTAL_KEYWORDS = re.compile(
96
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
97
  re.I,
98
  )
99
  FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
100
+
101
+ HEADER_KEYWORDS = [
102
+ "description", "qty", "hrs", "rate", "discount", "net", "amt", "amount",
103
+ "consultation", "address", "sex", "age", "mobile", "patient", "category",
104
+ "doctor", "dr", "invoice", "bill", "subtotal", "total", "charges", "service"
105
+ ]
106
  HEADER_PHRASES = [
107
  "description qty / hrs consultation rate discount net amt",
108
  "description qty / hrs rate discount net amt",
 
 
 
109
  ]
110
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
111
 
112
+ def sanitize_ocr_text(s: Optional[str]) -> str:
 
113
  if not s:
114
  return ""
115
  s = s.replace("\u2014", "-").replace("\u2013", "-")
116
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
117
  s = s.replace("\r\n", "\n").replace("\r", "\n")
118
  s = re.sub(r"[ \t]+", " ", s)
119
+ # common OCR corrections
120
+ s = re.sub(r"\bqiy\b", "qty", s, flags=re.I)
121
+ s = re.sub(r"\bdeseription\b", "description", s, flags=re.I)
122
+ return s.strip()
 
123
 
124
  def normalize_num_str(s: Optional[str]) -> Optional[float]:
125
  if s is None:
 
136
  if s in ("", "-", "+"):
137
  return None
138
  try:
139
+ val = float(s)
140
+ return -val if negative else val
141
  except Exception:
142
  try:
143
  return float(s.replace(" ", ""))
 
150
  def clean_name_text(s: str) -> str:
151
  s = s.replace("—", "-")
152
  s = re.sub(r"\s+", " ", s)
153
+ s = s.strip(" -:,.=")
154
+ s = re.sub(r"\s+x$", "", s, flags=re.I)
155
+ s = re.sub(r"[\)\}\]]+$", "", s)
156
+ s = re.sub(r"\bOR\b", "DR", s) # OCR OR -> DR
157
  return s.strip()
158
 
159
+ # -------------------------------------------------------------------------
160
+ # Image preprocessing helpers (for Tesseract pipeline)
161
+ # -------------------------------------------------------------------------
162
  def pil_to_cv2(img: Image.Image) -> Any:
163
  arr = np.array(img)
164
  if arr.ndim == 2:
165
  return arr
166
  return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
167
 
168
+ def preprocess_image_for_tesseract(pil_img: Image.Image, target_w: int = 1500) -> Any:
169
  pil_img = pil_img.convert("RGB")
170
  w, h = pil_img.size
 
171
  if w < target_w:
172
  scale = target_w / float(w)
173
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
174
  cv_img = pil_to_cv2(pil_img)
175
+ if cv_img.ndim == 3:
176
+ gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
177
+ else:
178
+ gray = cv_img
179
  gray = cv2.fastNlMeansDenoising(gray, h=10)
180
  try:
181
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
 
186
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
187
  return bw
188
 
 
189
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
190
+ # returns list of OCR 'cells' compatible with your parsing pipeline
191
  try:
192
+ o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config=f"--psm {TESSERACT_PSM}")
193
  except Exception:
194
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
195
  cells = []
 
202
  if not txt:
203
  continue
204
  try:
205
+ conf_raw = o.get("conf", [])[i]
206
+ conf = float(conf_raw) if conf_raw not in (None, "", "-1") else -1.0
207
  except Exception:
208
  conf = -1.0
209
  left = int(o.get("left", [0])[i])
 
213
  center_y = top + height / 2.0
214
  center_x = left + width / 2.0
215
  cells.append({"text": txt, "conf": conf, "left": left, "top": top,
216
+ "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
217
  return cells
218
 
 
219
  def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
220
  if not cells:
221
  return []
 
244
  row = rows[i]
245
  tokens = [c["text"] for c in row]
246
  has_num = any(is_numeric_token(t) for t in tokens)
 
247
  if not has_num and i + 1 < len(rows):
248
  next_row = rows[i+1]
249
  next_tokens = [c["text"] for c in next_row]
 
262
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
263
  i += 2
264
  continue
 
265
  if not has_num and i + 1 < len(rows):
266
  next_row = rows[i+1]
267
  next_tokens = [c["text"] for c in next_row]
 
272
  offset = 10
273
  for c in row + next_row:
274
  newc = c.copy()
275
+ if newc["left"] > min_left:
276
+ newc["left"] = newc["left"]
277
+ else:
278
+ newc["left"] = min_left - offset
279
  newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
280
  merged_row.append(newc)
281
  offset += 5
 
286
  i += 1
287
  return merged
288
 
 
289
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]:
290
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
291
  if not xs:
 
296
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
297
  mean_gap = float(np.mean(gaps))
298
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
299
+ gap_thresh = max(28.0, mean_gap + 0.6 * std_gap)
300
  clusters = []
301
  curr = [xs[0]]
302
  for i, g in enumerate(gaps):
 
317
  distances = [abs(token_x - cx) for cx in column_centers]
318
  return int(np.argmin(distances))
319
 
320
+ # -------------------------------------------------------------------------
321
+ # Parsing pipeline (shared)
322
+ # -------------------------------------------------------------------------
323
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
324
  parsed_items = []
325
  rows = merge_multiline_names(rows)
 
332
  joined_lower = " ".join(tokens).lower()
333
  if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
334
  continue
335
+ # require some numeric token (date-only rows excluded later)
336
  if all(not is_numeric_token(t) for t in tokens):
337
  continue
338
 
 
339
  numeric_values = []
340
  for t in tokens:
341
  if is_numeric_token(t):
 
361
  raw_name = " ".join(left_text_parts).strip()
362
  name = clean_name_text(raw_name) if raw_name else ""
363
  num_cols = len(column_centers)
 
364
  def get_bucket(idx):
365
  vals = numeric_bucket_map.get(idx, [])
366
  return vals[-1] if vals else None
 
367
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
368
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
369
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
 
375
  if amount is not None:
376
  break
377
 
378
+ # infer rate and qty heuristics
379
  if amount is not None and numeric_values:
380
  for cand in numeric_values:
381
  try:
 
400
  qty = float(r)
401
  break
402
 
 
403
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
404
  try:
405
  candidate_rate = amount / qty
406
+ if candidate_rate >= 1.0:
407
  rate = candidate_rate
408
  except Exception:
409
  pass
 
411
  if qty is None:
412
  qty = 1.0
413
 
 
414
  try:
415
  amount = float(round(amount, 2))
416
  except Exception:
 
430
  "item_rate": rate if rate is not None else 0.0,
431
  "item_quantity": qty if qty is not None else 1.0,
432
  })
 
433
  else:
434
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
435
  if not numeric_idxs:
 
481
  "item_rate": float(round(rate, 2)),
482
  "item_quantity": float(qty),
483
  })
 
484
  return parsed_items
485
 
 
486
  def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
487
  seen = set()
488
  out = []
489
  for it in items:
490
+ nm = re.sub(r"\s+", " ", (it.get("item_name","") or "").lower()).strip()
491
+ key = (nm[:120], round(float(it.get("item_amount", 0) or 0), 2))
492
  if key in seen:
493
  continue
494
  seen.add(key)
495
  out.append(it)
496
  return out
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
499
  if not txt:
500
  return False
501
  t = re.sub(r"\s+", " ", txt.strip().lower())
 
502
  if any(h == t for h in HEADER_PHRASES):
503
  return True
504
  hits = sum(1 for k in HEADER_KEYWORDS if k in t)
 
516
  return True
517
  return False
518
 
519
+ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = []) -> bool:
520
  name = (item.get("item_name") or "").strip()
521
  if not name:
522
  return False
523
  ln = name.lower()
 
524
  for h in known_page_headers:
525
  if h and h.strip() and h.strip().lower() in ln:
526
  return False
527
  if FOOTER_KEYWORDS.search(ln):
528
  return False
529
+ amt = float(item.get("item_amount", 0) or 0)
530
+ if amt <= 0:
 
531
  return False
532
+ # sanity: weird giant amounts are likely OCR garbage
533
+ if amt > 10_000_000:
 
 
534
  return False
 
535
  rate = float(item.get("item_rate", 0) or 0)
536
+ if rate and rate > amt * 20 and amt < 10000:
 
537
  return False
538
  return True
539
 
540
+ # -------------------------------------------------------------------------
541
+ # Gemini refinement (deterministic, optional)
542
+ # -------------------------------------------------------------------------
543
+ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str,int]]:
544
+ zero_usage = {"total_tokens":0, "input_tokens":0, "output_tokens":0}
545
+ if not GEMINI_API_KEY or genai is None:
546
+ return page_items, zero_usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  try:
548
+ safe_text = sanitize_ocr_text(page_text)[:3000]
549
+ system_prompt = (
550
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
551
+ "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
552
+ "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
553
+ )
554
+ user_prompt = f"page_text='''{safe_text}'''\nitems={json.dumps(page_items, ensure_ascii=False)}\nReturn only the cleaned JSON array."
555
+ model = genai.GenerativeModel(GEMINI_MODEL_NAME)
556
+ response = model.generate_content(
557
+ [
558
+ {"role":"system","parts":[system_prompt]},
559
+ {"role":"user","parts":[user_prompt]},
560
+ ],
561
+ temperature=0.0,
562
+ max_output_tokens=1000,
563
+ )
564
+ raw = response.text.strip()
565
+ if raw.startswith("```"):
566
+ raw = re.sub(r"^```[a-zA-Z]*", "", raw)
567
+ raw = re.sub(r"```$", "", raw).strip()
568
+ parsed = json.loads(raw)
569
+ out = []
570
+ for obj in parsed:
571
  try:
572
+ out.append({
573
+ "item_name": str(obj.get("item_name","")).strip(),
574
+ "item_amount": float(obj.get("item_amount",0.0)),
575
+ "item_rate": float(obj.get("item_rate",0.0) or 0.0),
576
+ "item_quantity": float(obj.get("item_quantity",1.0) or 1.0),
577
+ })
578
  except Exception:
579
+ continue
580
+ return out, zero_usage
581
+ except Exception as e:
582
+ logger.warning("Gemini refine failed: %s", e)
583
+ return page_items, zero_usage
584
 
585
+ # -------------------------------------------------------------------------
586
+ # OCR engine implementations
587
+ # -------------------------------------------------------------------------
588
+ def ocr_with_textract(file_bytes: bytes) -> List[Dict[str, Any]]:
589
+ """
590
+ Use Amazon Textract AnalyzeExpense on each page image. Returns list of pages:
591
+ [{"page_no": "1", "page_type": "...", "bill_items": [...]}]
592
+ Note: Textract AnalyzeExpense returns structured expense/line-item data; we map it to our output.
593
+ """
594
+ pages_out = []
595
+ client = textract_client()
596
+
597
+ # Convert bytes to images and call AnalyzeExpense for each page (synchronous).
598
+ try:
599
+ images = convert_from_bytes(file_bytes)
600
+ except Exception as e:
601
+ logger.warning("Textract fallback: PDF->image conversion failed: %s", e)
602
+ return []
603
 
604
+ for idx, pil_img in enumerate(images, start=1):
605
+ bio = BytesIO()
606
+ pil_img.save(bio, format="JPEG", quality=90)
607
+ img_bytes = bio.getvalue()
608
+ try:
609
+ resp = client.analyze_expense(Document={'Bytes': img_bytes})
610
+ except (BotoCoreError, ClientError) as e:
611
+ logger.exception("Textract analyze_expense failed: %s", e)
612
+ pages_out.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
613
+ continue
614
+ # Parse Textract response
615
+ items = []
616
+ line_item_groups = resp.get("ExpenseDocuments", [])
617
+ if line_item_groups:
618
+ for doc in line_item_groups:
619
+ groups = doc.get("LineItemGroups", [])
620
+ for g in groups:
621
+ for li in g.get("LineItems", []):
622
+ # Each line item has LineItemExpenseFields list
623
+ name_parts = []
624
+ amount = None
625
+ rate = None
626
+ qty = None
627
+ for f in li.get("LineItemExpenseFields", []):
628
+ tname = f.get("Type", {}).get("Text", "") or ""
629
+ v = f.get("ValueDetection", {}).get("Text", "") or ""
630
+ txt_l = tname.lower()
631
+ if txt_l in ("item", "description", "item description", "service"):
632
+ name_parts.append(v)
633
+ elif txt_l in ("amount", "price", "total"):
634
+ maybe = normalize_num_str(v)
635
+ if maybe is not None:
636
+ amount = maybe
637
+ elif txt_l in ("quantity", "qty"):
638
+ maybe = normalize_num_str(v)
639
+ if maybe is not None:
640
+ qty = maybe
641
+ elif txt_l in ("rate", "unit price", "price per unit"):
642
+ maybe = normalize_num_str(v)
643
+ if maybe is not None:
644
+ rate = maybe
645
+ else:
646
+ # Heuristic: if value looks numeric and field name is empty, try assign
647
+ if is_numeric_token(v) and amount is None:
648
+ maybe = normalize_num_str(v)
649
+ if maybe is not None:
650
+ amount = maybe
651
+ elif v and not is_numeric_token(v):
652
+ name_parts.append(v)
653
+ name = " ".join(name_parts).strip() or "UNKNOWN"
654
+ # Post-process amount/rate/qty
655
+ if amount is None:
656
+ # try to find from summary fields
657
+ pass
658
+ if qty is None and rate is not None and amount is not None and rate != 0:
659
+ try:
660
+ qty = round(amount / rate, 2)
661
+ except Exception:
662
+ qty = 1.0
663
+ if qty is None:
664
+ qty = 1.0
665
+ if rate is None and qty and qty != 0 and amount is not None:
666
+ try:
667
+ rate = round(amount / qty, 2)
668
+ except Exception:
669
+ rate = 0.0
670
+ if amount is None:
671
+ amount = 0.0
672
+ items.append({
673
+ "item_name": clean_name_text(name),
674
+ "item_amount": float(round(amount, 2)),
675
+ "item_rate": float(round(rate or 0.0, 2)),
676
+ "item_quantity": float(qty or 1.0),
677
+ })
678
+ # Fallback: if Textract returned no structured line items, attempt to extract lines from Blocks
679
+ if not items:
680
+ # try to extract lines from DocumentMetadata / Blocks
681
+ blocks = resp.get("Blocks", [])
682
+ lines = []
683
+ for b in blocks:
684
+ if b.get("BlockType") == "LINE":
685
+ lines.append(b.get("Text", ""))
686
+ # naive fallback: group lines that contain numbers
687
+ for ln in lines:
688
+ tokens = ln.split()
689
+ numbers = [t for t in tokens if is_numeric_token(t)]
690
+ if numbers:
691
+ name = " ".join([t for t in tokens if not is_numeric_token(t)])
692
+ amount = None
693
+ for t in reversed(tokens):
694
+ if is_numeric_token(t):
695
+ v = normalize_num_str(t)
696
+ if v is not None:
697
+ amount = v
698
+ break
699
+ if amount:
700
+ items.append({
701
+ "item_name": clean_name_text(name or "UNKNOWN"),
702
+ "item_amount": float(round(amount, 2)),
703
+ "item_rate": 0.0,
704
+ "item_quantity": 1.0,
705
+ })
706
+ # Filter & dedupe
707
+ items = [it for it in items if final_item_filter(it, [])]
708
+ items = dedupe_items(items)
709
+ page_type = "Bill Detail"
710
+ items_text = " ".join([it["item_name"] for it in items]).lower()
711
+ if "pharmacy" in items_text or "tablet" in items_text or "medicine" in items_text:
712
+ page_type = "Pharmacy"
713
+ pages_out.append({"page_no": str(idx), "page_type": page_type, "bill_items": items})
714
+ return pages_out
715
+
716
+ def ocr_with_google_vision(file_bytes: bytes) -> List[Dict[str,Any]]:
717
+ """
718
+ Google Vision Document OCR pipeline. Returns parsed pages (same format).
719
+ """
720
+ client = vision_client()
721
+ pages_out = []
722
+ try:
723
+ images = convert_from_bytes(file_bytes)
724
+ except Exception as e:
725
+ logger.warning("Vision pipeline: PDF->image conversion failed: %s", e)
726
+ return []
727
+ for idx, pil_img in enumerate(images, start=1):
728
+ bio = BytesIO()
729
+ pil_img.save(bio, format="JPEG", quality=90)
730
+ content = bio.getvalue()
731
+ image = vision.Image(content=content)
732
+ resp = client.document_text_detection(image=image)
733
+ text = resp.full_text_annotation.text if resp.full_text_annotation else ""
734
+ # Build pseudo-cells from words using bounding boxes if available
735
+ cells = []
736
+ for page in (resp.full_text_annotation.pages or []):
737
+ for block in page.blocks:
738
+ for para in block.paragraphs:
739
+ for word in para.words:
740
+ word_text = "".join([sym.text for sym in word.symbols])
741
+ bbox = word.bounding_box
742
+ # compute approximate left/top/width/height
743
+ xs = [v.x for v in bbox.vertices]
744
+ ys = [v.y for v in bbox.vertices]
745
+ left = int(min(xs)) if xs else 0
746
+ top = int(min(ys)) if ys else 0
747
+ width = int(max(xs)-min(xs)) if xs else 0
748
+ height = int(max(ys)-min(ys)) if ys else 0
749
+ center_x = left + width/2.0
750
+ center_y = top + height/2.0
751
+ cells.append({"text": word_text, "conf": -1.0, "left": left, "top": top, "width": width, "height": height, "center_x": center_x, "center_y": center_y})
752
+ # row grouping + parse using shared functions
753
+ rows = group_cells_into_rows(cells, y_tolerance=14)
754
+ parsed_items = parse_rows_with_columns(rows, cells)
755
+ cleaned = [p for p in parsed_items if final_item_filter(p, [])]
756
+ cleaned = dedupe_items(cleaned)
757
+ page_type = "Bill Detail"
758
+ page_txt = text.lower()
759
+ if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
760
+ page_type = "Pharmacy"
761
+ pages_out.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
762
+ return pages_out
763
+
764
+ def ocr_with_tesseract(file_bytes: bytes) -> List[Dict[str,Any]]:
765
+ """Tesseract pipeline using your preprocessing + TSV + parsing functions."""
766
+ pages_out = []
767
+ try:
768
+ images = convert_from_bytes(file_bytes)
769
+ except Exception as e:
770
+ # maybe it's a single image format (jpg/png)
771
+ try:
772
+ im = Image.open(BytesIO(file_bytes))
773
+ images = [im]
774
+ except Exception:
775
+ logger.exception("Tesseract pipeline can't open file: %s", e)
776
+ return []
777
+ for idx, pil_img in enumerate(images, start=1):
778
  try:
779
+ proc = preprocess_image_for_tesseract(pil_img)
780
  cells = image_to_tsv_cells(proc)
781
  rows = group_cells_into_rows(cells, y_tolerance=12)
782
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
783
+ # header prefilter
 
784
  rows_filtered = []
785
  for i, (r, rt) in enumerate(zip(rows, rows_texts)):
786
  top_flag = (i < 6)
 
790
  if any(h in rt_norm for h in HEADER_PHRASES):
791
  continue
792
  rows_filtered.append(r)
 
793
  rows = rows_filtered
 
 
 
 
 
 
 
 
 
794
  parsed_items = parse_rows_with_columns(rows, cells)
795
+ refined_items, _ = refine_with_gemini(parsed_items, sanitize_ocr_text(" ".join(rows_texts)))
796
+ cleaned = [p for p in refined_items if final_item_filter(p, [])]
 
 
 
 
 
 
797
  cleaned = dedupe_items(cleaned)
 
798
  page_type = "Bill Detail"
799
+ page_txt = " ".join(rows_texts).lower()
800
  if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
801
  page_type = "Pharmacy"
802
+ pages_out.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
803
+ except Exception as e:
804
+ logger.exception("Tesseract parse page failed: %s", e)
805
+ pages_out.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
806
+ return pages_out
807
 
808
+ # -------------------------------------------------------------------------
809
+ # Main endpoint
810
+ # -------------------------------------------------------------------------
811
+ @app.post("/extract-bill-data")
812
+ async def extract_bill_data(payload: BillRequest):
813
+ doc_url = payload.document
814
+ file_bytes = None
815
+
816
+ # local file support
817
+ if doc_url.startswith("file://"):
818
+ local_path = doc_url.replace("file://", "")
819
+ try:
820
+ with open(local_path, "rb") as f:
821
+ file_bytes = f.read()
822
+ except Exception as e:
823
+ return {"is_success": False, "error": f"Local file read error: {e}",
824
+ "data": {"pagewise_line_items": [], "total_item_count": 0}, "token_usage": {"total_tokens":0,"input_tokens":0,"output_tokens":0}}
825
+ else:
826
+ try:
827
+ headers = {"User-Agent": "Mozilla/5.0"}
828
+ resp = requests.get(doc_url, headers=headers, timeout=30)
829
+ if resp.status_code != 200:
830
+ return {"is_success": False, "error": f"Download failed status={resp.status_code}",
831
+ "data": {"pagewise_line_items": [], "total_item_count": 0}, "token_usage": {"total_tokens":0,"input_tokens":0,"output_tokens":0}}
832
+ file_bytes = resp.content
833
+ except Exception as e:
834
+ return {"is_success": False, "error": f"HTTP error: {e}",
835
+ "data": {"pagewise_line_items": [], "total_item_count": 0}, "token_usage": {"total_tokens":0,"input_tokens":0,"output_tokens":0}}
836
+
837
+ if not file_bytes:
838
+ return {"is_success": False, "error": "No file bytes found", "data": {"pagewise_line_items": [], "total_item_count": 0}, "token_usage": {"total_tokens":0,"input_tokens":0,"output_tokens":0}}
839
+
840
+ pages = []
841
+ token_usage = {"total_tokens":0,"input_tokens":0,"output_tokens":0}
842
+ engine = OCR_ENGINE
843
+ logger.info("Using OCR engine: %s", engine)
844
+
845
+ try:
846
+ if engine == "textract":
847
+ pages = ocr_with_textract(file_bytes)
848
+ elif engine == "vision":
849
+ pages = ocr_with_google_vision(file_bytes)
850
+ else:
851
+ pages = ocr_with_tesseract(file_bytes)
852
+ except Exception as e:
853
+ logger.exception("OCR engine failed: %s", e)
854
+ # fallback to tesseract pipeline
855
+ try:
856
+ pages = ocr_with_tesseract(file_bytes)
857
+ except Exception as e:
858
+ logger.exception("Tesseract fallback also failed: %s", e)
859
+ pages = []
860
 
861
+ total_item_count = sum(len(p.get("bill_items", [])) for p in pages)
862
  if not GEMINI_API_KEY or genai is None:
863
+ token_usage["warning_no_gemini"] = 1
864
 
865
+ return {"is_success": True, "token_usage": token_usage, "data": {"pagewise_line_items": pages, "total_item_count": total_item_count}}
 
866
 
867
+ # -------------------------------------------------------------------------
868
+ # Debug endpoint to return tsv cell info for inspection
869
+ # -------------------------------------------------------------------------
870
  @app.post("/debug-tsv")
871
  async def debug_tsv(payload: BillRequest):
872
  doc_url = payload.document
873
  try:
874
+ if doc_url.startswith("file://"):
875
+ local_path = doc_url.replace("file://", "")
876
+ with open(local_path, "rb") as f:
877
+ file_bytes = f.read()
878
+ else:
879
+ resp = requests.get(doc_url, timeout=20)
880
+ resp.raise_for_status()
881
+ file_bytes = resp.content
882
+ except Exception as e:
883
+ return {"error": f"Download failed: {e}"}
884
+ try:
885
  imgs = convert_from_bytes(file_bytes)
886
  img = imgs[0]
887
+ except Exception:
888
+ try:
889
+ img = Image.open(BytesIO(file_bytes)).convert("RGB")
890
+ except Exception as e:
891
+ return {"error": f"Image conversion failed: {e}"}
892
+ proc = preprocess_image_for_tesseract(img)
893
  cells = image_to_tsv_cells(proc)
894
  return {"cells": cells}
895
 
896
  @app.get("/")
897
  def health_check():
898
+ msg = f"Bill extraction API live. OCR_ENGINE={OCR_ENGINE}"
899
  if not GEMINI_API_KEY or genai is None:
900
+ msg += " (Gemini not configured LLM refinement skipped.)"
901
+ return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url or file://path>'}"}