Sathvik-kota commited on
Commit
811fc30
·
verified ·
1 Parent(s): 8e9daea

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +81 -259
app.py CHANGED
@@ -11,10 +11,10 @@ from io import BytesIO
11
  from typing import List, Dict, Any, Optional, Tuple
12
 
13
  import uvicorn
14
- from fastapi import FastAPI
15
  from pydantic import BaseModel
16
  import requests
17
- from PIL import Image, ImageOps
18
  from pdf2image import convert_from_bytes
19
  import pytesseract
20
  from pytesseract import Output
@@ -54,7 +54,7 @@ TOTAL_KEYWORDS = re.compile(
54
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|total)",
55
  re.I,
56
  )
57
- FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm|signature)", re.I)
58
 
59
  HEADER_KEYWORDS = [
60
  "description", "qty", "hrs", "rate", "discount", "net", "amt", "amount",
@@ -71,10 +71,9 @@ HEADER_PHRASES = [
71
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
72
 
73
  # ---------------- small utilities ----------------
74
- def sanitize_ocr_text(s: Optional[str]) -> str:
75
  if not s:
76
  return ""
77
- s = str(s)
78
  s = s.replace("\u2014", "-").replace("\u2013", "-")
79
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
80
  s = s.replace("\r\n", "\n").replace("\r", "\n")
@@ -88,8 +87,6 @@ def normalize_num_str(s: Optional[str]) -> Optional[float]:
88
  s = str(s).strip()
89
  if s == "":
90
  return None
91
- # common OCR fixes in numeric strings (O -> 0, , as thousands)
92
- s = s.replace("O", "0").replace("o", "0").replace("l", "1")
93
  s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
94
  negative = False
95
  if s.startswith("(") and s.endswith(")"):
@@ -110,8 +107,6 @@ def is_numeric_token(t: Optional[str]) -> bool:
110
  return bool(t and NUM_RE.search(str(t)))
111
 
112
  def looks_like_date_num(s: str) -> bool:
113
- if not s:
114
- return False
115
  s_digits = re.sub(r"[^\d]", "", s or "")
116
  if len(s_digits) >= 7:
117
  if s_digits.endswith(("2025","2024","2023","2022","2026")):
@@ -123,50 +118,17 @@ def looks_like_date_num(s: str) -> bool:
123
  pass
124
  return False
125
 
126
- def collapse_repeated_chars(s: str) -> str:
127
- # collapse runs of repeated punctuation/letters that are OCR artifacts
128
- s = re.sub(r"([^\w\s])\1{2,}", r"\1", s)
129
- s = re.sub(r"([A-Za-z])\1{3,}", r"\1\1", s)
130
- return s
131
-
132
  def clean_name_text(s: str) -> str:
133
- if not s:
134
- return ""
135
- s = str(s)
136
- s = s.replace("—", "-").replace("–", "-")
137
- s = collapse_repeated_chars(s)
138
- s = re.sub(r"[_\|]{2,}", " ", s)
139
- s = re.sub(r"[^\x00-\x7F]+", " ", s) # remove non-ascii weird chars
140
- s = re.sub(r"[\[\]\{\}\(\)]", " ", s)
141
  s = re.sub(r"\s+", " ", s)
142
  s = s.strip(" -:,.=")
143
- # fix common OCR 'OR' -> 'DR' when it's standalone uppercase
 
144
  s = re.sub(r"\bOR\b", "DR", s)
145
- # remove trailing artifacts like 'x' placed between qty and name
146
- s = re.sub(r"\s+x\s*$", "", s, flags=re.I)
147
  s = s.strip()
148
  return s
149
 
150
- def is_probable_garbage_name(name: str) -> bool:
151
- if not name:
152
- return True
153
- n = name.strip()
154
- # too short or too many non-alpha
155
- alpha_count = len(re.findall(r"[A-Za-z]", n))
156
- digit_count = len(re.findall(r"\d", n))
157
- non_word = len(re.findall(r"[^\w\s]", n))
158
- if alpha_count == 0:
159
- return True
160
- if len(n) < 2:
161
- return True
162
- # if >50% of chars are non-alnum, garbage
163
- if non_word / max(1, len(n)) > 0.45:
164
- return True
165
- # if digits dominate and look not like code/date
166
- if digit_count / max(1, len(n)) > 0.6 and not looks_like_date_num(n):
167
- return True
168
- return False
169
-
170
  # ---------------- image preprocessing ----------------
171
  def pil_to_cv2(img: Image.Image) -> Any:
172
  arr = np.array(img)
@@ -182,56 +144,30 @@ def preprocess_image(pil_img: Image.Image) -> Any:
182
  if w < target_w:
183
  scale = target_w / float(w)
184
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
185
-
186
- # convert to gray + CLAHE (adaptive contrast)
187
  cv_img = pil_to_cv2(pil_img)
 
188
  if cv_img.ndim == 3:
189
  gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
190
  else:
191
  gray = cv_img
192
-
193
- # unsigned int conversion
194
- gray = np.asarray(gray, dtype=np.uint8)
195
-
196
- # CLAHE for contrast enhancement
197
- try:
198
- clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
199
- gray = clahe.apply(gray)
200
- except Exception:
201
- pass
202
-
203
- # denoise
204
- try:
205
- gray = cv2.fastNlMeansDenoising(gray, h=10)
206
- except Exception:
207
- pass
208
-
209
- # adaptive threshold
210
  try:
211
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
212
  cv2.THRESH_BINARY, 41, 15)
213
  except Exception:
214
- _, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
215
-
216
- # morphological operations to remove tiny noise and thin grid lines
217
- kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
218
- bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel_close)
219
- kernel_open = cv2.getStructuringElement(cv2.MORPH_RECT, (1,1))
220
- bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel_open)
221
-
222
  return bw
223
 
224
- # ---------------- OCR TSV (word-level) ----------------
225
- OCR_CONF_THRESHOLD = 30.0 # drop tokens with confidence less than this (if provided by tesseract)
226
-
227
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
228
- # pytesseract can accept numpy arrays or PIL images
229
  try:
230
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
231
  except Exception:
232
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
233
-
234
- cells: List[Dict[str, Any]] = []
235
  n = len(o.get("text", []))
236
  for i in range(n):
237
  raw = o["text"][i]
@@ -240,40 +176,18 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
240
  txt = str(raw).strip()
241
  if not txt:
242
  continue
243
- # try to parse confidence (tesseract returns strings sometimes)
244
- conf_raw = o.get("conf", [None]*n)[i]
245
  try:
246
- conf = float(conf_raw) if conf_raw not in (None, "", "-1") else -1.0
247
  except Exception:
248
  conf = -1.0
249
-
250
- # skip very-low-confidence tokens (reduce garbage)
251
- if conf >= 0 and conf < OCR_CONF_THRESHOLD:
252
- continue
253
-
254
- left = int(o.get("left", [0]*n)[i])
255
- top = int(o.get("top", [0]*n)[i])
256
- width = int(o.get("width", [0]*n)[i])
257
- height = int(o.get("height", [0]*n)[i])
258
  center_y = top + height / 2.0
259
  center_x = left + width / 2.0
260
-
261
- # normalize numeric OCR artifacts inside token
262
- if re.search(r"[0-9]", txt):
263
- # quick fixes
264
- txt = txt.replace("O", "0").replace("o", "0").replace("l", "1")
265
- txt = re.sub(r"[^0-9\.\,\-\(\)]", lambda m: "" if m.group(0).isspace() else m.group(0), txt)
266
-
267
- cells.append({
268
- "text": txt,
269
- "conf": conf,
270
- "left": left,
271
- "top": top,
272
- "width": width,
273
- "height": height,
274
- "center_y": center_y,
275
- "center_x": center_x
276
- })
277
  return cells
278
 
279
  # ---------------- grouping & merging helpers ----------------
@@ -281,7 +195,7 @@ def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) ->
281
  if not cells:
282
  return []
283
  sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
284
- rows: List[List[Dict[str, Any]]] = []
285
  current = [sorted_cells[0]]
286
  last_y = sorted_cells[0]["center_y"]
287
  for c in sorted_cells[1:]:
@@ -299,13 +213,12 @@ def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) ->
299
  def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
300
  if not rows:
301
  return rows
302
- merged: List[List[Dict[str, Any]]] = []
303
  i = 0
304
  while i < len(rows):
305
  row = rows[i]
306
  tokens = [c["text"] for c in row]
307
  has_num = any(is_numeric_token(t) for t in tokens)
308
- # Merge a full-text row with the next numeric row if appropriate
309
  if not has_num and i + 1 < len(rows):
310
  next_row = rows[i+1]
311
  next_tokens = [c["text"] for c in next_row]
@@ -313,34 +226,34 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
313
  if next_has_num and len(tokens) >= 2 and len([t for t in next_tokens if not is_numeric_token(t)]) <= 3:
314
  merged_row = []
315
  min_left = min((c["left"] for c in next_row), default=0)
316
- offset = 0
317
  for c in row:
318
  newc = c.copy()
319
- # Shift text to left of numeric columns
320
- newc["left"] = min_left - 20 + offset
321
- newc["center_x"] = newc["left"] + newc["width"] / 2.0
322
  merged_row.append(newc)
323
- offset += 8
324
  merged_row.extend(next_row)
325
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
326
  i += 2
327
  continue
328
- # Merge two short text-only rows (e.g. split names)
329
  if not has_num and i + 1 < len(rows):
330
  next_row = rows[i+1]
331
  next_tokens = [c["text"] for c in next_row]
332
  next_has_num = any(is_numeric_token(t) for t in next_tokens)
333
  if not next_has_num and len(tokens) <= 3 and len(next_tokens) <= 3:
334
- combined = row + next_row
335
- min_left = min((c["left"] for c in combined), default=0)
336
  merged_row = []
337
- for c in combined:
 
 
338
  newc = c.copy()
339
- # move slightly to align left
340
- if newc["left"] <= min_left:
341
- newc["left"] = min_left
342
- newc["center_x"] = newc["left"] + newc["width"] / 2.0
 
343
  merged_row.append(newc)
 
344
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
345
  i += 2
346
  continue
@@ -350,9 +263,10 @@ def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[st
350
 
351
  # ---------------- numeric column detection ----------------
352
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
353
- xs = sorted([c["center_x"] for c in cells if is_numeric_token(c["text"])])
354
  if not xs:
355
  return []
 
356
  if len(xs) == 1:
357
  return [xs[0]]
358
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
@@ -379,91 +293,6 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
379
  distances = [abs(token_x - cx) for cx in column_centers]
380
  return int(np.argmin(distances))
381
 
382
- # ---------------- item validation & repair ----------------
383
- MAX_REASONABLE_QTY = 100.0
384
- MAX_REASONABLE_RATE = 1_000_000.0
385
-
386
- def validate_and_fix_item(item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
387
- """
388
- Ensure amount/rate/qty are reasonable. Try to fix obvious OCR-caused errors.
389
- Return None if the item should be discarded as garbage.
390
- """
391
- # sanitize name
392
- name = clean_name_text(item.get("item_name", "") or "")
393
- if is_probable_garbage_name(name):
394
- # reject obviously garbage names
395
- return None
396
-
397
- amt = item.get("item_amount", 0.0) or 0.0
398
- rate = item.get("item_rate", 0.0) or 0.0
399
- qty = item.get("item_quantity", 0.0) or 0.0
400
-
401
- # sanity caps
402
- try:
403
- amt = float(amt)
404
- except Exception:
405
- return None
406
- try:
407
- rate = float(rate)
408
- except Exception:
409
- rate = 0.0
410
- try:
411
- qty = float(qty)
412
- except Exception:
413
- qty = 1.0
414
-
415
- # If qty is ridiculously large -> likely OCR error. Reset to 1 and set rate=amount if rate invalid
416
- if qty > MAX_REASONABLE_QTY:
417
- logger.debug("Qty %s too large for '%s' — resetting to 1", qty, name)
418
- qty = 1.0
419
- if rate <= 0 or rate > amt * 10:
420
- rate = amt
421
-
422
- # If rate > amt but rate is extremely large -> swap/assume misplace: if rate*qty approximates amt, fine.
423
- if rate > amt and qty > 0:
424
- if abs(rate * qty - amt) > max(0.05 * amt, 1.0):
425
- # If rate bigger than amount and doesn't fit, assume rate was missing -> set rate = amt/qty if meaningful
426
- try:
427
- candidate = amt / qty if qty else amt
428
- if 0 < candidate <= MAX_REASONABLE_RATE:
429
- logger.debug("Adjusting rate for '%s' from %s to %s", name, rate, candidate)
430
- rate = candidate
431
- except Exception:
432
- pass
433
-
434
- # If rate == 0 but qty>0 and amt>0 try infer simple integer ratio from numeric candidates already done upstream,
435
- # fallback: set rate = amt (qty assumed 1)
436
- if (rate == 0 or rate is None) and qty and qty > 0:
437
- if qty == 1 or not (amt / qty).is_integer():
438
- # simply compute rate
439
- try:
440
- candidate_rate = amt / qty
441
- if candidate_rate > 0 and candidate_rate <= MAX_REASONABLE_RATE:
442
- rate = round(candidate_rate, 2)
443
- except Exception:
444
- rate = 0.0
445
-
446
- # final sanity: negative/zero amounts dropped
447
- if amt <= 0.0:
448
- return None
449
- if qty <= 0:
450
- qty = 1.0
451
- # clamp qty to reasonable
452
- if qty > MAX_REASONABLE_QTY:
453
- qty = 1.0
454
-
455
- # Round sensible values
456
- amt = float(round(amt, 2))
457
- rate = float(round(rate, 2)) if rate is not None else 0.0
458
- qty = float(round(qty, 3))
459
-
460
- return {
461
- "item_name": name,
462
- "item_amount": amt,
463
- "item_rate": rate,
464
- "item_quantity": qty
465
- }
466
-
467
  # ---------------- Gemini refinement (deterministic) ----------------
468
  def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
469
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
@@ -515,7 +344,7 @@ def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") ->
515
 
516
  # ---------------- parsing rows into items (modified) ----------------
517
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
518
- parsed_items: List[Dict[str, Any]] = []
519
  rows = merge_multiline_names(rows)
520
  column_centers = detect_numeric_columns(page_cells, max_columns=4)
521
 
@@ -524,15 +353,9 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
524
  if not tokens:
525
  continue
526
  joined_lower = " ".join(tokens).lower()
527
-
528
- # skip obvious footers and headers
529
  if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
530
  continue
531
  if all(not is_numeric_token(t) for t in tokens):
532
- # if a pure-text row but looks like header -> skip
533
- if looks_like_header_text(joined_lower):
534
- continue
535
- # otherwise we may have description-only rows (handled by merge_multiline_names)
536
  continue
537
 
538
  numeric_values = []
@@ -543,7 +366,6 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
543
  v = normalize_num_str(t)
544
  if v is not None:
545
  numeric_values.append(float(v))
546
- # unique sorted descending
547
  numeric_values = sorted({int(x) if float(x).is_integer() else x for x in numeric_values}, reverse=True)
548
 
549
  if column_centers:
@@ -553,7 +375,7 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
553
  t = c["text"]
554
  if is_numeric_token(t) and not looks_like_date_num(t):
555
  col_idx = assign_token_to_column(c["center_x"], column_centers)
556
- if col_idx is None or col_idx < 0:
557
  numeric_bucket_map[len(column_centers) - 1].append(t)
558
  else:
559
  numeric_bucket_map[col_idx].append(t)
@@ -571,16 +393,13 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
571
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
572
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
573
 
574
- # fallback: last numeric token as amount
575
  if amount is None:
576
  for t in reversed(tokens):
577
  if is_numeric_token(t) and not looks_like_date_num(t):
578
- candidate = normalize_num_str(t)
579
- if candidate is not None:
580
- amount = candidate
581
  break
582
 
583
- # try to infer rate & qty from numeric_values
584
  if amount is not None and numeric_values:
585
  for cand in numeric_values:
586
  try:
@@ -600,33 +419,43 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
600
  if r < 1 or r > 200:
601
  continue
602
  if abs(ratio - r) <= max(0.03 * r, 0.15):
603
- # reasonable integer quantity
604
- qty = float(r)
605
- rate = cand_float
606
- break
607
 
608
- # additional fallback if rate missing but qty exists
609
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
610
  try:
611
  candidate_rate = amount / qty
612
- if 0 < candidate_rate <= 1e7:
613
  rate = candidate_rate
614
  except Exception:
615
  pass
616
 
617
- # default quantity = 1 if unknown
618
  if qty is None:
619
  qty = 1.0
620
 
621
- # final rounding / validation via helper
622
- raw_item = {"item_name": name if name else "UNKNOWN", "item_amount": amount or 0.0,
623
- "item_rate": rate or 0.0, "item_quantity": qty or 1.0}
624
- fixed = validate_and_fix_item(raw_item)
625
- if fixed:
626
- parsed_items.append(fixed)
627
- # else skip
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  else:
629
- # fallback parsing if no clear numeric columns
630
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t) and not looks_like_date_num(t)]
631
  if not numeric_idxs:
632
  continue
@@ -637,8 +466,8 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
637
  name = " ".join(tokens[:last]).strip()
638
  if not name:
639
  continue
 
640
 
641
- # gather numeric candidates on the right to infer rate/qty
642
  right_nums = []
643
  for i in numeric_idxs:
644
  v = normalize_num_str(tokens[i])
@@ -646,7 +475,6 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
646
  right_nums.append(float(v))
647
  right_nums = sorted({int(x) if float(x).is_integer() else x for x in right_nums}, reverse=True)
648
 
649
- rate = None; qty = None
650
  if len(right_nums) >= 2:
651
  cand = right_nums[1]
652
  if 1 < cand < float(amt):
@@ -672,21 +500,22 @@ def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[D
672
  if rate is None:
673
  rate = 0.0
674
 
675
- raw_item = {"item_name": clean_name_text(name), "item_amount": float(round(amt,2)),
676
- "item_rate": float(round(rate,2)), "item_quantity": float(qty)}
677
- fixed = validate_and_fix_item(raw_item)
678
- if fixed:
679
- parsed_items.append(fixed)
 
680
 
681
  return parsed_items
682
 
683
  # ---------------- dedupe & totals ----------------
684
  def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
685
  seen = set()
686
- out: List[Dict[str, Any]] = []
687
  for it in items:
688
  nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
689
- key = (nm[:120], round(float(it.get("item_amount", 0.0) or 0.0), 2))
690
  if key in seen:
691
  continue
692
  seen.add(key)
@@ -771,13 +600,10 @@ def final_item_filter(item: Dict[str, Any], known_page_headers: List[str] = [],
771
  qty = float(item.get("item_quantity", 0) or 0)
772
  if qty <= 0:
773
  return False
774
- if rate and rate > amt * 10 and amt < 10000:
775
  return False
776
  if amt <= 0.0:
777
  return False
778
- # must contain at least one alphabetic char
779
- if not re.search(r"[A-Za-z]", name):
780
- return False
781
  return True
782
 
783
  # ---------------- main endpoint ----------------
@@ -824,7 +650,7 @@ async def extract_bill_data(payload: BillRequest):
824
  "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
825
  }
826
 
827
- images: List[Image.Image] = []
828
  clean_url = doc_url.split("?", 1)[0].lower()
829
  try:
830
  if clean_url.endswith(".pdf"):
@@ -847,8 +673,6 @@ async def extract_bill_data(payload: BillRequest):
847
  try:
848
  proc = preprocess_image(page_img)
849
  cells = image_to_tsv_cells(proc)
850
- if not cells:
851
- logger.debug("No OCR cells extracted for page %s", idx)
852
  rows = group_cells_into_rows(cells, y_tolerance=12)
853
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
854
 
@@ -874,7 +698,6 @@ async def extract_bill_data(payload: BillRequest):
874
 
875
  parsed_items = parse_rows_with_columns(rows, cells)
876
 
877
- # Use Gemini only if configured
878
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
879
  for k in cumulative_token_usage:
880
  cumulative_token_usage[k] += token_u.get(k, 0)
@@ -942,5 +765,4 @@ async def run_all_samples():
942
  logger.exception("run_all_samples failed: %s", e)
943
  return {"status": "error", "error": str(e)}
944
 
945
- if __name__ == "__main__":
946
- uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
11
  from typing import List, Dict, Any, Optional, Tuple
12
 
13
  import uvicorn
14
+ from fastapi import FastAPI, BackgroundTasks
15
  from pydantic import BaseModel
16
  import requests
17
+ from PIL import Image
18
  from pdf2image import convert_from_bytes
19
  import pytesseract
20
  from pytesseract import Output
 
54
  r"(grand\s*total|net\s*payable|total\s*amount|amount\s*payable|bill\s*amount|final\s*amount|balance\s*due|sub\s*total|subtotal|total)",
55
  re.I,
56
  )
57
+ FOOTER_KEYWORDS = re.compile(r"(page|printed on|printed:|date:|time:|am|pm)", re.I)
58
 
59
  HEADER_KEYWORDS = [
60
  "description", "qty", "hrs", "rate", "discount", "net", "amt", "amount",
 
71
  HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES]
72
 
73
  # ---------------- small utilities ----------------
74
+ def sanitize_ocr_text(s: str) -> str:
75
  if not s:
76
  return ""
 
77
  s = s.replace("\u2014", "-").replace("\u2013", "-")
78
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
79
  s = s.replace("\r\n", "\n").replace("\r", "\n")
 
87
  s = str(s).strip()
88
  if s == "":
89
  return None
 
 
90
  s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
91
  negative = False
92
  if s.startswith("(") and s.endswith(")"):
 
107
  return bool(t and NUM_RE.search(str(t)))
108
 
109
  def looks_like_date_num(s: str) -> bool:
 
 
110
  s_digits = re.sub(r"[^\d]", "", s or "")
111
  if len(s_digits) >= 7:
112
  if s_digits.endswith(("2025","2024","2023","2022","2026")):
 
118
  pass
119
  return False
120
 
 
 
 
 
 
 
121
  def clean_name_text(s: str) -> str:
122
+ s = s.replace("—", "-")
 
 
 
 
 
 
 
123
  s = re.sub(r"\s+", " ", s)
124
  s = s.strip(" -:,.=")
125
+ s = re.sub(r"\s+x$", "", s, flags=re.I)
126
+ s = re.sub(r"[\)\}\]]+$", "", s)
127
  s = re.sub(r"\bOR\b", "DR", s)
128
+ s = s.strip(" -:,.")
 
129
  s = s.strip()
130
  return s
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # ---------------- image preprocessing ----------------
133
  def pil_to_cv2(img: Image.Image) -> Any:
134
  arr = np.array(img)
 
144
  if w < target_w:
145
  scale = target_w / float(w)
146
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
 
 
147
  cv_img = pil_to_cv2(pil_img)
148
+ # grayscale and denoise
149
  if cv_img.ndim == 3:
150
  gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
151
  else:
152
  gray = cv_img
153
+ gray = cv2.fastNlMeansDenoising(gray, h=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  try:
155
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
156
  cv2.THRESH_BINARY, 41, 15)
157
  except Exception:
158
+ _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
159
+ kernel = np.ones((1,1), np.uint8)
160
+ bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
 
 
 
 
 
161
  return bw
162
 
163
+ # ---------------- OCR TSV ----------------
 
 
164
  def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
165
+ # pytesseract expects either a PIL image or numpy array
166
  try:
167
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config="--psm 6")
168
  except Exception:
169
  o = pytesseract.image_to_data(cv_img, output_type=Output.DICT)
170
+ cells = []
 
171
  n = len(o.get("text", []))
172
  for i in range(n):
173
  raw = o["text"][i]
 
176
  txt = str(raw).strip()
177
  if not txt:
178
  continue
 
 
179
  try:
180
+ conf = float(o["conf"][i]) if o["conf"][i] not in (None, "", "-1") else -1.0
181
  except Exception:
182
  conf = -1.0
183
+ left = int(o.get("left", [0])[i])
184
+ top = int(o.get("top", [0])[i])
185
+ width = int(o.get("width", [0])[i])
186
+ height = int(o.get("height", [0])[i])
 
 
 
 
 
187
  center_y = top + height / 2.0
188
  center_x = left + width / 2.0
189
+ cells.append({"text": txt, "conf": conf, "left": left, "top": top,
190
+ "width": width, "height": height, "center_y": center_y, "center_x": center_x})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  return cells
192
 
193
  # ---------------- grouping & merging helpers ----------------
 
195
  if not cells:
196
  return []
197
  sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"]))
198
+ rows = []
199
  current = [sorted_cells[0]]
200
  last_y = sorted_cells[0]["center_y"]
201
  for c in sorted_cells[1:]:
 
213
  def merge_multiline_names(rows: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
214
  if not rows:
215
  return rows
216
+ merged = []
217
  i = 0
218
  while i < len(rows):
219
  row = rows[i]
220
  tokens = [c["text"] for c in row]
221
  has_num = any(is_numeric_token(t) for t in tokens)
 
222
  if not has_num and i + 1 < len(rows):
223
  next_row = rows[i+1]
224
  next_tokens = [c["text"] for c in next_row]
 
226
  if next_has_num and len(tokens) >= 2 and len([t for t in next_tokens if not is_numeric_token(t)]) <= 3:
227
  merged_row = []
228
  min_left = min((c["left"] for c in next_row), default=0)
229
+ offset = 10
230
  for c in row:
231
  newc = c.copy()
232
+ newc["left"] = min_left - offset
233
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
 
234
  merged_row.append(newc)
235
+ offset += 10
236
  merged_row.extend(next_row)
237
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
238
  i += 2
239
  continue
 
240
  if not has_num and i + 1 < len(rows):
241
  next_row = rows[i+1]
242
  next_tokens = [c["text"] for c in next_row]
243
  next_has_num = any(is_numeric_token(t) for t in next_tokens)
244
  if not next_has_num and len(tokens) <= 3 and len(next_tokens) <= 3:
 
 
245
  merged_row = []
246
+ min_left = min((c["left"] for c in next_row + row), default=0)
247
+ offset = 10
248
+ for c in row + next_row:
249
  newc = c.copy()
250
+ if newc["left"] > min_left:
251
+ newc["left"] = newc["left"]
252
+ else:
253
+ newc["left"] = min_left - offset
254
+ newc["center_x"] = newc["left"] + newc.get("width", 0) / 2.0
255
  merged_row.append(newc)
256
+ offset += 5
257
  merged.append(sorted(merged_row, key=lambda cc: cc["left"]))
258
  i += 2
259
  continue
 
263
 
264
  # ---------------- numeric column detection ----------------
265
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 4) -> List[float]:
266
+ xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])]
267
  if not xs:
268
  return []
269
+ xs = sorted(xs)
270
  if len(xs) == 1:
271
  return [xs[0]]
272
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
 
293
  distances = [abs(token_x - cx) for cx in column_centers]
294
  return int(np.argmin(distances))
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # ---------------- Gemini refinement (deterministic) ----------------
297
  def refine_with_gemini(page_items: List[Dict[str, Any]], page_text: str = "") -> Tuple[List[Dict[str, Any]], Dict[str, int]]:
298
  zero_usage = {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
 
344
 
345
  # ---------------- parsing rows into items (modified) ----------------
346
  def parse_rows_with_columns(rows: List[List[Dict[str, Any]]], page_cells: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
347
+ parsed_items = []
348
  rows = merge_multiline_names(rows)
349
  column_centers = detect_numeric_columns(page_cells, max_columns=4)
350
 
 
353
  if not tokens:
354
  continue
355
  joined_lower = " ".join(tokens).lower()
 
 
356
  if FOOTER_KEYWORDS.search(joined_lower) and not any(is_numeric_token(t) for t in tokens):
357
  continue
358
  if all(not is_numeric_token(t) for t in tokens):
 
 
 
 
359
  continue
360
 
361
  numeric_values = []
 
366
  v = normalize_num_str(t)
367
  if v is not None:
368
  numeric_values.append(float(v))
 
369
  numeric_values = sorted({int(x) if float(x).is_integer() else x for x in numeric_values}, reverse=True)
370
 
371
  if column_centers:
 
375
  t = c["text"]
376
  if is_numeric_token(t) and not looks_like_date_num(t):
377
  col_idx = assign_token_to_column(c["center_x"], column_centers)
378
+ if col_idx is None:
379
  numeric_bucket_map[len(column_centers) - 1].append(t)
380
  else:
381
  numeric_bucket_map[col_idx].append(t)
 
393
  rate = normalize_num_str(get_bucket(num_cols - 2)) if num_cols >= 2 else None
394
  qty = normalize_num_str(get_bucket(num_cols - 3)) if num_cols >= 3 else None
395
 
 
396
  if amount is None:
397
  for t in reversed(tokens):
398
  if is_numeric_token(t) and not looks_like_date_num(t):
399
+ amount = normalize_num_str(t)
400
+ if amount is not None:
 
401
  break
402
 
 
403
  if amount is not None and numeric_values:
404
  for cand in numeric_values:
405
  try:
 
419
  if r < 1 or r > 200:
420
  continue
421
  if abs(ratio - r) <= max(0.03 * r, 0.15):
422
+ if r <= 100:
423
+ rate = cand_float
424
+ qty = float(r)
425
+ break
426
 
 
427
  if (rate is None or rate == 0) and qty and qty != 0 and amount is not None:
428
  try:
429
  candidate_rate = amount / qty
430
+ if candidate_rate >= 2:
431
  rate = candidate_rate
432
  except Exception:
433
  pass
434
 
 
435
  if qty is None:
436
  qty = 1.0
437
 
438
+ try:
439
+ amount = float(round(amount, 2))
440
+ except:
441
+ continue
442
+ try:
443
+ rate = float(round(rate, 2)) if rate is not None else 0.0
444
+ except:
445
+ rate = 0.0
446
+ try:
447
+ qty = float(qty)
448
+ except:
449
+ qty = 1.0
450
+
451
+ parsed_items.append({
452
+ "item_name": name if name else "UNKNOWN",
453
+ "item_amount": amount,
454
+ "item_rate": rate if rate is not None else 0.0,
455
+ "item_quantity": qty if qty is not None else 1.0,
456
+ })
457
+
458
  else:
 
459
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t) and not looks_like_date_num(t)]
460
  if not numeric_idxs:
461
  continue
 
466
  name = " ".join(tokens[:last]).strip()
467
  if not name:
468
  continue
469
+ rate = None; qty = None
470
 
 
471
  right_nums = []
472
  for i in numeric_idxs:
473
  v = normalize_num_str(tokens[i])
 
475
  right_nums.append(float(v))
476
  right_nums = sorted({int(x) if float(x).is_integer() else x for x in right_nums}, reverse=True)
477
 
 
478
  if len(right_nums) >= 2:
479
  cand = right_nums[1]
480
  if 1 < cand < float(amt):
 
500
  if rate is None:
501
  rate = 0.0
502
 
503
+ parsed_items.append({
504
+ "item_name": clean_name_text(name),
505
+ "item_amount": float(round(amt, 2)),
506
+ "item_rate": float(round(rate, 2)),
507
+ "item_quantity": float(qty),
508
+ })
509
 
510
  return parsed_items
511
 
512
  # ---------------- dedupe & totals ----------------
513
  def dedupe_items(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
514
  seen = set()
515
+ out = []
516
  for it in items:
517
  nm = re.sub(r"\s+", " ", it["item_name"].lower()).strip()
518
+ key = (nm[:120], round(float(it["item_amount"]), 2))
519
  if key in seen:
520
  continue
521
  seen.add(key)
 
600
  qty = float(item.get("item_quantity", 0) or 0)
601
  if qty <= 0:
602
  return False
603
+ if rate and rate > amt:
604
  return False
605
  if amt <= 0.0:
606
  return False
 
 
 
607
  return True
608
 
609
  # ---------------- main endpoint ----------------
 
650
  "token_usage": {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
651
  }
652
 
653
+ images = []
654
  clean_url = doc_url.split("?", 1)[0].lower()
655
  try:
656
  if clean_url.endswith(".pdf"):
 
673
  try:
674
  proc = preprocess_image(page_img)
675
  cells = image_to_tsv_cells(proc)
 
 
676
  rows = group_cells_into_rows(cells, y_tolerance=12)
677
  rows_texts = [" ".join([c["text"] for c in r]).strip() for r in rows]
678
 
 
698
 
699
  parsed_items = parse_rows_with_columns(rows, cells)
700
 
 
701
  refined_items, token_u = refine_with_gemini(parsed_items, page_text)
702
  for k in cumulative_token_usage:
703
  cumulative_token_usage[k] += token_u.get(k, 0)
 
765
  logger.exception("run_all_samples failed: %s", e)
766
  return {"status": "error", "error": str(e)}
767
 
768
+