Sathvik-kota commited on
Commit
1404047
·
verified ·
1 Parent(s): 5ca19e8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +440 -411
app.py CHANGED
@@ -1,522 +1,551 @@
1
- # app.py (Final v2 — added multiline-name merging + qty inference)
 
 
 
 
 
2
  import os
3
  import re
4
  import json
5
  from io import BytesIO
6
  from typing import List, Dict, Any, Optional, Tuple
7
 
8
- from fastapi import FastAPI
9
- from pydantic import BaseModel
10
  import requests
11
  from PIL import Image
12
  from pdf2image import convert_from_bytes
 
 
13
  import pytesseract
14
  from pytesseract import Output
15
- import numpy as np
16
- import cv2
17
  import google.generativeai as genai
18
 
19
- # ---------------- LLM CONFIG (Gemini) ----------------
 
 
 
20
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
21
- GEMINI_MODEL_NAME = "gemini-2.5-flash"
 
22
  if GEMINI_API_KEY:
23
  genai.configure(api_key=GEMINI_API_KEY)
24
 
25
- # ---------------- FASTAPI APP ----------------
26
- app = FastAPI(title="Bajaj Datathon - Bill Extractor (v2)")
 
 
 
 
27
 
28
  class BillRequest(BaseModel):
29
  document: str
30
 
31
- # ---------------- Globals & regex ----------------
 
 
 
32
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
33
- TOTAL_KEYWORDS = re.compile(
34
- r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
 
 
35
  re.I
36
  )
37
- HEADER_LIKE = re.compile(r"^(consultation|room|nursing|surgery|radiology|laborat|laboratory|charges|services|investigation|package|section)$", re.I)
38
- FOOTER_KEYWORDS = re.compile(r"(page|printed|printed on|page\s*\d+|printed:|date:|time:|am|pm)", re.I)
39
 
40
- # ---------------- Utilities ----------------
41
- def normalize_num_str(s: Optional[str]) -> Optional[float]:
42
- if s is None:
43
- return None
44
- s = str(s).strip()
45
- if s == "":
 
 
 
 
 
 
 
 
 
46
  return None
47
- s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
48
- negative = False
49
- if s.startswith("(") and s.endswith(")"):
50
- negative = True
51
- s = s[1:-1]
52
- s = s.replace(",", "")
53
- if s == "" or s in ("-", "+"):
54
  return None
 
 
 
 
 
 
55
  try:
56
- return -float(s) if negative else float(s)
57
- except Exception:
58
- s2 = s.replace(" ", "")
59
- try:
60
- return float(s2)
61
- except Exception:
62
- return None
63
-
64
- def is_numeric_token(t: Optional[str]) -> bool:
65
- if not t:
66
- return False
67
- return bool(NUM_RE.search(str(t)))
68
-
69
- def clean_name_text(s: str) -> str:
70
- s = s.replace("—", "-")
71
- s = s.replace("|", "|")
72
- s = re.sub(r"\s+", " ", s)
73
- s = s.strip(" -:,.")
74
- s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
75
- s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
76
- s = s.strip()
77
- return s
78
-
79
- # ---------------- preprocessing ----------------
80
- def pil_to_cv2(img: Image.Image) -> Any:
81
- arr = np.array(img)
82
- if arr.ndim == 2:
83
- return arr
84
- return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
85
-
86
- def preprocess_image(pil_img: Image.Image) -> Any:
87
  pil_img = pil_img.convert("RGB")
88
  w, h = pil_img.size
89
- target_w = 1500
90
- if w < target_w:
91
- scale = target_w / float(w)
92
- new_size = (int(w * scale), int(h * scale))
93
- pil_img = pil_img.resize(new_size, Image.LANCZOS)
94
- cv_img = pil_to_cv2(pil_img)
95
- gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
 
96
  gray = cv2.fastNlMeansDenoising(gray, h=10)
 
97
  try:
98
- bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 41, 15)
 
 
 
 
 
99
  except Exception:
100
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
101
- kernel = np.ones((1,1), np.uint8)
102
- bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
103
  return bw
104
 
105
- # ---------------- OCR TSV ----------------
106
- def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
 
 
 
 
 
107
  try:
108
- o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
109
- except Exception:
110
- o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
 
111
  cells = []
112
- n = len(o.get("text", []))
 
113
  for i in range(n):
114
- raw = o["text"][i]
115
- if raw is None:
116
  continue
117
- txt = str(raw).strip()
118
- if txt == "":
119
- continue
120
- try:
121
- conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
122
- except Exception:
123
- conf = -1.0
124
- left = int(o.get("left", [0])[i])
125
- top = int(o.get("top", [0])[i])
126
- width = int(o.get("width", [0])[i])
127
- height = int(o.get("height", [0])[i])
128
- center_y = top + height/2.0
129
- center_x = left + width/2.0
130
- cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
 
 
 
 
131
  return cells
132
 
133
- # ---------------- group rows ----------------
134
- def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]:
 
 
 
 
 
135
  if not cells:
136
  return []
137
- sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
138
- rows = []
139
- current = [sorted_cells[0]]
140
- last_y = sorted_cells[0]["center_y"]
141
- for c in sorted_cells[1:]:
142
- if abs(c["center_y"] - last_y) <= y_tolerance:
 
143
  current.append(c)
144
- last_y = (last_y*(len(current)-1)+c["center_y"]) / len(current)
145
  else:
146
- rows.append(sorted(current, key=lambda cc: cc["left"]))
147
  current = [c]
148
- last_y = c["center_y"]
149
- if current:
150
- rows.append(sorted(current, key=lambda cc: cc["left"]))
151
  return rows
152
 
153
- # ---------------- merge multiline names ----------------
154
- def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
155
  """
156
- If a textual row (no numbers) is immediately followed by a numeric row with a short left-text,
157
- merge the text tokens into the numeric row to form full item_name.
158
  """
159
  if not rows:
160
  return rows
 
161
  merged = []
162
  i = 0
 
163
  while i < len(rows):
164
  row = rows[i]
165
  tokens = [c["text"] for c in row]
166
- has_num = any(is_numeric_token(t) for t in tokens)
167
- if not has_num:
168
- # candidate textual-only header or continuation check next row
169
- if i+1 < len(rows):
170
- next_row = rows[i+1]
171
- next_tokens = [c["text"] for c in next_row]
172
- next_has_num = any(is_numeric_token(t) for t in next_tokens)
173
- # if next has numbers and the textual row has decent length, merge
174
- if next_has_num and len(tokens) >= 2 and len([t for t in next_tokens if not is_numeric_token(t)]) <= 2:
175
- # merge: prepend current row tokens into next_row (maintain x order by creating fake cells)
176
- merged_row = []
177
- # create synthetic cells for tokens in row to be merged with left positions slightly left of next_row
178
- # We will just create dicts with left positions smaller than any in next_row to preserve order
179
- min_left = min((c["left"] for c in next_row), default=0)
180
- offset = 5
181
- # convert textual row cells to keep their original positions but ensure left ordering
182
- for c in row:
183
- # keep original center_x but set left to min_left - big offset to keep them on left
184
- newc = c.copy()
185
- newc["left"] = min_left - (offset)
186
- newc["center_x"] = newc["left"] + newc.get("width", 0)/2.0
187
- merged_row.append(newc)
188
- offset += 10
189
- # then append next_row cells
190
- merged_row.extend(next_row)
191
- merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
192
- i += 2
193
- continue
194
- # default append
195
  merged.append(row)
196
  i += 1
 
197
  return merged
198
 
199
- # ---------------- numeric column detection ----------------
200
- def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
201
- xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
 
 
 
 
 
202
  if not xs:
203
  return []
204
- xs = sorted(xs)
205
  if len(xs) == 1:
206
- return [xs[0]]
207
- gaps = [xs[i+1]-xs[i] for i in range(len(xs)-1)]
208
- mean_gap = float(np.mean(gaps))
209
- std_gap = float(np.std(gaps)) if len(gaps)>1 else 0.0
210
- gap_thresh = max(30.0, mean_gap + 0.6*std_gap)
211
  clusters = []
212
  curr = [xs[0]]
213
- for i,g in enumerate(gaps):
214
- if g > gap_thresh and len(clusters) < (max_columns-1):
 
215
  clusters.append(curr)
216
- curr = [xs[i+1]]
217
  else:
218
- curr.append(xs[i+1])
 
219
  clusters.append(curr)
220
- centers = [float(np.median(c)) for c in clusters]
221
- if len(centers) > max_columns:
222
- centers = centers[-max_columns:]
223
- return sorted(centers)
224
 
225
- def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]:
226
- if not column_centers:
227
- return None
228
- distances = [abs(token_x - cx) for cx in column_centers]
229
  return int(np.argmin(distances))
230
 
231
- # ---------------- main row parser with qty inference ----------------
232
- def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
233
- parsed_items = []
234
- # first merge multiline names (very important for your sample)
235
- rows = merge_multiline_names(rows)
236
- column_centers = detect_numeric_columns(page_cells, max_columns=4)
 
 
 
 
 
 
237
  for row in rows:
238
  tokens = [c["text"] for c in row]
 
239
  if not tokens:
240
  continue
241
- joined_lower = " ".join(tokens).lower()
242
- # skip footer-like
243
- if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
 
 
244
  continue
245
- # skip header-like row w/o numeric tokens
246
- if all(not is_numeric_token(t) for t in tokens):
 
247
  continue
248
- # parse using columns
249
- if column_centers:
250
- left_text_parts = []
251
- numeric_bucket_map = {i: [] for i in range(len(column_centers))}
252
- for c in row:
253
- t = c["text"]
254
- cx = c["center_x"]
255
- if is_numeric_token(t):
256
- col_idx = assign_token_to_column(cx, column_centers)
257
- if col_idx is None:
258
- numeric_bucket_map[len(column_centers)-1].append(t)
259
- else:
260
- numeric_bucket_map[col_idx].append(t)
261
- else:
262
- left_text_parts.append(t)
263
- raw_name = " ".join(left_text_parts).strip()
264
- name = clean_name_text(raw_name) if raw_name else ""
265
- num_cols = len(column_centers)
266
- def get_bucket(idx):
267
- vals = numeric_bucket_map.get(idx, [])
268
- return vals[-1] if vals else None
269
- amount = None; rate = None; qty = None
270
- if num_cols >= 1:
271
- amount = normalize_num_str(get_bucket(num_cols-1))
272
- if num_cols >= 2:
273
- rate = normalize_num_str(get_bucket(num_cols-2))
274
- if num_cols >= 3:
275
- qty = normalize_num_str(get_bucket(num_cols-3))
276
- # fallbacks
277
- if amount is None:
278
- for t in reversed(tokens):
279
- if is_numeric_token(t):
280
- amount = normalize_num_str(t); break
281
- # if qty missing but rate present, attempt qty = round(amount/rate) if close to integer
282
- if (qty is None or qty == 0) and amount is not None and rate:
283
- ratio = amount / rate if rate else None
284
- if ratio is not None:
285
- rounded = round(ratio)
286
- if rounded >= 1 and abs(ratio - rounded) <= max(0.04 * rounded, 0.2):
287
- qty = float(rounded)
288
- # if still missing qty, try scanning left_text_parts
289
- if qty is None:
290
- for pt in reversed(left_text_parts):
291
- m = re.match(r"^(\d+)(?:[xX])?$", pt)
292
- if m:
293
- qty = float(m.group(1)); break
294
- if qty is None:
295
- qty = 1.0
296
- # if rate missing but qty available and amount present, infer
297
- if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
298
- rate = round(amount / qty, 2)
299
- # normalize types
300
- try:
301
- amount = float(round(amount,2)) if amount is not None else None
302
- except Exception:
303
- amount = None
304
- try:
305
- rate = float(round(rate,2)) if rate is not None else 0.0
306
- except Exception:
307
- rate = 0.0
308
- try:
309
- qty = float(qty) if qty is not None else 1.0
310
- except Exception:
311
- qty = 1.0
312
- # skip if amount missing or zero (likely header)
313
- if amount is None or amount == 0:
314
- continue
315
- # skip header-like names
316
- if name and HEADER_LIKE.search(name) and (amount is None or amount == 0):
317
- continue
318
- parsed_items.append({
319
- "item_name": name if name else "UNKNOWN",
320
- "item_amount": float(round(amount,2)),
321
- "item_rate": float(round(rate,2)) if rate else 0.0,
322
- "item_quantity": float(qty) if qty else 1.0
323
- })
324
- else:
325
- # simple fallback
326
- numeric_idxs = [i for i,t in enumerate(tokens) if is_numeric_token(t)]
327
- if not numeric_idxs:
328
- continue
329
- last = numeric_idxs[-1]
330
- amt = normalize_num_str(tokens[last])
331
- if amt is None:
332
- continue
333
- name = " ".join(tokens[:last]).strip()
334
- if not name:
335
- continue
336
- rate = 0.0; qty = 1.0
337
- if len(numeric_idxs) >= 2:
338
- r = normalize_num_str(tokens[numeric_idxs[-2]])
339
- rate = r if r is not None else 0.0
340
- if len(numeric_idxs) >= 3:
341
- q = normalize_num_str(tokens[numeric_idxs[-3]])
342
- qty = q if q is not None else 1.0
343
- parsed_items.append({
344
- "item_name": clean_name_text(name),
345
- "item_amount": float(round(amt,2)),
346
- "item_rate": float(round(rate,2)),
347
- "item_quantity": float(qty)
348
- })
349
- return parsed_items
350
 
351
- # ---------------- dedupe & totals ----------------
352
- def dedupe_items(items: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  seen = set()
354
  out = []
 
355
  for it in items:
356
- nm = re.sub(r"\s+"," ", it["item_name"].lower()).strip()
357
- key = (nm[:120], round(float(it["item_amount"]),2))
358
- if key in seen:
359
- continue
360
- seen.add(key)
361
- out.append(it)
362
  return out
363
 
364
- def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str,Optional[float]]:
365
- subtotal=None; final=None
366
- for rt in rows_texts[::-1]:
367
- if not rt or rt.strip()=="":
368
- continue
369
- if TOTAL_KEYWORDS.search(rt):
370
- m = NUM_RE.search(rt)
371
- if m:
372
- v = normalize_num_str(m.group(0))
373
- if v is None:
374
- continue
375
- if re.search(r"sub", rt, re.I):
376
- if subtotal is None: subtotal = float(round(v,2))
377
- else:
378
- if final is None: final = float(round(v,2))
379
- return {"subtotal": subtotal, "final_total": final}
380
-
381
- # ---------------- Gemini refinement (optional) ----------------
382
- def refine_with_gemini(page_items: List[Dict[str,Any]], page_text: str="") -> Tuple[List[Dict[str,Any]], Dict[str,int]]:
383
- zero_usage = {"total_tokens":0,"input_tokens":0,"output_tokens":0}
384
  if not GEMINI_API_KEY:
385
- return page_items, zero_usage
 
386
  try:
387
  prompt = (
388
- "You are a precise bill extraction cleaner. Given items with item_name, item_quantity, item_rate, item_amount, "
389
- "fix broken names, infer quantity if qty missing by checking amount and rate, and remove header/footer rows. "
390
- "Return only a JSON array of cleaned items.\n\n"
391
- f"page_text='''{page_text[:4000]}'''\nitems = {json.dumps(page_items, ensure_ascii=False)}"
392
- )
393
- model = genai.GenerativeModel(GEMINI_MODEL_NAME)
394
- response = model.generate_content(
395
- [{"role":"system","parts":["Return only valid JSON array."]},{"role":"user","parts":[prompt]}]
396
  )
 
 
 
 
397
  raw = response.text.strip()
398
- if raw.startswith("```"):
399
- raw = re.sub(r"^```[a-zA-Z]*","", raw)
400
- raw = re.sub(r"```$","", raw).strip()
401
  parsed = json.loads(raw)
402
- if isinstance(parsed, list):
403
- cleaned=[]
404
- for obj in parsed:
405
- try:
406
- cleaned.append({
407
- "item_name": str(obj.get("item_name","")).strip(),
408
- "item_amount": float(obj.get("item_amount",0.0)),
409
- "item_rate": float(obj.get("item_rate",0.0) or 0.0),
410
- "item_quantity": float(obj.get("item_quantity",1.0) or 1.0)
411
- })
412
- except Exception:
413
- continue
414
- return cleaned, zero_usage
415
- return page_items, zero_usage
416
- except Exception:
417
- return page_items, zero_usage
418
 
419
- # ---------------- main endpoint ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  @app.post("/extract-bill-data")
421
  async def extract_bill_data(payload: BillRequest):
422
- doc_url = payload.document
423
- # download
 
 
424
  try:
425
- headers={"User-Agent":"Mozilla/5.0"}
426
- resp = requests.get(doc_url, headers=headers, timeout=30)
427
- if resp.status_code != 200:
428
- raise RuntimeError(f"download failed status={resp.status_code}")
429
- file_bytes = resp.content
430
- except Exception:
431
- return {"is_success":False,"token_usage":{"total_tokens":0,"input_tokens":0,"output_tokens":0},"data":{"pagewise_line_items":[],"total_item_count":0}}
432
- images=[]
433
- clean_url = doc_url.split("?",1)[0].lower()
 
 
 
 
 
 
 
434
  try:
435
- if clean_url.endswith(".pdf"):
436
- images = convert_from_bytes(file_bytes)
437
- elif any(clean_url.endswith(ext) for ext in [".png",".jpg",".jpeg",".tiff",".bmp"]):
438
- images = [Image.open(BytesIO(file_bytes))]
439
  else:
440
- try:
441
- images = convert_from_bytes(file_bytes)
442
- except Exception:
443
- images=[]
444
- except Exception:
445
- images=[]
446
- pagewise=[]
447
- cumulative_token_usage={"total_tokens":0,"input_tokens":0,"output_tokens":0}
448
- for idx,page_img in enumerate(images, start=1):
 
 
 
 
 
 
449
  try:
450
- proc = preprocess_image(page_img)
451
- cells = image_to_tsv_cells(proc)
452
- rows = group_cells_into_rows(cells, y_tolerance=12)
453
- rows_texts = [" ".join([c["text"] for c in r]) for r in rows]
454
- totals = detect_subtotals_and_totals(rows_texts)
455
- parsed_items = parse_rows_with_columns(rows, cells)
456
- parsed_items = [p for p in parsed_items if not TOTAL_KEYWORDS.search(p.get("item_name",""))]
457
- parsed_items = dedupe_items(parsed_items)
458
- # call LLM if many inconsistencies or user requested GEMINI
459
- call_llm = False
460
- if GEMINI_API_KEY and parsed_items:
461
- inconsistent = sum(1 for it in parsed_items if abs(it["item_quantity"]*it["item_rate"] - it["item_amount"]) > max(2.0, 0.03*(it["item_amount"] or 1.0)))
462
- if inconsistent > max(1, len(parsed_items)//6) or len(parsed_items) > 18:
463
- call_llm = True
464
- if call_llm:
465
- page_text = " ".join(rows_texts)
466
- refined, token_u = refine_with_gemini(parsed_items, page_text)
467
- parsed_items = refined
468
- for k in cumulative_token_usage:
469
- cumulative_token_usage[k] += token_u.get(k,0)
470
- # final filter remove headers/footers
471
- final=[]
472
- for it in parsed_items:
473
- nm = it.get("item_name","")
474
- if not nm or HEADER_LIKE.search(nm) or FOOTER_KEYWORDS.search(nm):
475
- continue
476
- if re.search(r"page\s+of|printed\s+on|printed:", nm, re.I):
477
- continue
478
- if float(it.get("item_amount",0)) <= 0:
479
- continue
480
- final.append(it)
481
- page_type="Bill Detail"
482
- page_txt = " ".join(rows_texts).lower()
483
- if any(x in page_txt for x in ["pharmacy","medicine","tablet"]):
484
- page_type="Pharmacy"
485
- if "final bill" in page_txt or "grand total" in page_txt:
486
- page_type="Final Bill"
487
- pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": final})
488
  except Exception:
489
- pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
490
- continue
491
- total_item_count = sum(len(p.get("bill_items",[])) for p in pagewise)
492
- return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
 
 
 
494
  @app.post("/debug-tsv")
495
  async def debug_tsv(payload: BillRequest):
496
- doc_url = payload.document
497
  try:
498
- resp = requests.get(doc_url, timeout=20)
499
- if resp.status_code != 200:
500
- return {"error": "Download failed"}
501
- file_bytes = resp.content
502
  except:
503
- return {"error": "Download failed"}
504
 
505
- # Get image
506
- clean_url = doc_url.split("?",1)[0].lower()
507
- if clean_url.endswith(".pdf"):
508
- imgs = convert_from_bytes(file_bytes)
509
- img = imgs[0]
510
  else:
511
- img = Image.open(BytesIO(file_bytes))
512
 
513
  proc = preprocess_image(img)
514
- cells = image_to_tsv_cells(proc)
515
-
516
- # return raw OCR cells for debugging
517
- return {"cells": cells}
518
 
519
 
520
  @app.get("/")
521
- def health_check():
522
- return {"status":"ok","message":"Bill extraction API (v2) live.","hint":"POST /extract-bill-data with {'document':'<url>'}"}
 
 
 
1
+ """
2
+ Bajaj Finserv Datathon – Bill Extraction Service
3
+ Clean, modular and human-written version (Option A)
4
+ Maintains your exact logic but reorganized for readability and robustness.
5
+ """
6
+
7
  import os
8
  import re
9
  import json
10
  from io import BytesIO
11
  from typing import List, Dict, Any, Optional, Tuple
12
 
13
+ import cv2
14
+ import numpy as np
15
  import requests
16
  from PIL import Image
17
  from pdf2image import convert_from_bytes
18
+ from fastapi import FastAPI
19
+ from pydantic import BaseModel
20
  import pytesseract
21
  from pytesseract import Output
 
 
22
  import google.generativeai as genai
23
 
24
+
25
+ # -------------------------------------------------------
26
+ # GEMINI CONFIG
27
+ # -------------------------------------------------------
28
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
29
+ GEMINI_MODEL = "gemini-2.5-flash"
30
+
31
  if GEMINI_API_KEY:
32
  genai.configure(api_key=GEMINI_API_KEY)
33
 
34
+
35
+ # -------------------------------------------------------
36
+ # FASTAPI APP
37
+ # -------------------------------------------------------
38
+ app = FastAPI(title="Bajaj Datathon - Bill Extractor (Clean vA)")
39
+
40
 
41
  class BillRequest(BaseModel):
42
  document: str
43
 
44
+
45
+ # -------------------------------------------------------
46
+ # REGEX + CONSTANTS
47
+ # -------------------------------------------------------
48
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
49
+
50
+ TOTAL_KEYS = re.compile(
51
+ r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|"
52
+ r"final\s*amount|balance\s*due|sub\s*total|subtotal|round\s*off)",
53
  re.I
54
  )
 
 
55
 
56
+ HEADER_HINT = re.compile(
57
+ r"^(consultation|room|nursing|surgery|radiology|laboratory|charges|services|investigation|package|section)$",
58
+ re.I
59
+ )
60
+
61
+ FOOTER_HINT = re.compile(r"(page|printed|date|time|am|pm|printed on)", re.I)
62
+
63
+
64
+ # =======================================================
65
+ # UTILITY HELPERS
66
+ # =======================================================
67
+
68
+ def normalize_number(raw: Optional[str]) -> Optional[float]:
69
+ """Convert OCR number-like text into a clean float."""
70
+ if not raw:
71
  return None
72
+
73
+ text = re.sub(r"[^\d\-\+\,\.\(\)]", "", str(raw)).strip()
74
+ if not text:
 
 
 
 
75
  return None
76
+
77
+ # Handle negative (accounting) format: (150.00)
78
+ is_negative = text.startswith("(") and text.endswith(")")
79
+ if is_negative:
80
+ text = text[1:-1]
81
+
82
  try:
83
+ val = float(text.replace(",", ""))
84
+ return -val if is_negative else val
85
+ except:
86
+ return None
87
+
88
+
89
+ def is_numeric(text: str) -> bool:
90
+ return bool(NUM_RE.search(str(text)))
91
+
92
+
93
+ def clean_item_name(text: str) -> str:
94
+ """Normalizes the left-side description of an item."""
95
+ t = text.replace("—", "-")
96
+ t = re.sub(r"\s+", " ", t)
97
+ t = t.strip(" -:,.")
98
+ t = re.sub(r"\bSG0?(\d+)\b", r"SG\1", t, flags=re.I)
99
+ t = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", t, flags=re.I)
100
+ return t.strip()
101
+
102
+
103
+ # =======================================================
104
+ # IMAGE PROCESSING
105
+ # =======================================================
106
+
107
+ def pil_to_cv(pil: Image.Image) -> np.ndarray:
108
+ np_img = np.array(pil)
109
+ return np_img if np_img.ndim == 2 else cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
110
+
111
+
112
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
113
+ """Resize, denoise & binarize image to improve OCR accuracy."""
114
  pil_img = pil_img.convert("RGB")
115
  w, h = pil_img.size
116
+
117
+ # Upscale if very small
118
+ if w < 1500:
119
+ scale = 1500 / float(w)
120
+ pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
121
+
122
+ img = pil_to_cv(pil_img)
123
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
124
  gray = cv2.fastNlMeansDenoising(gray, h=10)
125
+
126
  try:
127
+ bw = cv2.adaptiveThreshold(
128
+ gray, 255,
129
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
130
+ cv2.THRESH_BINARY,
131
+ 41, 15
132
+ )
133
  except Exception:
134
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
135
+
136
+ bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, np.ones((1, 1), np.uint8))
137
  return bw
138
 
139
+
140
+ # =======================================================
141
+ # OCR TSV PARSING
142
+ # =======================================================
143
+
144
+ def run_tesseract(cv_img: np.ndarray) -> List[Dict[str, Any]]:
145
+ """Extracts word-level bounding boxes and confidence from image."""
146
  try:
147
+ data = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
148
+ except:
149
+ data = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
150
+
151
  cells = []
152
+ n = len(data["text"])
153
+
154
  for i in range(n):
155
+ txt = str(data["text"][i]).strip()
156
+ if not txt:
157
  continue
158
+
159
+ conf = float(data["conf"][i]) if data["conf"][i] not in ("", "-1") else -1.0
160
+
161
+ left = int(data["left"][i])
162
+ top = int(data["top"][i])
163
+ w = int(data["width"][i])
164
+ h = int(data["height"][i])
165
+
166
+ cells.append({
167
+ "text": txt,
168
+ "conf": conf,
169
+ "left": left,
170
+ "top": top,
171
+ "width": w,
172
+ "height": h,
173
+ "center_x": left + w / 2,
174
+ "center_y": top + h / 2,
175
+ })
176
+
177
  return cells
178
 
179
+
180
+ # =======================================================
181
+ # ROW GROUPING + MERGING
182
+ # =======================================================
183
+
184
+ def group_cells(cells: List[Dict[str, Any]], tol: int = 12) -> List[List[Dict[str, Any]]]:
185
+ """Groups words into horizontal text rows."""
186
  if not cells:
187
  return []
188
+
189
+ cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
190
+ rows, current = [], [cells[0]]
191
+ last = cells[0]["center_y"]
192
+
193
+ for c in cells[1:]:
194
+ if abs(c["center_y"] - last) <= tol:
195
  current.append(c)
 
196
  else:
197
+ rows.append(sorted(current, key=lambda x: x["left"]))
198
  current = [c]
199
+ last = c["center_y"]
200
+
201
+ rows.append(sorted(current, key=lambda x: x["left"]))
202
  return rows
203
 
204
+
205
+ def merge_multiline_descriptions(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
206
  """
207
+ Some items have description on one line and numbers on the next.
208
+ This merges them into a single row.
209
  """
210
  if not rows:
211
  return rows
212
+
213
  merged = []
214
  i = 0
215
+
216
  while i < len(rows):
217
  row = rows[i]
218
  tokens = [c["text"] for c in row]
219
+ row_has_num = any(is_numeric(t) for t in tokens)
220
+
221
+ # If row is only text and next row is numeric: merge
222
+ if not row_has_num and i + 1 < len(rows):
223
+ next_row = rows[i + 1]
224
+ next_tokens = [c["text"] for c in next_row]
225
+
226
+ if any(is_numeric(t) for t in next_tokens):
227
+ # prepend text row to numeric row
228
+ new_row = []
229
+
230
+ # push all text cells slightly left of next row
231
+ base_left = min([c["left"] for c in next_row]) - 50
232
+
233
+ offset = 0
234
+ for cell in row:
235
+ c = dict(cell)
236
+ c["left"] = base_left + offset
237
+ c["center_x"] = c["left"] + c["width"] / 2
238
+ new_row.append(c)
239
+ offset += 15
240
+
241
+ new_row.extend(next_row)
242
+ merged.append(sorted(new_row, key=lambda x: x["left"]))
243
+ i += 2
244
+ continue
245
+
 
 
246
  merged.append(row)
247
  i += 1
248
+
249
  return merged
250
 
251
+
252
+ # =======================================================
253
+ # COLUMN DETECTION
254
+ # =======================================================
255
+
256
+ def detect_column_centers(cells: List[Dict[str, Any]], max_cols=4) -> List[float]:
257
+ xs = sorted([c["center_x"] for c in cells if is_numeric(c["text"])])
258
+
259
  if not xs:
260
  return []
261
+
262
  if len(xs) == 1:
263
+ return xs
264
+
265
+ gaps = [xs[i + 1] - xs[i] for i in range(len(xs) - 1)]
266
+ gap_thresh = max(30, np.mean(gaps) + 0.6 * np.std(gaps))
267
+
268
  clusters = []
269
  curr = [xs[0]]
270
+
271
+ for i, g in enumerate(gaps):
272
+ if g > gap_thresh and len(clusters) < max_cols - 1:
273
  clusters.append(curr)
274
+ curr = [xs[i + 1]]
275
  else:
276
+ curr.append(xs[i + 1])
277
+
278
  clusters.append(curr)
279
+ centers = sorted([np.median(c) for c in clusters])[:max_cols]
280
+ return centers
 
 
281
 
282
+
283
+ def nearest_column(x: float, centers: List[float]) -> int:
284
+ distances = [abs(x - c) for c in centers]
 
285
  return int(np.argmin(distances))
286
 
287
+
288
+ # =======================================================
289
+ # ROW PARSER (MAIN LOGIC)
290
+ # =======================================================
291
+
292
+ def parse_rows(rows: List[List[Dict[str, Any]]], cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
293
+ """Extract structured line items using detected columns."""
294
+ items = []
295
+
296
+ rows = merge_multiline_descriptions(rows)
297
+ col_centers = detect_column_centers(cells, max_cols=4)
298
+
299
  for row in rows:
300
  tokens = [c["text"] for c in row]
301
+
302
  if not tokens:
303
  continue
304
+
305
+ joined = " ".join(tokens).lower()
306
+
307
+ # Skip footer lines like "Page 1/4"
308
+ if FOOTER_HINT.search(joined) and not any(is_numeric(t) for t in tokens):
309
  continue
310
+
311
+ # Skip headings that do not contain numbers
312
+ if not any(is_numeric(t) for t in tokens):
313
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ # --- Parse row using detected columns ---
316
+ left_parts = []
317
+ numeric_buckets = {i: [] for i in range(len(col_centers))}
318
+
319
+ for c in row:
320
+ t = c["text"]
321
+ if is_numeric(t):
322
+ col = nearest_column(c["center_x"], col_centers) if col_centers else len(col_centers) - 1
323
+ numeric_buckets[col].append(t)
324
+ else:
325
+ left_parts.append(t)
326
+
327
+ name = clean_item_name(" ".join(left_parts))
328
+ num_cols = len(col_centers)
329
+
330
+ # Extract numeric fields by column order (qty, rate, amount)
331
+ def bucket(idx): return numeric_buckets.get(idx, [])[-1] if numeric_buckets.get(idx) else None
332
+
333
+ amount = normalize_number(bucket(num_cols - 1))
334
+ rate = normalize_number(bucket(num_cols - 2)) if num_cols >= 2 else None
335
+ qty = normalize_number(bucket(num_cols - 3)) if num_cols >= 3 else None
336
+
337
+ # Fallbacks
338
+ if amount is None:
339
+ for t in reversed(tokens):
340
+ if is_numeric(t):
341
+ amount = normalize_number(t)
342
+ break
343
+
344
+ if qty is None and amount and rate:
345
+ q_est = amount / rate
346
+ rounded = round(q_est)
347
+ if abs(q_est - rounded) <= 0.2:
348
+ qty = rounded
349
+
350
+ if qty is None:
351
+ qty = 1.0
352
+
353
+ if (rate is None or rate == 0) and qty and amount:
354
+ rate = round(amount / qty, 2)
355
+
356
+ if amount is None or amount <= 0:
357
+ continue
358
+
359
+ if HEADER_HINT.search(name):
360
+ continue
361
+
362
+ items.append({
363
+ "item_name": name or "UNKNOWN",
364
+ "item_amount": float(round(amount, 2)),
365
+ "item_rate": float(round(rate or 0.0, 2)),
366
+ "item_quantity": float(qty)
367
+ })
368
+
369
+ return items
370
+
371
+
372
+ # =======================================================
373
+ # DEDUPE ITEMS + DETECT TOTALS
374
+ # =======================================================
375
+
376
+ def dedupe(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
377
  seen = set()
378
  out = []
379
+
380
  for it in items:
381
+ key = (it["item_name"].lower()[:120], round(it["item_amount"], 2))
382
+ if key not in seen:
383
+ seen.add(key)
384
+ out.append(it)
385
+
 
386
  return out
387
 
388
+
389
+ # =======================================================
390
+ # OPTIONAL: GEMINI REFINEMENT
391
+ # =======================================================
392
+
393
+ def refine_with_llm(items: List[Dict[str, Any]], text: str):
394
+ """Uses Gemini only when inconsistencies are high."""
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  if not GEMINI_API_KEY:
396
+ return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
397
+
398
  try:
399
  prompt = (
400
+ "You are a precise bill-item cleaner. Fix broken names, validate qty = amount/rate, "
401
+ "and remove any invalid rows. Return JSON array only.\n\n"
402
+ f"Full text: '''{text[:3000]}'''\n"
403
+ f"Detected items: {json.dumps(items)}"
 
 
 
 
404
  )
405
+
406
+ model = genai.GenerativeModel(GEMINI_MODEL)
407
+ response = model.generate_content(prompt)
408
+
409
  raw = response.text.strip()
410
+ raw = raw.replace("```json", "").replace("```", "")
 
 
411
  parsed = json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
+ final_items = []
414
+ for obj in parsed:
415
+ final_items.append({
416
+ "item_name": str(obj.get("item_name", "")).strip(),
417
+ "item_amount": float(obj.get("item_amount", 0)),
418
+ "item_rate": float(obj.get("item_rate", 0)),
419
+ "item_quantity": float(obj.get("item_quantity", 1)),
420
+ })
421
+
422
+ return final_items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
423
+
424
+ except:
425
+ return items, {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
426
+
427
+
428
+ # =======================================================
429
+ # MAIN API ENDPOINT
430
+ # =======================================================
431
+
432
  @app.post("/extract-bill-data")
433
  async def extract_bill_data(payload: BillRequest):
434
+
435
+ # ---------------------------------------------------
436
+ # 1. DOWNLOAD FILE
437
+ # ---------------------------------------------------
438
  try:
439
+ resp = requests.get(payload.document, headers={"User-Agent": "Mozilla"}, timeout=30)
440
+ resp.raise_for_status()
441
+ data_bytes = resp.content
442
+ except:
443
+ return {
444
+ "is_success": False,
445
+ "token_usage": {},
446
+ "data": {"pagewise_line_items": [], "total_item_count": 0}
447
+ }
448
+
449
+ # ---------------------------------------------------
450
+ # 2. LOAD PAGES (PDF / IMAGE)
451
+ # ---------------------------------------------------
452
+ pages = []
453
+
454
+ url_no_query = payload.document.split("?", 1)[0].lower()
455
  try:
456
+ if url_no_query.endswith(".pdf"):
457
+ pages = convert_from_bytes(data_bytes)
 
 
458
  else:
459
+ pages = [Image.open(BytesIO(data_bytes))]
460
+ except:
461
+ return {
462
+ "is_success": False,
463
+ "token_usage": {},
464
+ "data": {"pagewise_line_items": [], "total_item_count": 0}
465
+ }
466
+
467
+ # ---------------------------------------------------
468
+ # 3. PROCESS EACH PAGE
469
+ # ---------------------------------------------------
470
+ results = []
471
+ gemini_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
472
+
473
+ for idx, page in enumerate(pages, start=1):
474
  try:
475
+ proc = preprocess_image(page)
476
+ cells = run_tesseract(proc)
477
+ rows = group_cells(cells)
478
+
479
+ page_text = " ".join(" ".join(c["text"] for c in r) for r in rows).lower()
480
+
481
+ items = parse_rows(rows, cells)
482
+ items = dedupe(items)
483
+
484
+ # decide whether to refine with LLM
485
+ use_llm = False
486
+ if GEMINI_API_KEY and len(items) > 0:
487
+ inconsistent = sum(
488
+ 1 for it in items
489
+ if abs(it["item_quantity"] * it["item_rate"] - it["item_amount"]) > max(2, 0.03 * it["item_amount"])
490
+ )
491
+ if inconsistent > max(1, len(items) // 6):
492
+ use_llm = True
493
+
494
+ if use_llm:
495
+ items, usage = refine_with_llm(items, page_text)
496
+ for k in gemini_usage:
497
+ gemini_usage[k] += usage[k]
498
+
499
+ results.append({
500
+ "page_no": str(idx),
501
+ "page_type": "Bill Detail",
502
+ "bill_items": items,
503
+ })
504
+
 
 
 
 
 
 
 
 
505
  except Exception:
506
+ results.append({
507
+ "page_no": str(idx),
508
+ "page_type": "Bill Detail",
509
+ "bill_items": []
510
+ })
511
+
512
+ total_count = sum(len(p["bill_items"]) for p in results)
513
+
514
+ return {
515
+ "is_success": True,
516
+ "token_usage": gemini_usage,
517
+ "data": {
518
+ "pagewise_line_items": results,
519
+ "total_item_count": total_count
520
+ }
521
+ }
522
+
523
 
524
+ # -------------------------------------------------------
525
+ # RAW TSV DEBUG ENDPOINT
526
+ # -------------------------------------------------------
527
  @app.post("/debug-tsv")
528
  async def debug_tsv(payload: BillRequest):
 
529
  try:
530
+ resp = requests.get(payload.document, timeout=20)
531
+ resp.raise_for_status()
532
+ data = resp.content
 
533
  except:
534
+ return {"error": "Unable to download"}
535
 
536
+ url = payload.document.split("?", 1)[0].lower()
537
+
538
+ if url.endswith(".pdf"):
539
+ img = convert_from_bytes(data)[0]
 
540
  else:
541
+ img = Image.open(BytesIO(data))
542
 
543
  proc = preprocess_image(img)
544
+ return {"cells": run_tesseract(proc)}
 
 
 
545
 
546
 
547
  @app.get("/")
548
+ def root():
549
+ return {"status": "ok", "message": "Bill extraction API running"}
550
+
551
+