Sathvik-kota commited on
Commit
84d69f8
·
verified ·
1 Parent(s): b086ce8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +351 -133
app.py CHANGED
@@ -101,6 +101,8 @@ def clean_name_text(s: str) -> str:
101
  s = s.strip(" -:,.")
102
  s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
103
  s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
 
 
104
  return s.strip()
105
 
106
  # ---------------- image preprocessing ----------------
@@ -138,6 +140,8 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
138
  n = len(o.get("text", []))
139
  for i in range(n):
140
  raw = o["text"][i]
 
 
141
  txt = str(raw).strip()
142
  if not txt:
143
  continue
@@ -145,22 +149,13 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
145
  conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
146
  except Exception:
147
  conf = -1.0
148
- left = int(o["left"][i])
149
- top = int(o["top"][i])
150
- width = int(o["width"][i])
151
- height = int(o["height"][i])
152
  center_y = top + height / 2.0
153
  center_x = left + width / 2.0
154
- cells.append({
155
- "text": txt,
156
- "conf": conf,
157
- "left": left,
158
- "top": top,
159
- "width": width,
160
- "height": height,
161
- "center_y": center_y,
162
- "center_x": center_x
163
- })
164
  return cells
165
 
166
  # ---------------- grouping & merge helpers ----------------
@@ -192,6 +187,7 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
192
  row = rows[i]
193
  tokens = [c["text"] for c in row]
194
  has_num = any(is_numeric_token(t) for t in tokens)
 
195
  if not has_num and i + 1 < len(rows):
196
  next_row = rows[i+1]
197
  next_tokens = [c["text"] for c in next_row]
@@ -210,12 +206,34 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
210
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
211
  i += 2
212
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  merged.append(row)
214
  i += 1
215
  return merged
216
 
217
  # ---------------- numeric column detection ----------------
218
- # >>> FIX START replaced rigid 50px with adaptive clustering
219
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
220
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
221
  if not xs:
@@ -225,9 +243,10 @@ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) ->
225
  return [xs[0]]
226
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
227
  mean_gap = float(np.mean(gaps))
228
- std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0
229
  gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
230
- clusters, curr = [], [xs[0]]
 
231
  for i, g in enumerate(gaps):
232
  if g > gap_thresh and len(clusters) < (max_columns - 1):
233
  clusters.append(curr)
@@ -239,7 +258,6 @@ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) ->
239
  if len(centers) > max_columns:
240
  centers = centers[-max_columns:]
241
  return sorted(centers)
242
- # >>> FIX END
243
 
244
  def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]:
245
  if not column_centers:
@@ -248,7 +266,6 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
248
  return int(np.argmin(distances))
249
 
250
  # ---------------- parsing rows into items ----------------
251
-
252
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
253
  parsed_items = []
254
  rows = merge_multiline_names(rows)
@@ -258,42 +275,44 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
258
  tokens = [c["text"] for c in row]
259
  if not tokens:
260
  continue
 
 
 
261
  if all(not is_numeric_token(t) for t in tokens):
262
  continue
263
 
264
- # >>> FIX START — build numeric token list for inference
265
  numeric_values = []
266
  for t in tokens:
267
  if is_numeric_token(t):
268
  v = normalize_num_str(t)
269
  if v is not None:
270
  numeric_values.append(float(v))
271
- # >>> FIX END
 
272
 
273
  if column_centers:
274
  left_text_parts = []
275
  numeric_bucket_map = {i: [] for i in range(len(column_centers))}
276
-
277
  for c in row:
278
  t = c["text"]
 
279
  if is_numeric_token(t):
280
- col_idx = assign_token_to_column(c["center_x"], column_centers)
281
  if col_idx is None:
282
- numeric_bucket_map[len(column_centers)-1].append(t)
283
  else:
284
  numeric_bucket_map[col_idx].append(t)
285
  else:
286
  left_text_parts.append(t)
287
-
288
  raw_name = " ".join(left_text_parts).strip()
289
- name = clean_name_text(raw_name)
290
 
291
  num_cols = len(column_centers)
292
  def get_bucket(idx):
293
  vals = numeric_bucket_map.get(idx, [])
294
  return vals[-1] if vals else None
295
 
296
- # base extraction
297
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
298
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
299
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
@@ -302,78 +321,127 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
302
  for t in reversed(tokens):
303
  if is_numeric_token(t):
304
  amount = normalize_num_str(t)
305
- break
 
306
 
307
- # >>> FIX STARTstrong inference block
308
  if amount is not None and numeric_values:
309
- # Look for: amount / candidate_rate integer
310
  for cand in numeric_values:
311
- if cand == 0 or cand == amount:
 
 
 
 
 
 
 
 
 
 
 
312
  continue
313
- ratio = amount / cand
314
  r = round(ratio)
315
- if 1 <= r <= 200 and abs(ratio - r) <= max(0.04*r, 0.2):
316
- rate = cand
317
- qty = float(r)
318
- break
319
- # >>> FIX END
320
-
321
- # fallback inference
322
- if (rate is None or rate == 0) and qty:
 
 
 
 
323
  try:
324
- rate = amount / qty
325
- except:
 
 
 
326
  pass
327
 
 
328
  if qty is None:
329
  qty = 1.0
330
 
331
- # cleanup
332
- try: amount = float(round(amount,2))
333
- except: continue
334
- try: rate = float(round(rate,2)) if rate else 0.0
335
- except: rate = 0.0
336
- try: qty = float(qty)
337
- except: qty = 1.0
 
 
 
 
 
 
338
 
339
  parsed_items.append({
340
  "item_name": name if name else "UNKNOWN",
341
  "item_amount": amount,
342
- "item_rate": rate,
343
- "item_quantity": qty
344
  })
345
 
346
  else:
347
- numeric_idxs = [i for i,t in enumerate(tokens) if is_numeric_token(t)]
348
  if not numeric_idxs:
349
  continue
350
-
351
  last = numeric_idxs[-1]
352
- amount = normalize_num_str(tokens[last])
353
- if amount is None:
 
 
 
354
  continue
 
355
 
356
- name = clean_name_text(" ".join(tokens[:last]).strip())
357
- rate = 0.0
358
- qty = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
- # >>> FIX START — fallback inference also upgraded
361
- for cand in numeric_values:
362
- if cand == 0 or cand == amount:
363
- continue
364
- ratio = amount / cand
365
- r = round(ratio)
366
- if 1 <= r <= 200 and abs(ratio - r) <= max(0.04*r, 0.2):
367
- rate = cand
368
- qty = float(r)
369
- break
370
- # >>> FIX END
371
 
372
  parsed_items.append({
373
- "item_name": name,
374
- "item_amount": float(round(amount,2)),
375
- "item_rate": float(round(rate,2)),
376
- "item_quantity": float(qty)
377
  })
378
 
379
  return parsed_items
@@ -384,69 +452,144 @@ def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
384
  out = []
385
  for it in items:
386
  nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
387
- key = (nm[:120], round(it["item_amount"], 2))
388
  if key in seen:
389
  continue
390
  seen.add(key)
391
  out.append(it)
392
  return out
393
 
394
- # ---------------- Gemini refinement (unchanged) ----------------
395
- def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = ""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
397
  if not GEMINI_API_KEY or genai is None:
398
  return page_items, zero_usage
399
-
400
  try:
401
  safe_text = sanitize_ocr_text(page_text)
402
  system_prompt = (
403
- "You are a strict bill-extraction cleaner. Return ONLY a JSON array."
 
 
404
  )
405
  user_prompt = (
406
  f"page_text='''{safe_text}'''\n"
407
  f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
408
- "Return only the cleaned JSON array."
 
 
 
 
 
409
  )
410
-
411
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
412
- response = model.generate_content([
413
- {"role": "system", "parts": [system_prompt]},
414
- {"role": "user", "parts": [user_prompt]}
415
- ], temperature=0.0)
416
-
 
 
 
417
  raw = response.text.strip()
418
  if raw.startswith("```"):
419
- raw = raw.split("```")[1]
 
420
  parsed = json.loads(raw)
421
-
422
  if isinstance(parsed, list):
423
  cleaned = []
424
  for obj in parsed:
425
  try:
426
  cleaned.append({
427
- "item_name": str(obj.get("item_name","")).strip(),
428
- "item_amount": float(obj.get("item_amount",0)),
429
- "item_rate": float(obj.get("item_rate",0)),
430
- "item_quantity": float(obj.get("item_quantity",1)),
431
  })
432
- except:
433
  continue
434
  return cleaned, zero_usage
435
-
436
  return page_items, zero_usage
437
-
438
  except Exception:
439
  return page_items, zero_usage
440
 
441
  # ---------------- header heuristics & final filter ----------------
442
- def final_item_filter(item, known_page_headers):
443
- name = item["item_name"].lower()
444
- amt = item["item_amount"]
445
- if amt <= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  return False
447
- if FOOTER_KEYWORDS.search(name):
448
  return False
449
- if any(h in name for h in known_page_headers):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  return False
451
  return True
452
 
@@ -455,42 +598,117 @@ def final_item_filter(item, known_page_headers):
455
  async def extract_bill_data(payload: BillRequest):
456
  doc_url = payload.document
457
  try:
458
- resp = requests.get(doc_url, timeout=30)
 
 
 
459
  file_bytes = resp.content
460
- except:
461
- return {"is_success": False, "data": {}}
462
 
463
- if doc_url.lower().endswith(".pdf"):
464
- images = convert_from_bytes(file_bytes)
465
- else:
466
- images = [Image.open(BytesIO(file_bytes))]
 
 
 
 
 
 
 
 
 
 
467
 
468
  pagewise = []
469
- total_items = 0
470
-
471
- for idx, img in enumerate(images, start=1):
472
- proc = preprocess_image(img)
473
- cells = image_to_tsv_cells(proc)
474
- rows = group_cells_into_rows(cells)
475
-
476
- rows_text = [" ".join([c["text"] for c in r]) for r in rows]
477
- parsed = parse_rows_with_columns(rows, cells)
478
-
479
- pagewise.append({
480
- "page_no": str(idx),
481
- "page_type": "Bill Detail",
482
- "bill_items": parsed
483
- })
484
- total_items += len(parsed)
485
-
486
- return {
487
- "is_success": True,
488
- "data": {
489
- "pagewise_line_items": pagewise,
490
- "total_item_count": total_items
491
- }
492
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  @app.get("/")
495
- def health():
496
- return {"status": "ok"}
 
 
 
 
101
  s = s.strip(" -:,.")
102
  s = re.sub(r"\bSG0?(\d+)\b", r"SG\1", s, flags=re.I)
103
  s = re.sub(r"\b(RR)[\s\-]*2\b", r"RR-2", s, flags=re.I)
104
+ # fix common OCR mistakes for doctor prefixes
105
+ s = re.sub(r"\bOR\b", "DR", s) # sometimes OCR turns 'DR' -> 'OR'
106
  return s.strip()
107
 
108
  # ---------------- image preprocessing ----------------
 
140
  n = len(o.get("text", []))
141
  for i in range(n):
142
  raw = o["text"][i]
143
+ if raw is None:
144
+ continue
145
  txt = str(raw).strip()
146
  if not txt:
147
  continue
 
149
  conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
150
  except Exception:
151
  conf = -1.0
152
+ left = int(o.get("left", [0])[i])
153
+ top = int(o.get("top", [0])[i])
154
+ width = int(o.get("width", [0])[i])
155
+ height = int(o.get("height", [0])[i])
156
  center_y = top + height / 2.0
157
  center_x = left + width / 2.0
158
+ cells.append({"text": txt, "conf": conf, "left": left, "top": top, "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
 
 
 
 
 
 
 
 
159
  return cells
160
 
161
  # ---------------- grouping & merge helpers ----------------
 
187
  row = rows[i]
188
  tokens = [c["text"] for c in row]
189
  has_num = any(is_numeric_token(t) for t in tokens)
190
+ # if row looks pure text and next row contains numbers but short left text tokens, merge
191
  if not has_num and i + 1 < len(rows):
192
  next_row = rows[i+1]
193
  next_tokens = [c["text"] for c in next_row]
 
206
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
207
  i += 2
208
  continue
209
+ # Additional merge: If a row ends with a trailing token like a doctor's name line with single token and next row also text, merge (helps names split across 2+ lines)
210
+ if not has_num and i + 1 < len(rows):
211
+ next_row = rows[i+1]
212
+ next_tokens = [c["text"] for c in next_row]
213
+ next_has_num = any(is_numeric_token(t) for t in next_tokens)
214
+ if not next_has_num and len(tokens) <= 3 and len(next_tokens) <= 3:
215
+ # merge both textual lines into one (keeps relative left ordering by shifting)
216
+ merged_row = []
217
+ min_left = min((c["left"] for c in next_row + row), default=0)
218
+ offset = 10
219
+ for c in row + next_row:
220
+ newc = c.copy()
221
+ if newc["left"] > min_left:
222
+ newc["left"] = newc["left"]
223
+ else:
224
+ newc["left"] = min_left - offset
225
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
226
+ merged_row.append(newc)
227
+ offset += 5
228
+ merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
229
+ i += 2
230
+ continue
231
  merged.append(row)
232
  i += 1
233
  return merged
234
 
235
  # ---------------- numeric column detection ----------------
236
+ # >>> CHANGE: adaptive clustering (restored to conservative adaptive threshold)
237
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
238
  xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
239
  if not xs:
 
243
  return [xs[0]]
244
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
245
  mean_gap = float(np.mean(gaps))
246
+ std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
247
  gap_thresh = max(30.0, mean_gap + 0.6 * std_gap)
248
+ clusters = []
249
+ curr = [xs[0]]
250
  for i, g in enumerate(gaps):
251
  if g > gap_thresh and len(clusters) < (max_columns - 1):
252
  clusters.append(curr)
 
258
  if len(centers) > max_columns:
259
  centers = centers[-max_columns:]
260
  return sorted(centers)
 
261
 
262
  def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]:
263
  if not column_centers:
 
266
  return int(np.argmin(distances))
267
 
268
  # ---------------- parsing rows into items ----------------
 
269
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
270
  parsed_items = []
271
  rows = merge_multiline_names(rows)
 
275
  tokens = [c["text"] for c in row]
276
  if not tokens:
277
  continue
278
+ joined_lower = " ".join(tokens).lower()
279
+ if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
280
+ continue
281
  if all(not is_numeric_token(t) for t in tokens):
282
  continue
283
 
284
+ # gather numeric candidates (unique, filtered)
285
  numeric_values = []
286
  for t in tokens:
287
  if is_numeric_token(t):
288
  v = normalize_num_str(t)
289
  if v is not None:
290
  numeric_values.append(float(v))
291
+ # de-duplicate and sort descending (larger candidates first)
292
+ numeric_values = sorted(list({int(x) if float(x).is_integer() else x for x in numeric_values}), reverse=True)
293
 
294
  if column_centers:
295
  left_text_parts = []
296
  numeric_bucket_map = {i: [] for i in range(len(column_centers))}
 
297
  for c in row:
298
  t = c["text"]
299
+ cx = c["center_x"]
300
  if is_numeric_token(t):
301
+ col_idx = assign_token_to_column(cx, column_centers)
302
  if col_idx is None:
303
+ numeric_bucket_map[len(column_centers) - 1].append(t)
304
  else:
305
  numeric_bucket_map[col_idx].append(t)
306
  else:
307
  left_text_parts.append(t)
 
308
  raw_name = " ".join(left_text_parts).strip()
309
+ name = clean_name_text(raw_name) if raw_name else ""
310
 
311
  num_cols = len(column_centers)
312
  def get_bucket(idx):
313
  vals = numeric_bucket_map.get(idx, [])
314
  return vals[-1] if vals else None
315
 
 
316
  amount = normalize_num_str(get_bucket(num_cols - 1)) if num_cols >= 1 else None
317
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
318
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
 
321
  for t in reversed(tokens):
322
  if is_numeric_token(t):
323
  amount = normalize_num_str(t)
324
+ if amount is not None:
325
+ break
326
 
327
+ # >>> CHANGE: safer inference skip tiny candidates like 1, enforce qty bounds, require close ratio
328
  if amount is not None and numeric_values:
329
+ # Only accept candidate as rate if candidate >= 2 (or amount is tiny) and candidate < amount
330
  for cand in numeric_values:
331
+ try:
332
+ cand_float = float(cand)
333
+ except:
334
+ continue
335
+ if cand_float <= 1.0:
336
+ continue
337
+ if amount <= 5 and cand_float < 1.0:
338
+ continue
339
+ if cand_float >= amount:
340
+ continue
341
+ ratio = amount / cand_float if cand_float else None
342
+ if ratio is None:
343
  continue
 
344
  r = round(ratio)
345
+ if r < 1 or r > 200:
346
+ continue
347
+ # require relative closeness threshold (adaptive)
348
+ if abs(ratio - r) <= max(0.03 * r, 0.15):
349
+ # Accept only if qty reasonable (<=100)
350
+ if r <= 100:
351
+ rate = cand_float
352
+ qty = float(r)
353
+ break
354
+
355
+ # fallback compute rate if qty found but rate missing
356
+ if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
357
  try:
358
+ candidate_rate = amount / qty
359
+ # require candidate_rate > 1 (avoid tiny rates) and reasonable
360
+ if candidate_rate >= 2:
361
+ rate = candidate_rate
362
+ except Exception:
363
  pass
364
 
365
+ # final defaults
366
  if qty is None:
367
  qty = 1.0
368
 
369
+ # normalize and sanity-check
370
+ try:
371
+ amount = float(round(amount, 2))
372
+ except Exception:
373
+ continue
374
+ try:
375
+ rate = float(round(rate, 2)) if rate is not None else 0.0
376
+ except Exception:
377
+ rate = 0.0
378
+ try:
379
+ qty = float(qty)
380
+ except Exception:
381
+ qty = 1.0
382
 
383
  parsed_items.append({
384
  "item_name": name if name else "UNKNOWN",
385
  "item_amount": amount,
386
+ "item_rate": rate if rate is not None else 0.0,
387
+ "item_quantity": qty if qty is not None else 1.0,
388
  })
389
 
390
  else:
391
+ numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
392
  if not numeric_idxs:
393
  continue
 
394
  last = numeric_idxs[-1]
395
+ amt = normalize_num_str(tokens[last])
396
+ if amt is None:
397
+ continue
398
+ name = " ".join(tokens[:last]).strip()
399
+ if not name:
400
  continue
401
+ rate = None; qty = None
402
 
403
+ # try to pick rate/qty from previous numeric tokens (right-to-left)
404
+ # and use the safer inference logic (ignore candidate == 1)
405
+ right_nums = []
406
+ for i in numeric_idxs:
407
+ v = normalize_num_str(tokens[i])
408
+ if v is not None:
409
+ right_nums.append(float(v))
410
+ right_nums = sorted(list({int(x) if float(x).is_integer() else x for x in right_nums}), reverse=True)
411
+
412
+ # attempt direct mapping: last numeric = amount, previous maybe rate / qty
413
+ if len(right_nums) >= 2:
414
+ cand = right_nums[1]
415
+ if float(cand) > 1 and float(cand) < float(amt):
416
+ # check ratio
417
+ ratio = float(amt) / float(cand) if cand else None
418
+ if ratio:
419
+ r = round(ratio)
420
+ if 1 <= r <= 200 and abs(ratio - r) <= max(0.03 * r, 0.15) and r <= 100:
421
+ rate = float(cand)
422
+ qty = float(r)
423
+ # fallback: conservative search like above
424
+ if rate is None and right_nums:
425
+ for cand in right_nums:
426
+ if cand <= 1.0 or cand >= float(amt):
427
+ continue
428
+ ratio = float(amt) / float(cand)
429
+ r = round(ratio)
430
+ if 1 <= r <= 100 and abs(ratio - r) <= max(0.03 * r, 0.15):
431
+ rate = float(cand)
432
+ qty = float(r)
433
+ break
434
 
435
+ if qty is None:
436
+ qty = 1.0
437
+ if rate is None:
438
+ rate = 0.0
 
 
 
 
 
 
 
439
 
440
  parsed_items.append({
441
+ "item_name": clean_name_text(name),
442
+ "item_amount": float(round(amt, 2)),
443
+ "item_rate": float(round(rate, 2)),
444
+ "item_quantity": float(qty),
445
  })
446
 
447
  return parsed_items
 
452
  out = []
453
  for it in items:
454
  nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
455
+ key = (nm[:120], round(float(it["item_amount"]), 2))
456
  if key in seen:
457
  continue
458
  seen.add(key)
459
  out.append(it)
460
  return out
461
 
462
+ def detect_subtotals_and_totals(rows_texts: List[str]) -> Dict[str, Optional[float]]:
463
+ subtotal = None; final = None
464
+ for rt in rows_texts[::-1]:
465
+ if not rt or rt.strip() == "":
466
+ continue
467
+ if TOTAL_KEYWORDS.search(rt):
468
+ m = NUM_RE.search(rt)
469
+ if m:
470
+ v = normalize_num_str(m.group(0))
471
+ if v is None:
472
+ continue
473
+ if re.search(r"sub", rt, re.I):
474
+ if subtotal is None: subtotal = float(round(v, 2))
475
+ else:
476
+ if final is None: final = float(round(v, 2))
477
+ return {"subtotal": subtotal, "final_total": final}
478
+
479
+ # ---------------- Gemini refinement (deterministic) ----------------
480
+ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
481
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
482
  if not GEMINI_API_KEY or genai is None:
483
  return page_items, zero_usage
 
484
  try:
485
  safe_text = sanitize_ocr_text(page_text)
486
  system_prompt = (
487
+ "You are a strict bill-extraction cleaner. Return ONLY a JSON array (no explanation, no backticks). "
488
+ "Each entry must be an object with keys: item_name (string), item_amount (float), item_rate (float), item_quantity (float). "
489
+ "Do NOT include subtotal or total lines as items. Do not invent items; only clean/fix/normalize the given items."
490
  )
491
  user_prompt = (
492
  f"page_text='''{safe_text}'''\n"
493
  f"items = {json.dumps(page_items, ensure_ascii=False)}\n\n"
494
+ "Example:\n"
495
+ "items = [{'item_name':'Consultation Charge | DR PREETHI','item_amount':300.0,'item_rate':0.0,'item_quantity':300.0},\n"
496
+ " {'item_name':'Description Qty / Hrs Consultation Rate Discount Net Amt','item_amount':1950.0,'item_rate':1950.0,'item_quantity':1.0}]\n"
497
+ "=>\n"
498
+ "[{'item_name':'Consultation Charge | DR PREETHI MARY JOSEPH','item_amount':300.0,'item_rate':300.0,'item_quantity':1.0}]\n\n"
499
+ "Return only the cleaned JSON array of items."
500
  )
 
501
  model = genai.GenerativeModel(GEMINI_MODEL_NAME)
502
+ response = model.generate_content(
503
+ [
504
+ {"role": "system", "parts": [system_prompt]},
505
+ {"role": "user", "parts": [user_prompt]},
506
+ ],
507
+ temperature=0.0,
508
+ max_output_tokens=1000,
509
+ )
510
  raw = response.text.strip()
511
  if raw.startswith("```"):
512
+ raw = re.sub(r"^```[a-zA-Z]*", "", raw)
513
+ raw = re.sub(r"```$", "", raw).strip()
514
  parsed = json.loads(raw)
 
515
  if isinstance(parsed, list):
516
  cleaned = []
517
  for obj in parsed:
518
  try:
519
  cleaned.append({
520
+ "item_name": str(obj.get("item_name", "")).strip(),
521
+ "item_amount": float(obj.get("item_amount", 0.0)),
522
+ "item_rate": float(obj.get("item_rate", 0.0) or 0.0),
523
+ "item_quantity": float(obj.get("item_quantity", 1.0) or 1.0),
524
  })
525
+ except Exception:
526
  continue
527
  return cleaned, zero_usage
 
528
  return page_items, zero_usage
 
529
  except Exception:
530
  return page_items, zero_usage
531
 
532
  # ---------------- header heuristics & final filter ----------------
533
+ def looks_like_header_text(txt: str, top_of_page: bool = False) -> bool:
534
+ if not txt:
535
+ return False
536
+ t = re.sub(r"\s+", " ", txt.strip().lower())
537
+ # exact phrase blacklist
538
+ if any(h == t for h in HEADER_PHRASES):
539
+ return True
540
+ hits = sum(1 for k in HEADER_KEYWORDS if k in t)
541
+ if hits >= 2:
542
+ return True
543
+ tokens = re.split(r"[\s\|,/:]+", t)
544
+ key_hit_count = sum(1 for tok in tokens if tok in HEADER_KEYWORDS)
545
+ if key_hit_count >= 3:
546
+ return True
547
+ if top_of_page and len(tokens) <= 10 and key_hit_count >= 2:
548
+ return True
549
+ if ("rate" in t or "net" in t) and "amt" in t and not any(ch.isdigit() for ch in t):
550
+ return True
551
+ if t.startswith("description") or t.startswith("qty") or t.startswith("qty /"):
552
+ return True
553
+ return False
554
+
555
+ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [], other_item_names: List[str] = []) -> bool:
556
+ name = (item.get("item_name") or "").strip()
557
+ if not name:
558
+ return False
559
+ ln = name.lower()
560
+ # header exact detection
561
+ for h in known_page_headers:
562
+ if h and h.strip() and h.strip().lower() in ln:
563
+ return False
564
+ if FOOTER_KEYWORDS.search(ln):
565
+ return False
566
+ if item.get("item_amount", 0) > 1_000_000:
567
  return False
568
+ if len(name) <= 2 and not re.search(r"[a-zA-Z]", name):
569
  return False
570
+ # avoid pure section headers (short & header words)
571
+ words = ln.split()
572
+ header_word_hits = sum(1 for k in HEADER_KEYWORDS if k in ln)
573
+ if header_word_hits >= 1 and len(words) <= 3:
574
+ # if page contains more detailed items with 'room'/'rent'/'nursing' etc, remove this generic header
575
+ lower_other = " ".join(other_item_names).lower()
576
+ if any(k in lower_other for k in ["room", "rent", "nursing", "ward", "surgeon", "anaes", "ot", "charges", "procedure", "radiology"]):
577
+ return False
578
+ # also if name is exactly one of the short header words, drop
579
+ if ln in ("charge", "charges", "services", "consultation", "room", "radiology", "surgery"):
580
+ return False
581
+ # drop non-informative labels even if they have amount (summary rows)
582
+ if len(words) <= 4 and re.search(r"\b(charges|services|room|radiolog|laborat|surgery|procedure|rent|nursing)\b", ln):
583
+ # try to detect if it's a summary (presence of other more specific items)
584
+ lower_other = " ".join(other_item_names).lower()
585
+ if any(tok in lower_other for tok in ["rent", "room", "ward", "nursing", "surgeon", "anaes", "ot"]):
586
+ return False
587
+ if float(item.get("item_amount", 0)) <= 0.0:
588
+ return False
589
+ # sanity check rate vs amount
590
+ rate = float(item.get("item_rate", 0) or 0)
591
+ amt = float(item.get("item_amount", 0) or 0)
592
+ if rate and rate > amt * 10 and amt < 10000:
593
  return False
594
  return True
595
 
 
598
  async def extract_bill_data(payload: BillRequest):
599
  doc_url = payload.document
600
  try:
601
+ headers = {"User-Agent": "Mozilla/5.0"}
602
+ resp = requests.get(doc_url, headers=headers, timeout=30)
603
+ if resp.status_code != 200:
604
+ raise RuntimeError(f"download failed status={resp.status_code}")
605
  file_bytes = resp.content
606
+ except Exception:
607
+ return {"is_success": False, "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, "data": {"pagewise_line_items": [], "total_item_count": 0}}
608
 
609
+ images = []
610
+ clean_url = doc_url.split("?", 1)[0].lower()
611
+ try:
612
+ if clean_url.endswith(".pdf"):
613
+ images = convert_from_bytes(file_bytes)
614
+ elif any(clean_url.endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]):
615
+ images = [Image.open(BytesIO(file_bytes))]
616
+ else:
617
+ try:
618
+ images = convert_from_bytes(file_bytes)
619
+ except Exception:
620
+ images = []
621
+ except Exception:
622
+ images = []
623
 
624
  pagewise = []
625
+ cumulative_token_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
626
+
627
+ for idx, page_img in enumerate(images, start=1):
628
+ try:
629
+ proc = preprocess_image(page_img)
630
+ cells = image_to_tsv_cells(proc)
631
+ rows = group_cells_into_rows(cells, y_tolerance=12)
632
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
633
+
634
+ # === HEADER PREFILTER: remove header-like rows anywhere on page ===
635
+ rows_filtered = []
636
+ for i, (r, rt) in enumerate(zip(rows, rows_texts)):
637
+ top_flag = (i < 6)
638
+ rt_norm = sanitize_ocr_text(rt).lower()
639
+ if looks_like_header_text(rt_norm, top_of_page=top_flag):
640
+ continue
641
+ if any(h in rt_norm for h in HEADER_PHRASES):
642
+ continue
643
+ rows_filtered.append(r)
644
+ # recompute row texts and a simple page_text
645
+ rows = rows_filtered
646
+ rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
647
+ page_text = sanitize_ocr_text(" ".join(rows_texts))
648
+
649
+ # detect page-level top headers (for final filtering)
650
+ top_headers = []
651
+ for i, rt in enumerate(rows_texts[:6]):
652
+ if looks_like_header_text(rt, top_of_page=(i < 4)):
653
+ top_headers.append(rt.strip().lower())
654
+
655
+ parsed_items = parse_rows_with_columns(rows, cells)
656
+
657
+ # ALWAYS attempt Gemini refinement if available (deterministic settings)
658
+ refined_items, token_u = refine_with_gemini(parsed_items, page_text)
659
+ for k in cumulative_token_usage:
660
+ cumulative_token_usage[k] += token_u.get(k, 0)
661
+
662
+ # Prepare other_item_names for contextual filtering (helps remove generic section headers)
663
+ other_item_names = [it.get("item_name","") for it in refined_items]
664
+
665
+ # final cleaning & dedupe
666
+ cleaned = [p for p in refined_items if final_item_filter(p, known_page_headers=top_headers, other_item_names=other_item_names)]
667
+ cleaned = dedupe_items(cleaned)
668
+ cleaned = [p for p in cleaned if not looks_like_header_text(p["item_name"].lower())]
669
+
670
+ page_type = "Bill Detail"
671
+ page_txt = page_text.lower()
672
+ if any(x in page_txt for x in ["pharmacy", "medicine", "tablet"]):
673
+ page_type = "Pharmacy"
674
+ if "final bill" in page_txt or "grand total" in page_txt:
675
+ page_type = "Final Bill"
676
+
677
+ pagewise.append({"page_no": str(idx), "page_type": page_type, "bill_items": cleaned})
678
+ except Exception:
679
+ pagewise.append({"page_no": str(idx), "page_type": "Bill Detail", "bill_items": []})
680
+ continue
681
+
682
+ total_item_count = sum(len(p.get("bill_items", [])) for p in pagewise)
683
+ if not GEMINI_API_KEY or genai is None:
684
+ cumulative_token_usage["warning_no_gemini"] = 1
685
+
686
+ return {"is_success": True, "token_usage": cumulative_token_usage, "data": {"pagewise_line_items": pagewise, "total_item_count": total_item_count}}
687
+
688
+ # ---------------- debug TSV ----------------
689
+ @app.post("/debug-tsv")
690
+ async def debug_tsv(payload: BillRequest):
691
+ doc_url = payload.document
692
+ try:
693
+ resp = requests.get(doc_url, timeout=20)
694
+ if resp.status_code != 200:
695
+ return {"error": "Download failed"}
696
+ file_bytes = resp.content
697
+ except Exception:
698
+ return {"error": "Download failed"}
699
+ clean_url = doc_url.split("?", 1)[0].lower()
700
+ if clean_url.endswith(".pdf"):
701
+ imgs = convert_from_bytes(file_bytes)
702
+ img = imgs[0]
703
+ else:
704
+ img = Image.open(BytesIO(file_bytes))
705
+ proc = preprocess_image(img)
706
+ cells = image_to_tsv_cells(proc)
707
+ return {"cells": cells}
708
 
709
  @app.get("/")
710
+ def health_check():
711
+ msg = "Bill extraction API (final) live."
712
+ if not GEMINI_API_KEY or genai is None:
713
+ msg += " (No GEMINI_API_KEY/configured SDK — LLM refinement skipped.)"
714
+ return {"status": "ok", "message": msg, "hint": "POST /extract-bill-data with {'document':'<url>'}"}