Sathvik-kota commited on
Commit
5ec4a93
·
verified ·
1 Parent(s): 1404047

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +451 -445
app.py CHANGED
@@ -1,8 +1,7 @@
1
- """
2
- Bajaj Finserv Datathon Bill Extraction Service
3
- Clean, modular and human-written version (Option A)
4
- Maintains your exact logic but reorganized for readability and robustness.
5
- """
6
 
7
  import os
8
  import re
@@ -10,542 +9,549 @@ import json
10
  from io import BytesIO
11
  from typing import List, Dict, Any, Optional, Tuple
12
 
13
- import cv2
14
- import numpy as np
15
  import requests
16
  from PIL import Image
17
  from pdf2image import convert_from_bytes
18
- from fastapi import FastAPI
19
- from pydantic import BaseModel
20
  import pytesseract
21
  from pytesseract import Output
22
- import google.generativeai as genai
 
23
 
 
 
 
 
 
24
 
25
- # -------------------------------------------------------
26
- # GEMINI CONFIG
27
- # -------------------------------------------------------
28
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
29
- GEMINI_MODEL = "gemini-2.5-flash"
30
-
31
- if GEMINI_API_KEY:
32
- genai.configure(api_key=GEMINI_API_KEY)
33
-
34
-
35
- # -------------------------------------------------------
36
- # FASTAPI APP
37
- # -------------------------------------------------------
38
- app = FastAPI(title="Bajaj Datathon - Bill Extractor (Clean vA)")
39
 
 
 
40
 
41
  class BillRequest(BaseModel):
42
  document: str
43
 
44
-
45
- # -------------------------------------------------------
46
- # REGEX + CONSTANTS
47
- # -------------------------------------------------------
48
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
49
-
50
- TOTAL_KEYS = re.compile(
51
- r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|"
52
- r"final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
53
- re.I
54
  )
55
-
56
- HEADER_HINT = re.compile(
57
- r"^(consultation|room|nursing|surgery|radiology|laboratory|charges|services|investigation|package|section)$",
58
- re.I
59
- )
60
-
61
- FOOTER_HINT = re.compile(r"(page|printed|date|time|am|pm|printed on)", re.I)
62
-
63
-
64
- # =======================================================
65
- # UTILITY HELPERS
66
- # =======================================================
67
-
68
- def normalize_number(raw: Optional[str]) -> Optional[float]:
69
- """Convert OCR number-like text into a clean float."""
70
- if not raw:
 
71
  return None
72
-
73
- text = re.sub(r"[^\d\-\+\,\.\(\)]", "", str(raw)).strip()
74
- if not text:
75
  return None
76
-
77
- # Handle negative (accounting) format: (150.00)
78
- is_negative = text.startswith("(") and text.endswith(")")
79
- if is_negative:
80
- text = text[1:-1]
81
-
82
- try:
83
- val = float(text.replace(",", ""))
84
- return -val if is_negative else val
85
- except:
86
  return None
87
-
88
-
89
- def is_numeric(text: str) -> bool:
90
- return bool(NUM_RE.search(str(text)))
91
-
92
-
93
- def clean_item_name(text: str) -> str:
94
- """Normalizes the left-side description of an item."""
95
- t = text.replace("—", "-")
96
- t = re.sub(r"\s+", " ", t)
97
- t = t.strip(" -:,.")
98
- t = re.sub(r"\bSG0?(\d+)\b", r"SG\1", t, flags=re.I)
99
- t = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", t, flags=re.I)
100
- return t.strip()
101
-
102
-
103
- # =======================================================
104
- # IMAGE PROCESSING
105
- # =======================================================
106
-
107
- def pil_to_cv(pil: Image.Image) -> np.ndarray:
108
- np_img = np.array(pil)
109
- return np_img if np_img.ndim == 2 else cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
110
-
111
-
112
- def preprocess_image(pil_img: Image.Image) -> np.ndarray:
113
- """Resize, denoise & binarize image to improve OCR accuracy."""
 
114
  pil_img = pil_img.convert("RGB")
115
  w, h = pil_img.size
116
-
117
- # Upscale if very small
118
- if w < 1500:
119
- scale = 1500 / float(w)
120
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
121
-
122
- img = pil_to_cv(pil_img)
123
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
124
  gray = cv2.fastNlMeansDenoising(gray, h=10)
125
-
126
  try:
127
- bw = cv2.adaptiveThreshold(
128
- gray, 255,
129
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
130
- cv2.THRESH_BINARY,
131
- 41, 15
132
- )
133
  except Exception:
134
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
135
-
136
- bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, np.ones((1, 1), np.uint8))
137
  return bw
138
 
139
-
140
- # =======================================================
141
- # OCR TSV PARSING
142
- # =======================================================
143
-
144
- def run_tesseract(cv_img: np.ndarray) -> List[Dict[str, Any]]:
145
- """Extracts word-level bounding boxes and confidence from image."""
146
  try:
147
- data = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
148
- except:
149
- data = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
150
-
151
  cells = []
152
- n = len(data["text"])
153
-
154
  for i in range(n):
155
- txt = str(data["text"][i]).strip()
 
 
 
156
  if not txt:
157
  continue
158
-
159
- conf = float(data["conf"][i]) if data["conf"][i] not in ("", "-1") else -1.0
160
-
161
- left = int(data["left"][i])
162
- top = int(data["top"][i])
163
- w = int(data["width"][i])
164
- h = int(data["height"][i])
165
-
166
- cells.append({
167
- "text": txt,
168
- "conf": conf,
169
- "left": left,
170
- "top": top,
171
- "width": w,
172
- "height": h,
173
- "center_x": left + w / 2,
174
- "center_y": top + h / 2,
175
- })
176
-
177
  return cells
178
 
179
-
180
- # =======================================================
181
- # ROW GROUPING + MERGING
182
- # =======================================================
183
-
184
- def group_cells(cells: List[Dict[str, Any]], tol: int = 12) -> List[List[Dict[str, Any]]]:
185
- """Groups words into horizontal text rows."""
186
  if not cells:
187
  return []
188
-
189
- cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
190
- rows, current = [], [cells[0]]
191
- last = cells[0]["center_y"]
192
-
193
- for c in cells[1:]:
194
- if abs(c["center_y"] - last) <= tol:
195
  current.append(c)
 
196
  else:
197
- rows.append(sorted(current, key=lambda x: x["left"]))
198
  current = [c]
199
- last = c["center_y"]
200
-
201
- rows.append(sorted(current, key=lambda x: x["left"]))
202
  return rows
203
 
204
-
205
- def merge_multiline_descriptions(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
206
- """
207
- Some items have description on one line and numbers on the next.
208
- This merges them into a single row.
209
- """
210
  if not rows:
211
  return rows
212
-
213
  merged = []
214
  i = 0
215
-
216
  while i < len(rows):
217
  row = rows[i]
218
  tokens = [c["text"] for c in row]
219
- row_has_num = any(is_numeric(t) for t in tokens)
220
-
221
- # If row is only text and next row is numeric: merge
222
- if not row_has_num and i + 1 < len(rows):
223
- next_row = rows[i + 1]
224
  next_tokens = [c["text"] for c in next_row]
225
-
226
- if any(is_numeric(t) for t in next_tokens):
227
- # prepend text row to numeric row
228
- new_row = []
229
-
230
- # push all text cells slightly left of next row
231
- base_left = min([c["left"] for c in next_row]) - 50
232
-
233
- offset = 0
234
- for cell in row:
235
- c = dict(cell)
236
- c["left"] = base_left + offset
237
- c["center_x"] = c["left"] + c["width"] / 2
238
- new_row.append(c)
239
- offset += 15
240
-
241
- new_row.extend(next_row)
242
- merged.append(sorted(new_row, key=lambda x: x["left"]))
243
  i += 2
244
  continue
245
-
246
  merged.append(row)
247
  i += 1
248
-
249
  return merged
250
 
251
-
252
- # =======================================================
253
- # COLUMN DETECTION
254
- # =======================================================
255
-
256
- def detect_column_centers(cells: List[Dict[str, Any]], max_cols=4) -> List[float]:
257
- xs = sorted([c["center_x"] for c in cells if is_numeric(c["text"])])
258
-
259
  if not xs:
260
  return []
261
-
262
  if len(xs) == 1:
263
- return xs
264
-
265
- gaps = [xs[i + 1] - xs[i] for i in range(len(xs) - 1)]
266
- gap_thresh = max(30, np.mean(gaps) + 0.6 * np.std(gaps))
267
-
268
  clusters = []
269
  curr = [xs[0]]
270
-
271
  for i, g in enumerate(gaps):
272
- if g > gap_thresh and len(clusters) < max_cols - 1:
273
  clusters.append(curr)
274
- curr = [xs[i + 1]]
275
  else:
276
- curr.append(xs[i + 1])
277
-
278
  clusters.append(curr)
279
- centers = sorted([np.median(c) for c in clusters])[:max_cols]
280
- return centers
 
 
281
 
282
-
283
- def nearest_column(x: float, centers: List[float]) -> int:
284
- distances = [abs(x - c) for c in centers]
 
285
  return int(np.argmin(distances))
286
 
287
-
288
- # =======================================================
289
- # ROW PARSER (MAIN LOGIC)
290
- # =======================================================
291
-
292
- def parse_rows(rows: List[List[Dict[str, Any]]], cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
293
- """Extract structured line items using detected columns."""
294
- items = []
295
-
296
- rows = merge_multiline_descriptions(rows)
297
- col_centers = detect_column_centers(cells, max_cols=4)
298
-
299
  for row in rows:
300
  tokens = [c["text"] for c in row]
301
-
302
  if not tokens:
303
  continue
304
-
305
- joined = " ".join(tokens).lower()
306
-
307
- # Skip footer lines like "Page 1/4"
308
- if FOOTER_HINT.search(joined) and not any(is_numeric(t) for t in tokens):
309
- continue
310
-
311
- # Skip headings that do not contain numbers
312
- if not any(is_numeric(t) for t in tokens):
313
- continue
314
-
315
- # --- Parse row using detected columns ---
316
- left_parts = []
317
- numeric_buckets = {i: [] for i in range(len(col_centers))}
318
-
319
- for c in row:
320
- t = c["text"]
321
- if is_numeric(t):
322
- col = nearest_column(c["center_x"], col_centers) if col_centers else len(col_centers) - 1
323
- numeric_buckets[col].append(t)
324
- else:
325
- left_parts.append(t)
326
-
327
- name = clean_item_name(" ".join(left_parts))
328
- num_cols = len(col_centers)
329
-
330
- # Extract numeric fields by column order (qty, rate, amount)
331
- def bucket(idx): return numeric_buckets.get(idx, [])[-1] if numeric_buckets.get(idx) else None
332
-
333
- amount = normalize_number(bucket(num_cols - 1))
334
- rate = normalize_number(bucket(num_cols - 2)) if num_cols >= 2 else None
335
- qty = normalize_number(bucket(num_cols - 3)) if num_cols >= 3 else None
336
-
337
- # Fallbacks
338
- if amount is None:
339
- for t in reversed(tokens):
340
- if is_numeric(t):
341
- amount = normalize_number(t)
342
- break
343
-
344
- if qty is None and amount and rate:
345
- q_est = amount / rate
346
- rounded = round(q_est)
347
- if abs(q_est - rounded) <= 0.2:
348
- qty = rounded
349
-
350
- if qty is None:
351
- qty = 1.0
352
-
353
- if (rate is None or rate == 0) and qty and amount:
354
- rate = round(amount / qty, 2)
355
-
356
- if amount is None or amount <= 0:
357
  continue
358
-
359
- if HEADER_HINT.search(name):
360
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
- items.append({
363
- "item_name": name or "UNKNOWN",
364
- "item_amount": float(round(amount, 2)),
365
- "item_rate": float(round(rate or 0.0, 2)),
366
- "item_quantity": float(qty)
367
- })
368
-
369
- return items
370
-
371
-
372
- # =======================================================
373
- # DEDUPE ITEMS + DETECT TOTALS
374
- # =======================================================
375
-
376
- def dedupe(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
377
  seen = set()
378
  out = []
379
-
380
  for it in items:
381
- key = (it["item_name"].lower()[:120], round(it["item_amount"], 2))
382
- if key not in seen:
383
- seen.add(key)
384
- out.append(it)
385
-
 
386
  return out
387
 
388
-
389
- # =======================================================
390
- # OPTIONAL: GEMINI REFINEMENT
391
- # =======================================================
392
-
393
- def refine_with_llm(items: List[Dict[str, Any]], text: str):
394
- """Uses Gemini only when inconsistencies are high."""
395
- if not GEMINI_API_KEY:
396
- return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
397
-
 
 
 
 
 
 
 
 
 
 
 
 
398
  try:
399
- prompt = (
400
- "You are a precise bill-item cleaner. Fix broken names, validate qty = amount/rate, "
401
- "and remove any invalid rows. Return JSON array only.\n\n"
402
- f"Full text: '''{text[:3000]}'''\n"
403
- f"Detected items: {json.dumps(items)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  )
405
-
406
- model = genai.GenerativeModel(GEMINI_MODEL)
407
- response = model.generate_content(prompt)
408
-
409
  raw = response.text.strip()
410
- raw = raw.replace("```json", "").replace("```", "")
 
 
411
  parsed = json.loads(raw)
412
-
413
- final_items = []
414
- for obj in parsed:
415
- final_items.append({
416
- "item_name": str(obj.get("item_name", "")).strip(),
417
- "item_amount": float(obj.get("item_amount", 0)),
418
- "item_rate": float(obj.get("item_rate", 0)),
419
- "item_quantity": float(obj.get("item_quantity", 1)),
420
- })
421
-
422
- return final_items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
423
-
424
- except:
425
- return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
426
-
427
-
428
- # =======================================================
429
- # MAIN API ENDPOINT
430
- # =======================================================
431
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  @app.post("/extract-bill-data")
433
  async def extract_bill_data(payload: BillRequest):
434
-
435
- # ---------------------------------------------------
436
- # 1. DOWNLOAD FILE
437
- # ---------------------------------------------------
438
  try:
439
- resp = requests.get(payload.document, headers={"User-Agent": "Mozilla"}, timeout=30)
440
- resp.raise_for_status()
441
- data_bytes = resp.content
442
- except:
443
- return {
444
- "is_success": False,
445
- "token_usage": {},
446
- "data": {"pagewise_line_items": [], "total_item_count": 0}
447
- }
448
-
449
- # ---------------------------------------------------
450
- # 2. LOAD PAGES (PDF / IMAGE)
451
- # ---------------------------------------------------
452
- pages = []
453
-
454
- url_no_query = payload.document.split("?", 1)[0].lower()
455
  try:
456
- if url_no_query.endswith(".pdf"):
457
- pages = convert_from_bytes(data_bytes)
 
 
458
  else:
459
- pages = [Image.open(BytesIO(data_bytes))]
460
- except:
461
- return {
462
- "is_success": False,
463
- "token_usage": {},
464
- "data": {"pagewise_line_items": [], "total_item_count": 0}
465
- }
466
-
467
- # ---------------------------------------------------
468
- # 3. PROCESS EACH PAGE
469
- # ---------------------------------------------------
470
- results = []
471
- gemini_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
472
-
473
- for idx, page in enumerate(pages, start=1):
474
- try:
475
- proc = preprocess_image(page)
476
- cells = run_tesseract(proc)
477
- rows = group_cells(cells)
478
-
479
- page_text = " ".join(" ".join(c["text"] for c in r) for r in rows).lower()
480
-
481
- items = parse_rows(rows, cells)
482
- items = dedupe(items)
483
-
484
- # decide whether to refine with LLM
485
- use_llm = False
486
- if GEMINI_API_KEY and len(items) > 0:
487
- inconsistent = sum(
488
- 1 for it in items
489
- if abs(it["item_quantity"] * it["item_rate"] - it["item_amount"]) > max(2, 0.03 * it["item_amount"])
490
- )
491
- if inconsistent > max(1, len(items) // 6):
492
- use_llm = True
493
-
494
- if use_llm:
495
- items, usage = refine_with_llm(items, page_text)
496
- for k in gemini_usage:
497
- gemini_usage[k] += usage[k]
498
-
499
- results.append({
500
- "page_no": str(idx),
501
- "page_type": "Bill Detail",
502
- "bill_items": items,
503
- })
504
-
505
- except Exception:
506
- results.append({
507
- "page_no": str(idx),
508
- "page_type": "Bill Detail",
509
- "bill_items": []
510
- })
511
 
512
- total_count = sum(len(p["bill_items"]) for p in results)
 
513
 
514
- return {
515
- "is_success": True,
516
- "token_usage": gemini_usage,
517
- "data": {
518
- "pagewise_line_items": results,
519
- "total_item_count": total_count
520
- }
521
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
 
 
 
 
523
 
524
- # -------------------------------------------------------
525
- # RAW TSV DEBUG ENDPOINT
526
- # -------------------------------------------------------
527
  @app.post("/debug-tsv")
528
  async def debug_tsv(payload: BillRequest):
 
529
  try:
530
- resp = requests.get(payload.document, timeout=20)
531
- resp.raise_for_status()
532
- data = resp.content
533
- except:
534
- return {"error": "Unable to download"}
535
-
536
- url = payload.document.split("?", 1)[0].lower()
537
-
538
- if url.endswith(".pdf"):
539
- img = convert_from_bytes(data)[0]
540
  else:
541
- img = Image.open(BytesIO(data))
542
-
543
  proc = preprocess_image(img)
544
- return {"cells": run_tesseract(proc)}
545
-
546
 
547
  @app.get("/")
548
- def root():
549
- return {"status": "ok", "message": "Bill extraction API running"}
550
-
551
-
 
 
1
+ # app_bill_extractor_final.py
2
+ # Humanized, high-accuracy bill extraction API.
3
+ # Combines robust OCR preprocessing, TSV-based layout parsing, numeric-column inference,
4
+ # and ALWAYS attempts Gemini refinement (if GEMINI_API_KEY set). Made compact & readable.
 
5
 
6
  import os
7
  import re
 
9
  from io import BytesIO
10
  from typing import List, Dict, Any, Optional, Tuple
11
 
12
+ from fastapi import FastAPI
13
+ from pydantic import BaseModel
14
  import requests
15
  from PIL import Image
16
  from pdf2image import convert_from_bytes
 
 
17
  import pytesseract
18
  from pytesseract import Output
19
+ import numpy as np
20
+ import cv2
21
 
22
+ # Optional: Google Gemini SDK (if you use it). Code will gracefully work without it.
23
+ try:
24
+ import google.generativeai as genai
25
+ except Exception:
26
+ genai = None
27
 
28
+ # ---------------- LLM CONFIG ----------------
 
 
29
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
30
+ GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.5-flash")
31
+ if GEMINI_API_KEY and genai is not None:
32
+ try:
33
+ genai.configure(api_key=GEMINI_API_KEY)
34
+ except Exception:
35
+ pass
 
 
 
 
36
 
37
+ # ---------------- FastAPI app ----------------
38
+ app = FastAPI(title="Bajaj Datathon - Bill Extractor (final, humanized)")
39
 
40
  class BillRequest(BaseModel):
41
  document: str
42
 
43
+ # ---------------- Regex, small utils ----------------
 
 
 
44
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
45
+ TOTAL_KEYWORDS = re.compile(
46
+ r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
47
+ re.I,
 
 
48
  )
49
+ FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
50
+ HEADER_KEYWORDS = ["description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", "consultation", "qty/hrs", "qty / hrs"]
51
+
52
+ # sanitize OCR text before ever sending to an LLM or using it for heuristics
53
+ def sanitize_ocr_text(s: str) -> str:
54
+ if not s:
55
+ return ""
56
+ # unify dashes and remove odd control characters
57
+ s = s.replace("\u2014", "-").replace("\u2013", "-")
58
+ s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
59
+ s = s.replace("\r\n", "\n").replace("\r", "\n")
60
+ s = re.sub(r"[ \t]+", " ", s)
61
+ s = s.strip()
62
+ return s[:4000]
63
+
64
+ def normalize_num_str(s: Optional[str]) -> Optional[float]:
65
+ if s is None:
66
  return None
67
+ s = str(s).strip()
68
+ if s == "":
 
69
  return None
70
+ s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
71
+ negative = False
72
+ if s.startswith("(") and s.endswith(")"):
73
+ negative = True
74
+ s = s[1:-1]
75
+ s = s.replace(",", "")
76
+ if s in ("", "-", "+"):
 
 
 
77
  return None
78
+ try:
79
+ return -float(s) if negative else float(s)
80
+ except Exception:
81
+ try:
82
+ return float(s.replace(" ", ""))
83
+ except Exception:
84
+ return None
85
+
86
+ def is_numeric_token(t: Optional[str]) -> bool:
87
+ return bool(t and NUM_RE.search(str(t)))
88
+
89
+ def clean_name_text(s: str) -> str:
90
+ s = s.replace("", "-")
91
+ s = re.sub(r"\s+", " ", s)
92
+ s = s.strip(" -:,.")
93
+ s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
94
+ s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
95
+ return s.strip()
96
+
97
+ # ---------------- image preprocessing ----------------
98
+ def pil_to_cv2(img: Image.Image) -> Any:
99
+ arr = np.array(img)
100
+ if arr.ndim == 2:
101
+ return arr
102
+ return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
103
+
104
+ def preprocess_image(pil_img: Image.Image) -> Any:
105
+ # quick, robust steps: upscale small images, grayscale, denoise, adaptive threshold
106
  pil_img = pil_img.convert("RGB")
107
  w, h = pil_img.size
108
+ target_w = 1500
109
+ if w < target_w:
110
+ scale = target_w / float(w)
 
111
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
112
+ cv_img = pil_to_cv2(pil_img)
113
+ gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
 
114
  gray = cv2.fastNlMeansDenoising(gray, h=10)
 
115
  try:
116
+ bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 41, 15)
 
 
 
 
 
117
  except Exception:
118
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
119
+ kernel = np.ones((1,1), np.uint8)
120
+ bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
121
  return bw
122
 
123
+ # ---------------- OCR TSV helpers ----------------
124
+ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
 
 
 
 
 
125
  try:
126
+ o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
127
+ except Exception:
128
+ o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
 
129
  cells = []
130
+ n = len(o.get("text", []))
 
131
  for i in range(n):
132
+ raw = o["text"][i]
133
+ if raw is None:
134
+ continue
135
+ txt = str(raw).strip()
136
  if not txt:
137
  continue
138
+ try:
139
+ conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
140
+ except Exception:
141
+ conf = -1.0
142
+ left = int(o.get("left", [0])[i])
143
+ top = int(o.get("top", [0])[i])
144
+ width = int(o.get("width", [0])[i])
145
+ height = int(o.get("height", [0])[i])
146
+ center_y = top + height / 2.0
147
+ center_x = left + width / 2.0
148
+ cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
 
 
 
 
 
 
 
149
  return cells
150
 
151
+ # ---------------- grouping & merging ----------------
152
+ def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
 
 
 
 
 
153
  if not cells:
154
  return []
155
+ sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
156
+ rows = []
157
+ current = [sorted_cells[0]]
158
+ last_y = sorted_cells[0]["center_y"]
159
+ for c in sorted_cells[1:]:
160
+ if abs(c["center_y"] - last_y) <= y_tolerance:
 
161
  current.append(c)
162
+ last_y = (last_y * (len(current) - 1) + c["center_y"]) / len(current)
163
  else:
164
+ rows.append(sorted(current, key=lambda cc: cc["left"]))
165
  current = [c]
166
+ last_y = c["center_y"]
167
+ if current:
168
+ rows.append(sorted(current, key=lambda cc: cc["left"]))
169
  return rows
170
 
171
+ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
 
 
 
 
 
172
  if not rows:
173
  return rows
 
174
  merged = []
175
  i = 0
 
176
  while i < len(rows):
177
  row = rows[i]
178
  tokens = [c["text"] for c in row]
179
+ has_num = any(is_numeric_token(t) for t in tokens)
180
+ if not has_num and i + 1 < len(rows):
181
+ next_row = rows[i+1]
 
 
182
  next_tokens = [c["text"] for c in next_row]
183
+ next_has_num = any(is_numeric_token(t) for t in next_tokens)
184
+ if next_has_num and len(tokens) >= 2 and len([t for t in next_tokens if not is_numeric_token(t)]) <= 3:
185
+ merged_row = []
186
+ min_left = min((c["left"] for c in next_row), default=0)
187
+ offset = 10
188
+ for c in row:
189
+ newc = c.copy()
190
+ newc["left"] = min_left - offset
191
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
192
+ merged_row.append(newc)
193
+ offset += 10
194
+ merged_row.extend(next_row)
195
+ merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
 
 
 
 
 
196
  i += 2
197
  continue
 
198
  merged.append(row)
199
  i += 1
 
200
  return merged
201
 
202
+ # ---------------- numeric column detection ----------------
203
+ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
204
+ xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
 
 
 
 
 
205
  if not xs:
206
  return []
207
+ xs = sorted(xs)
208
  if len(xs) == 1:
209
+ return [xs[0]]
210
+ gaps = [xs[i+1] - xs[i] for i in range(len(xs) - 1)]
211
+ mean_gap = float(np.mean(gaps))
212
+ std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
213
+ gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
214
  clusters = []
215
  curr = [xs[0]]
 
216
  for i, g in enumerate(gaps):
217
+ if g > gap_thresh and len(clusters) < (max_columns - 1):
218
  clusters.append(curr)
219
+ curr = [xs[i+1]]
220
  else:
221
+ curr.append(xs[i+1])
 
222
  clusters.append(curr)
223
+ centers = [float(np.median(c)) for c in clusters]
224
+ if len(centers) > max_columns:
225
+ centers = centers[-max_columns:]
226
+ return sorted(centers)
227
 
228
+ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]:
229
+ if not column_centers:
230
+ return None
231
+ distances = [abs(token_x - cx) for cx in column_centers]
232
  return int(np.argmin(distances))
233
 
234
+ # ---------------- parse rows into items ----------------
235
+ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
236
+ parsed_items = []
237
+ rows = merge_multiline_names(rows)
238
+ column_centers = detect_numeric_columns(page_cells, max_columns=4)
 
 
 
 
 
 
 
239
  for row in rows:
240
  tokens = [c["text"] for c in row]
 
241
  if not tokens:
242
  continue
243
+ joined_lower = " ".join(tokens).lower()
244
+ if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  continue
246
+ if all(not is_numeric_token(t) for t in tokens):
 
247
  continue
248
+ if column_centers:
249
+ left_text_parts = []
250
+ numeric_bucket_map = {i: [] for i in range(len(column_centers))}
251
+ for c in row:
252
+ t = c["text"]
253
+ cx = c["center_x"]
254
+ if is_numeric_token(t):
255
+ col_idx = assign_token_to_column(cx, column_centers)
256
+ if col_idx is None:
257
+ numeric_bucket_map[len(column_centers) - 1].append(t)
258
+ else:
259
+ numeric_bucket_map[col_idx].append(t)
260
+ else:
261
+ left_text_parts.append(t)
262
+ raw_name = " ".join(left_text_parts).strip()
263
+ name = clean_name_text(raw_name) if raw_name else ""
264
+ num_cols = len(column_centers)
265
+ def get_bucket(idx):
266
+ vals = numeric_bucket_map.get(idx, [])
267
+ return vals[-1] if vals else None
268
+ amount = None; rate = None; qty = None
269
+ if num_cols >= 1:
270
+ amount = normalize_num_str(get_bucket(num_cols - 1))
271
+ if num_cols >= 2:
272
+ rate = normalize_num_str(get_bucket(num_cols - 2))
273
+ if num_cols >= 3:
274
+ qty = normalize_num_str(get_bucket(num_cols - 3))
275
+ if amount is None:
276
+ for t in reversed(tokens):
277
+ if is_numeric_token(t):
278
+ amount = normalize_num_str(t)
279
+ break
280
+ if (qty is None or qty == 0) and amount is not None and rate:
281
+ ratio = amount / rate if rate else None
282
+ if ratio is not None:
283
+ rounded = round(ratio)
284
+ if rounded >= 1 and abs(ratio - rounded) <= max(0.04 * rounded, 0.2):
285
+ qty = float(rounded)
286
+ if qty is None:
287
+ for pt in reversed(left_text_parts):
288
+ m = re.match(r"^(\d+)(?:[xX])?$", pt)
289
+ if m:
290
+ qty = float(m.group(1))
291
+ break
292
+ if qty is None:
293
+ qty = 1.0
294
+ if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
295
+ rate = round(amount / qty, 2)
296
+ try:
297
+ amount = float(round(amount, 2)) if amount is not None else None
298
+ except Exception:
299
+ amount = None
300
+ try:
301
+ rate = float(round(rate, 2)) if rate is not None else 0.0
302
+ except Exception:
303
+ rate = 0.0
304
+ try:
305
+ qty = float(qty) if qty is not None else 1.0
306
+ except Exception:
307
+ qty = 1.0
308
+ if amount is None or amount == 0:
309
+ continue
310
+ parsed_items.append({
311
+ "item_name": name if name else "UNKNOWN",
312
+ "item_amount": float(round(amount, 2)),
313
+ "item_rate": float(round(rate, 2)) if rate else 0.0,
314
+ "item_quantity": float(qty) if qty else 1.0,
315
+ })
316
+ else:
317
+ numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
318
+ if not numeric_idxs:
319
+ continue
320
+ last = numeric_idxs[-1]
321
+ amt = normalize_num_str(tokens[last])
322
+ if amt is None:
323
+ continue
324
+ name = " ".join(tokens[:last]).strip()
325
+ if not name:
326
+ continue
327
+ rate = 0.0; qty = 1.0
328
+ if len(numeric_idxs) >= 2:
329
+ r = normalize_num_str(tokens[numeric_idxs[-2]])
330
+ rate = r if r is not None else 0.0
331
+ if len(numeric_idxs) >= 3:
332
+ q = normalize_num_str(tokens[numeric_idxs[-3]])
333
+ qty = q if q is not None else 1.0
334
+ parsed_items.append({
335
+ "item_name": clean_name_text(name),
336
+ "item_amount": float(round(amt, 2)),
337
+ "item_rate": float(round(rate, 2)),
338
+ "item_quantity": float(qty),
339
+ })
340
+ return parsed_items
341
 
342
+ # ---------------- dedupe & totals ----------------
343
+ def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  seen = set()
345
  out = []
 
346
  for it in items:
347
+ nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
348
+ key = (nm[:120], round(float(it["item_amount"]), 2))
349
+ if key in seen:
350
+ continue
351
+ seen.add(key)
352
+ out.append(it)
353
  return out
354
 
355
+ def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
356
+ subtotal = None; final = None
357
+ for rt in rows_texts[::-1]:
358
+ if not rt or rt.strip() == "":
359
+ continue
360
+ if TOTAL_KEYWORDS.search(rt):
361
+ m = NUM_RE.search(rt)
362
+ if m:
363
+ v = normalize_num_str(m.group(0))
364
+ if v is None:
365
+ continue
366
+ if re.search(r"sub", rt, re.I):
367
+ if subtotal is None: subtotal = float(round(v, 2))
368
+ else:
369
+ if final is None: final = float(round(v, 2))
370
+ return {"subtotal": subtotal, "final_total": final}
371
+
372
+ # ---------------- Gemini refinement (always attempted) ----------------
373
+ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
374
+ zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
375
+ if not GEMINI_API_KEY or genai is None:
376
+ return page_items, zero_usage
377
  try:
378
+ safe_text = sanitize_ocr_text(page_text)
379
+ system = (
380
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no text) of objects with keys "
381
+ "item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
382
+ "Do NOT return totals or subtotals as items. Do not invent items. Fix broken names and numeric mismatches."
383
+ )
384
+ # small few-shot example to anchor the model
385
+ few_shot = (
386
+ "# EXAMPLE\nitems = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0}]\n"
387
+ "=> [{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n"
388
+ )
389
+ prompt = f"page_text='''{safe_text}'''\nitems = {json.dumps(page_items, ensure_ascii=False)}\n\n{few_shot}\nReturn only a JSON array."
390
+ model = genai.GenerativeModel(GEMINI_MODEL_NAME)
391
+ response = model.generate_content(
392
+ [
393
+ {"role": "system", "parts": [system]},
394
+ {"role": "user", "parts": [prompt]},
395
+ ],
396
+ temperature=0.0,
397
+ max_output_tokens=1000,
398
  )
 
 
 
 
399
  raw = response.text.strip()
400
+ if raw.startswith("```"):
401
+ raw = re.sub(r"^```[a-zA-Z]*", "", raw)
402
+ raw = re.sub(r"```$", "", raw).strip()
403
  parsed = json.loads(raw)
404
+ if isinstance(parsed, list):
405
+ cleaned = []
406
+ for obj in parsed:
407
+ try:
408
+ cleaned.append({
409
+ "item_name": str(obj.get("item_name", "")).strip(),
410
+ "item_amount": float(obj.get("item_amount", 0.0)),
411
+ "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
412
+ "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
413
+ })
414
+ except Exception:
415
+ continue
416
+ return cleaned, zero_usage
417
+ return page_items, zero_usage
418
+ except Exception:
419
+ return page_items, zero_usage
420
+
421
+ # ---------------- header heuristics & final filter ----------------
422
+ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
423
+ if not txt:
424
+ return False
425
+ t = re.sub(r"\s+", " ", txt.strip().lower())
426
+ hits = sum(1 for k in HEADER_KEYWORDS if k in t)
427
+ if hits >= 2:
428
+ return True
429
+ tokens = re.split(r"[\s\|,/:]+", t)
430
+ key_hit_count = sum(1 for tok in tokens if tok in HEADER_KEYWORDS)
431
+ if key_hit_count >= 3:
432
+ return True
433
+ if top_of_page and len(tokens) <= 10 and key_hit_count >= 2:
434
+ return True
435
+ if ("rate" in t or "net" in t) and "amt" in t and not any(ch.isdigit() for ch in t):
436
+ return True
437
+ if t.startswith("description") or t.startswith("qty") or t.startswith("qty /"):
438
+ return True
439
+ return False
440
+
441
+
442
+ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = []) -> bool:
443
+ name = (item.get("item_name") or "").strip()
444
+ if not name:
445
+ return False
446
+ ln = name.lower()
447
+ for h in known_page_headers:
448
+ if h and h.strip() and h.strip().lower() in ln:
449
+ return False
450
+ if FOOTER_KEYWORDS.search(ln):
451
+ return False
452
+ if item.get("item_amount", 0) > 1_000_000:
453
+ return False
454
+ if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
455
+ return False
456
+ if re.fullmatch(r"(charge|charges|services|laboratory|lab|consultation)", ln.strip(), re.I):
457
+ return False
458
+ if float(item.get("item_amount", 0)) <= 0.0:
459
+ return False
460
+ rate = float(item.get("item_rate", 0) or 0)
461
+ amt = float(item.get("item_amount", 0) or 0)
462
+ if rate and rate > amt * 10 and amt < 10000:
463
+ return False
464
+ return True
465
+
466
+ # ---------------- main endpoint ----------------
467
  @app.post("/extract-bill-data")
468
  async def extract_bill_data(payload: BillRequest):
469
+ doc_url = payload.document
 
 
 
470
  try:
471
+ headers = {"User-Agent": "Mozilla/5.0"}
472
+ resp = requests.get(doc_url, headers=headers, timeout=30)
473
+ if resp.status_code != 200:
474
+ raise RuntimeError(f"download failed status={resp.status_code}")
475
+ file_bytes = resp.content
476
+ except Exception:
477
+ return {"is_success": False, "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, "data": {"pagewise_line_items": [], "total_item_count": 0}}
478
+
479
+ images = []
480
+ clean_url = doc_url.split("?", 1)[0].lower()
 
 
 
 
 
 
481
  try:
482
+ if clean_url.endswith(".pdf"):
483
+ images = convert_from_bytes(file_bytes)
484
+ elif any(clean_url.endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]):
485
+ images = [Image.open(BytesIO(file_bytes))]
486
  else:
487
+ try:
488
+ images = convert_from_bytes(file_bytes)
489
+ except Exception:
490
+ images = []
491
+ except Exception:
492
+ images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
+ pagewise = []
495
+ cumulative_token_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
496
 
497
+ for idx, page_img in enumerate(images, start=1):
498
+ try:
499
+ proc = preprocess_image(page_img)
500
+ cells = image_to_tsv_cells(proc)
501
+ rows = group_cells_into_rows(cells, y_tolerance=12)
502
+ rows_texts = [" ".join([c["text"] for c in r]) for r in rows]
503
+ top_headers = []
504
+ for i, rt in enumerate(rows_texts[:6]):
505
+ if looks_like_header_text(rt, top_of_page=(i < 4)):
506
+ top_headers.append(rt.strip().lower())
507
+ parsed_items = parse_rows_with_columns(rows, cells)
508
+ page_text = sanitize_ocr_text(" ".join(rows_texts))
509
+ refined_items, token_u = refine_with_gemini(parsed_items, page_text)
510
+ for k in cumulative_token_usage:
511
+ cumulative_token_usage[k] += token_u.get(k, 0)
512
+ cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers)]
513
+ cleaned = dedupe_items(cleaned)
514
+ cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
515
+ page_type = "Bill Detail"
516
+ page_txt = page_text.lower()
517
+ if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
518
+ page_type = "Pharmacy"
519
+ if "final bill" in page_txt or "grand total" in page_txt:
520
+ page_type = "Final Bill"
521
+ pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
522
+ except Exception:
523
+ pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
524
+ continue
525
 
526
+ total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
527
+ if not GEMINI_API_KEY or genai is None:
528
+ cumulative_token_usage["warning_no_gemini"] = 1
529
+ return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
530
 
531
+ # ---------------- debug TSV ----------------
 
 
532
  @app.post("/debug-tsv")
533
  async def debug_tsv(payload: BillRequest):
534
+ doc_url = payload.document
535
  try:
536
+ resp = requests.get(doc_url, timeout=20)
537
+ if resp.status_code != 200:
538
+ return {"error": "Download failed"}
539
+ file_bytes = resp.content
540
+ except Exception:
541
+ return {"error": "Download failed"}
542
+ clean_url = doc_url.split("?", 1)[0].lower()
543
+ if clean_url.endswith(".pdf"):
544
+ imgs = convert_from_bytes(file_bytes)
545
+ img = imgs[0]
546
  else:
547
+ img = Image.open(BytesIO(file_bytes))
 
548
  proc = preprocess_image(img)
549
+ cells = image_to_tsv_cells(proc)
550
+ return {"cells": cells}
551
 
552
  @app.get("/")
553
+ def health_check():
554
+ msg = "Bill extraction API (final) live."
555
+ if not GEMINI_API_KEY or genai is None:
556
+ msg += " (No GEMINI_API_KEY/configured SDK — LLM refinement skipped.)"
557
+ return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}