Sathvik-kota commited on
Commit
d246b8c
·
verified ·
1 Parent(s): e568983

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +41 -188
app.py CHANGED
@@ -1,22 +1,10 @@
1
- # Enhanced Bill Extraction API
2
- # Designed for Bajaj Datathon: accurate line item + subtotal + total extraction
3
- #
4
- # Key improvements:
5
- # 1. Explicit subtotal/total detection and preservation
6
- # 2. Double-count prevention via fingerprinting
7
- # 3. Item-sum vs bill-total validation
8
- # 4. Confidence scoring and anomaly detection
9
- # 5. Enhanced preprocessing for table structures
10
- # 6. Gemini-powered structural validation
11
-
12
  import os
13
  import re
14
  import json
15
  import logging
16
  from io import BytesIO
17
- from typing import List, Dict, Any, Optional, Tuple, Set
18
- from dataclasses import dataclass, asdict
19
- from collections import defaultdict
20
 
21
  from fastapi import FastAPI
22
  from pydantic import BaseModel
@@ -40,29 +28,15 @@ try:
40
  except Exception:
41
  vision = None
42
 
43
- try:
44
- import google.generativeai as genai
45
- except Exception:
46
- genai = None
47
-
48
  # -------------------------------------------------------------------------
49
  # Configuration
50
  # -------------------------------------------------------------------------
51
  OCR_ENGINE = os.getenv("OCR_ENGINE", "tesseract").lower()
52
- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
53
- GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-2.0-flash")
54
  AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
55
  TESSERACT_PSM = os.getenv("TESSERACT_PSM", "6")
56
 
57
  logging.basicConfig(level=logging.INFO)
58
- logger = logging.getLogger("bill-extractor-v2")
59
-
60
- if GEMINI_API_KEY and genai is not None:
61
- try:
62
- genai.configure(api_key=GEMINI_API_KEY)
63
- logger.info("Gemini configured")
64
- except Exception as e:
65
- logger.warning("Gemini config failed: %s", e)
66
 
67
  # Lazy clients
68
  _textract_client = None
@@ -85,7 +59,7 @@ def vision_client():
85
  return _vision_client
86
 
87
  # -------------------------------------------------------------------------
88
- # Data Models
89
  # -------------------------------------------------------------------------
90
  @dataclass
91
  class BillLineItem:
@@ -94,15 +68,19 @@ class BillLineItem:
94
  item_quantity: float = 1.0
95
  item_rate: float = 0.0
96
  item_amount: float = 0.0
97
- confidence: float = 1.0 # 0-1 confidence score
98
- source_row: str = "" # raw OCR text for debugging
99
- is_description_continuation: bool = False # multi-line item flag
 
100
 
101
  def to_dict(self) -> Dict[str, Any]:
102
- d = asdict(self)
103
- d.pop("source_row", None) # exclude raw text from output
104
- d.pop("is_description_continuation", None)
105
- return d
 
 
 
106
 
107
  @dataclass
108
  class BillTotal:
@@ -119,26 +97,25 @@ class BillTotal:
119
  class ExtractedPage:
120
  """Page-level extraction result"""
121
  page_no: int
122
- page_type: str # "Bill Detail", "Header", "Footer", etc.
123
  line_items: List[BillLineItem]
124
  bill_totals: BillTotal
125
- page_confidence: float = 1.0
126
 
127
  def to_dict(self) -> Dict[str, Any]:
 
128
  return {
129
  "page_no": self.page_no,
130
  "page_type": self.page_type,
131
  "line_items": [item.to_dict() for item in self.line_items],
132
  "bill_totals": self.bill_totals.to_dict(),
133
- "page_confidence": round(self.page_confidence, 3),
134
  }
135
 
136
  # -------------------------------------------------------------------------
137
- # Regular Expressions (Enhanced)
138
  # -------------------------------------------------------------------------
139
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
140
 
141
- # Total/Subtotal keywords (improved detection)
142
  TOTAL_KEYWORDS = re.compile(
143
  r"\b(grand\s+total|net\s+payable|total\s+(?:amount|due)|amount\s+payable|bill\s+amount|"
144
  r"final\s+(?:amount|total)|balance\s+due|amount\s+due|total\s+payable|payable)\b",
@@ -164,22 +141,20 @@ FOOTER_KEYWORDS = re.compile(
164
  HEADER_KEYWORDS = [
165
  "description", "qty", "qty/hrs", "hrs", "rate", "unit price", "discount",
166
  "net", "amt", "amount", "price", "total", "sl.no", "s.no", "item", "service",
167
- "consultation", "patient", "invoice", "bill", "charges"
168
  ]
169
 
170
  # -------------------------------------------------------------------------
171
  # Text Cleaning & Normalization
172
  # -------------------------------------------------------------------------
173
  def sanitize_ocr_text(s: Optional[str]) -> str:
174
- """Deep clean OCR text"""
175
  if not s:
176
  return ""
177
  s = s.replace("\u2014", "-").replace("\u2013", "-")
178
- s = s.replace("\u00A0", " ") # nbsp
179
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
180
  s = s.replace("\r\n", "\n").replace("\r", "\n")
181
  s = re.sub(r"[ \t]+", " ", s)
182
- # OCR corrections
183
  s = re.sub(r"\b(qiy|qty|oty|gty)\b", "qty", s, flags=re.I)
184
  s = re.sub(r"\b(deseription|descriptin|desription)\b", "description", s, flags=re.I)
185
  return s.strip()
@@ -192,13 +167,11 @@ def normalize_num_str(s: Optional[str], allow_zero: bool = False) -> Optional[fl
192
  if s == "":
193
  return None
194
 
195
- # Handle parentheses (negative indicator)
196
  negative = False
197
  if s.startswith("(") and s.endswith(")"):
198
  negative = True
199
  s = s[1:-1]
200
 
201
- # Remove non-numeric chars except decimal/comma
202
  s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
203
  s = s.replace(",", "")
204
 
@@ -223,7 +196,7 @@ def clean_item_name(s: str) -> str:
223
  s = s.replace("—", "-").replace("–", "-")
224
  s = re.sub(r"\s+", " ", s)
225
  s = s.strip(" -:,.=()[]{}|\\")
226
- s = re.sub(r"\bOR\b", "DR", s) # OCR OR -> DR
227
  return s.strip()
228
 
229
  # -------------------------------------------------------------------------
@@ -236,27 +209,20 @@ def item_fingerprint(item: BillLineItem) -> Tuple[str, float]:
236
  return (name_norm, amount_rounded)
237
 
238
  def dedupe_items_advanced(items: List[BillLineItem]) -> List[BillLineItem]:
239
- """
240
- Remove duplicates while preserving highest-confidence versions.
241
- Handles multi-line descriptions by checking sequential items.
242
- """
243
  if not items:
244
  return []
245
 
246
- # Remove exact duplicates (same fingerprint)
247
  seen: Dict[Tuple, BillLineItem] = {}
248
  for item in items:
249
  fp = item_fingerprint(item)
250
  if fp not in seen or item.confidence > seen[fp].confidence:
251
  seen[fp] = item
252
 
253
- # Remove high-similarity continuation rows (likely description wrapping)
254
  final = []
255
  for item in seen.values():
256
  if item.is_description_continuation:
257
- # Check if very similar to previous item
258
  if final and abs(float(final[-1].item_amount) - float(item.item_amount)) < 0.01:
259
- # Likely continuation; merge
260
  final[-1].item_name = (final[-1].item_name + " " + item.item_name).strip()
261
  continue
262
  final.append(item)
@@ -266,27 +232,24 @@ def dedupe_items_advanced(items: List[BillLineItem]) -> List[BillLineItem]:
266
  # -------------------------------------------------------------------------
267
  # Total/Subtotal Detection
268
  # -------------------------------------------------------------------------
 
 
 
 
 
 
269
  def detect_totals_in_rows(rows: List[List[Dict[str, Any]]]) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
270
- """
271
- Scan rows for subtotal, tax, discount, final total.
272
- Returns: (subtotal, tax, discount, final_total)
273
- """
274
  subtotal = None
275
  tax = None
276
  discount = None
277
  final_total = None
278
 
279
- rows_text = []
280
  for row in rows:
281
  row_text = " ".join([c["text"] for c in row])
282
- rows_text.append((row_text, row))
283
-
284
- # Scan for keywords
285
- for row_text, row in rows_text:
286
  row_lower = row_text.lower()
287
  tokens = row_text.split()
288
 
289
- # Extract number from row
290
  amounts = []
291
  for t in tokens:
292
  if is_numeric_token(t):
@@ -297,10 +260,8 @@ def detect_totals_in_rows(rows: List[List[Dict[str, Any]]]) -> Tuple[Optional[fl
297
  if not amounts:
298
  continue
299
 
300
- # Use rightmost/largest amount typically
301
  amount = max(amounts)
302
 
303
- # Keyword matching
304
  if FINAL_TOTAL_KEYWORDS.search(row_lower):
305
  final_total = amount
306
  elif SUBTOTAL_KEYWORDS.search(row_lower):
@@ -312,17 +273,10 @@ def detect_totals_in_rows(rows: List[List[Dict[str, Any]]]) -> Tuple[Optional[fl
312
 
313
  return subtotal, tax, discount, final_total
314
 
315
- FINAL_TOTAL_KEYWORDS = re.compile(
316
- r"\b(grand\s+total|final\s+(?:total|amount)|total\s+(?:due|payable|amount)|"
317
- r"net\s+payable|amount\s+(?:due|payable)|balance\s+due|payable)\b",
318
- re.I
319
- )
320
-
321
  # -------------------------------------------------------------------------
322
  # Image Preprocessing
323
  # -------------------------------------------------------------------------
324
  def pil_to_cv2(img: Image.Image) -> Any:
325
- """Convert PIL to OpenCV"""
326
  arr = np.array(img)
327
  if arr.ndim == 2:
328
  return arr
@@ -333,30 +287,25 @@ def preprocess_image_for_tesseract(pil_img: Image.Image, target_w: int = 1500) -
333
  pil_img = pil_img.convert("RGB")
334
  w, h = pil_img.size
335
 
336
- # Upscale if too small
337
  if w < target_w:
338
  scale = target_w / float(w)
339
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
340
 
341
  cv_img = pil_to_cv2(pil_img)
342
 
343
- # Grayscale
344
  if cv_img.ndim == 3:
345
  gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
346
  else:
347
  gray = cv_img
348
 
349
- # Denoise
350
  gray = cv2.fastNlMeansDenoising(gray, h=10)
351
 
352
- # Adaptive thresholding (better for tables with shadows)
353
  try:
354
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
355
  cv2.THRESH_BINARY, 41, 15)
356
  except Exception:
357
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
358
 
359
- # Morphological cleanup
360
  kernel = np.ones((2, 2), np.uint8)
361
  bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel)
362
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
@@ -395,7 +344,7 @@ def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]:
395
 
396
  cells.append({
397
  "text": txt,
398
- "conf": max(0.0, conf) / 100.0, # normalize to 0-1
399
  "left": left, "top": top, "width": width, "height": height,
400
  "center_x": center_x, "center_y": center_y
401
  })
@@ -427,7 +376,7 @@ def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) ->
427
  return rows
428
 
429
  # -------------------------------------------------------------------------
430
- # Column Detection (Enhanced)
431
  # -------------------------------------------------------------------------
432
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]:
433
  """Detect x-positions of numeric columns"""
@@ -439,7 +388,6 @@ def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) ->
439
  if len(xs) == 1:
440
  return xs
441
 
442
- # Cluster columns by gap analysis
443
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
444
  mean_gap = float(np.mean(gaps))
445
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
@@ -469,34 +417,28 @@ def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optio
469
  return int(np.argmin(distances))
470
 
471
  # -------------------------------------------------------------------------
472
- # Row Parsing (Enhanced for accuracy)
473
  # -------------------------------------------------------------------------
474
  def parse_rows_with_columns(
475
  rows: List[List[Dict[str, Any]]],
476
  page_cells: List[Dict[str, Any]],
477
  page_text: str = ""
478
  ) -> List[BillLineItem]:
479
- """
480
- Parse rows into line items with improved accuracy.
481
- Handles multi-line descriptions and uncertain quantities.
482
- """
483
  items = []
484
  column_centers = detect_numeric_columns(page_cells, max_columns=6)
485
 
486
- for row_idx, row in enumerate(rows):
487
  tokens = [c["text"] for c in row]
488
  row_text = " ".join(tokens)
489
  row_lower = row_text.lower()
490
 
491
- # Skip footers/headers
492
  if FOOTER_KEYWORDS.search(row_lower) and not any(is_numeric_token(t) for t in tokens):
493
  continue
494
 
495
- # Require at least one numeric token
496
  if not any(is_numeric_token(t) for t in tokens):
497
  continue
498
 
499
- # Extract amounts
500
  numeric_values = []
501
  for t in tokens:
502
  if is_numeric_token(t):
@@ -509,7 +451,6 @@ def parse_rows_with_columns(
509
 
510
  numeric_values = sorted(list(set(numeric_values)), reverse=True)
511
 
512
- # Column-based parsing
513
  if column_centers:
514
  left_text_parts = []
515
  numeric_buckets = {i: [] for i in range(len(column_centers))}
@@ -530,13 +471,11 @@ def parse_rows_with_columns(
530
  item_name = " ".join(left_text_parts).strip()
531
  item_name = clean_item_name(item_name) if item_name else "UNKNOWN"
532
 
533
- # Extract from columns (right-most is typically amount)
534
  num_cols = len(column_centers)
535
  amount = None
536
  rate = None
537
  qty = None
538
 
539
- # Try rightmost column first (usually total amount)
540
  if num_cols >= 1:
541
  bucket = numeric_buckets.get(num_cols - 1, [])
542
  if bucket:
@@ -544,25 +483,21 @@ def parse_rows_with_columns(
544
  amount = normalize_num_str(amt_str, allow_zero=False)
545
 
546
  if amount is None:
547
- # Fallback: take largest numeric value
548
  for v in numeric_values:
549
  if v > 0:
550
  amount = v
551
  break
552
 
553
- # Try second-to-right for rate
554
  if num_cols >= 2:
555
  bucket = numeric_buckets.get(num_cols - 2, [])
556
  if bucket:
557
  rate = normalize_num_str(bucket[-1][0], allow_zero=False)
558
 
559
- # Try third-to-right for quantity
560
  if num_cols >= 3:
561
  bucket = numeric_buckets.get(num_cols - 3, [])
562
  if bucket:
563
  qty = normalize_num_str(bucket[-1][0], allow_zero=False)
564
 
565
- # Smart qty/rate inference
566
  if amount and not qty and not rate and numeric_values:
567
  for cand in numeric_values:
568
  if cand <= 0.1 or cand >= amount:
@@ -574,7 +509,6 @@ def parse_rows_with_columns(
574
  rate = cand
575
  break
576
 
577
- # Derive missing values
578
  if qty and rate is None and amount and amount != 0:
579
  rate = amount / qty
580
  elif rate and qty is None and amount and amount != 0:
@@ -582,7 +516,6 @@ def parse_rows_with_columns(
582
  elif amount and qty and rate is None:
583
  rate = amount / qty if qty != 0 else 0.0
584
 
585
- # Defaults
586
  if qty is None:
587
  qty = 1.0
588
  if rate is None:
@@ -590,7 +523,6 @@ def parse_rows_with_columns(
590
  if amount is None:
591
  amount = qty * rate if qty and rate else 0.0
592
 
593
- # Finalize
594
  if amount > 0:
595
  confidence = np.mean([c.get("conf", 0.85) for c in row]) if row else 0.85
596
  items.append(BillLineItem(
@@ -602,7 +534,6 @@ def parse_rows_with_columns(
602
  source_row=row_text,
603
  ))
604
  else:
605
- # Fallback: simple parsing without columns
606
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
607
  if not numeric_idxs:
608
  continue
@@ -628,45 +559,10 @@ def parse_rows_with_columns(
628
  return items
629
 
630
  # -------------------------------------------------------------------------
631
- # Accuracy Validation
632
- # -------------------------------------------------------------------------
633
- def validate_totals(
634
- line_items: List[BillLineItem],
635
- bill_totals: BillTotal,
636
- tolerance_pct: float = 2.0
637
- ) -> Tuple[float, str]:
638
- """
639
- Validate extracted items sum vs bill total.
640
- Returns: (accuracy_score 0-100, validation_msg)
641
- """
642
- if not line_items:
643
- return 0.0, "No line items extracted"
644
-
645
- items_sum = sum(item.item_amount for item in line_items)
646
-
647
- # If we detected a final total, compare
648
- if bill_totals.final_total_amount is not None:
649
- final_total = bill_totals.final_total_amount
650
- diff = abs(items_sum - final_total)
651
- diff_pct = (diff / final_total * 100) if final_total != 0 else 0.0
652
-
653
- if diff_pct <= tolerance_pct:
654
- score = 100.0
655
- msg = f"✓ Extracted total ({items_sum:.2f}) matches bill total ({final_total:.2f})"
656
- else:
657
- # Scale score based on how close
658
- score = max(0.0, 100.0 - (diff_pct * 5))
659
- msg = f"⚠ Mismatch: items_sum={items_sum:.2f}, bill_total={final_total:.2f}, diff={diff_pct:.1f}%"
660
-
661
- return score, msg
662
-
663
- return 85.0, f"No bill total detected; items_sum={items_sum:.2f}"
664
-
665
- # -------------------------------------------------------------------------
666
- # Main OCR Pipelines (Tesseract)
667
  # -------------------------------------------------------------------------
668
  def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]:
669
- """Enhanced Tesseract pipeline"""
670
  pages_out = []
671
 
672
  try:
@@ -681,36 +577,28 @@ def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]:
681
 
682
  for idx, pil_img in enumerate(images, start=1):
683
  try:
684
- # Preprocess & extract
685
  proc = preprocess_image_for_tesseract(pil_img)
686
  cells = image_to_tsv_cells(proc)
687
  rows = group_cells_into_rows(cells, y_tolerance=12)
688
 
689
- # Get page text
690
  page_text = " ".join([" ".join([c["text"] for c in r]) for r in rows])
691
 
692
- # Detect totals early
693
  subtotal, tax, discount, final_total = detect_totals_in_rows(rows)
694
 
695
- # Parse line items
696
  items = parse_rows_with_columns(rows, cells, page_text)
697
 
698
- # Deduplicate
699
  items = dedupe_items_advanced(items)
700
 
701
- # Filter (exclude totals/subtotals)
702
  filtered_items = []
703
  for item in items:
704
  name_lower = item.item_name.lower()
705
 
706
- # Skip if name matches total keywords
707
  if TOTAL_KEYWORDS.search(name_lower) or SUBTOTAL_KEYWORDS.search(name_lower):
708
  continue
709
 
710
  if item.item_amount > 0:
711
  filtered_items.append(item)
712
 
713
- # Create bill totals object
714
  bill_totals = BillTotal(
715
  subtotal_amount=subtotal,
716
  tax_amount=tax,
@@ -718,10 +606,6 @@ def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]:
718
  final_total_amount=final_total,
719
  )
720
 
721
- # Validate
722
- accuracy, val_msg = validate_totals(filtered_items, bill_totals)
723
- logger.info(f"Page {idx}: {val_msg}")
724
-
725
  page_conf = np.mean([item.confidence for item in filtered_items]) if filtered_items else 0.8
726
 
727
  pages_out.append(ExtractedPage(
@@ -747,26 +631,23 @@ def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]:
747
  # -------------------------------------------------------------------------
748
  # FastAPI App
749
  # -------------------------------------------------------------------------
750
- app = FastAPI(title="Enhanced Bill Extractor (Datathon v2)")
751
 
752
  class BillRequest(BaseModel):
753
- document: str # file://path or http(s) URL
754
 
755
  class BillResponse(BaseModel):
756
  is_success: bool
757
  error: Optional[str] = None
758
  data: Dict[str, Any]
759
- accuracy_score: float # 0-100
760
- validation_message: str
761
  token_usage: Dict[str, int]
762
 
763
  @app.post("/extract-bill-data", response_model=BillResponse)
764
  async def extract_bill_data(payload: BillRequest):
765
- """Main extraction endpoint"""
766
  doc_url = payload.document
767
  file_bytes = None
768
 
769
- # Load file
770
  if doc_url.startswith("file://"):
771
  local_path = doc_url.replace("file://", "")
772
  try:
@@ -777,8 +658,6 @@ async def extract_bill_data(payload: BillRequest):
777
  is_success=False,
778
  error=f"Local file read failed: {e}",
779
  data={"pagewise_line_items": [], "total_item_count": 0},
780
- accuracy_score=0.0,
781
- validation_message="File load failed",
782
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
783
  )
784
  else:
@@ -790,8 +669,6 @@ async def extract_bill_data(payload: BillRequest):
790
  is_success=False,
791
  error=f"Download failed (status={resp.status_code})",
792
  data={"pagewise_line_items": [], "total_item_count": 0},
793
- accuracy_score=0.0,
794
- validation_message="HTTP error",
795
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
796
  )
797
  file_bytes = resp.content
@@ -800,8 +677,6 @@ async def extract_bill_data(payload: BillRequest):
800
  is_success=False,
801
  error=f"HTTP error: {e}",
802
  data={"pagewise_line_items": [], "total_item_count": 0},
803
- accuracy_score=0.0,
804
- validation_message="Network error",
805
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
806
  )
807
 
@@ -810,46 +685,28 @@ async def extract_bill_data(payload: BillRequest):
810
  is_success=False,
811
  error="No file bytes",
812
  data={"pagewise_line_items": [], "total_item_count": 0},
813
- accuracy_score=0.0,
814
- validation_message="Empty file",
815
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
816
  )
817
 
818
- # Extract
819
  logger.info(f"Processing with engine: {OCR_ENGINE}")
820
  try:
821
  if OCR_ENGINE == "tesseract":
822
  pages = ocr_with_tesseract(file_bytes)
823
  else:
824
- # Fallback to tesseract
825
  pages = ocr_with_tesseract(file_bytes)
826
  except Exception as e:
827
  logger.exception("OCR failed: %s", e)
828
  pages = []
829
 
830
- # Prepare response
831
  total_items = sum(len(p.line_items) for p in pages)
832
  pages_dict = [p.to_dict() for p in pages]
833
 
834
- # Calculate overall accuracy
835
- all_items = [item for p in pages for item in p.line_items]
836
- all_totals = BillTotal(
837
- subtotal_amount=sum(p.bill_totals.subtotal_amount or 0 for p in pages) or None,
838
- tax_amount=sum(p.bill_totals.tax_amount or 0 for p in pages) or None,
839
- discount_amount=sum(p.bill_totals.discount_amount or 0 for p in pages) or None,
840
- final_total_amount=sum(p.bill_totals.final_total_amount or 0 for p in pages) or None,
841
- )
842
-
843
- overall_acc, msg = validate_totals(all_items, all_totals)
844
-
845
  return BillResponse(
846
  is_success=True,
847
  data={
848
  "pagewise_line_items": pages_dict,
849
  "total_item_count": total_items,
850
  },
851
- accuracy_score=overall_acc,
852
- validation_message=msg,
853
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
854
  )
855
 
@@ -858,10 +715,6 @@ def health():
858
  return {
859
  "status": "ok",
860
  "engine": OCR_ENGINE,
861
- "message": "Enhanced Bill Extractor (Datathon v2 - High Accuracy Mode)",
862
  "hint": "POST /extract-bill-data with {'document': '<url or file://path>'}",
863
  }
864
-
865
- if __name__ == "__main__":
866
- import uvicorn
867
- uvicorn.run(app, host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
4
  import logging
5
  from io import BytesIO
6
+ from typing import List, Dict, Any, Optional, Tuple
7
+ from dataclasses import dataclass, asdict, field
 
8
 
9
  from fastapi import FastAPI
10
  from pydantic import BaseModel
 
28
  except Exception:
29
  vision = None
30
 
 
 
 
 
 
31
  # -------------------------------------------------------------------------
32
  # Configuration
33
  # -------------------------------------------------------------------------
34
  OCR_ENGINE = os.getenv("OCR_ENGINE", "tesseract").lower()
 
 
35
  AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
36
  TESSERACT_PSM = os.getenv("TESSERACT_PSM", "6")
37
 
38
  logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger("bill-extractor")
 
 
 
 
 
 
 
40
 
41
  # Lazy clients
42
  _textract_client = None
 
59
  return _vision_client
60
 
61
  # -------------------------------------------------------------------------
62
+ # Data Models (Clean Output)
63
  # -------------------------------------------------------------------------
64
  @dataclass
65
  class BillLineItem:
 
68
  item_quantity: float = 1.0
69
  item_rate: float = 0.0
70
  item_amount: float = 0.0
71
+ # Internal fields (not exported)
72
+ confidence: float = field(default=1.0, repr=False)
73
+ source_row: str = field(default="", repr=False)
74
+ is_description_continuation: bool = field(default=False, repr=False)
75
 
76
  def to_dict(self) -> Dict[str, Any]:
77
+ """Export only public fields"""
78
+ return {
79
+ "item_name": self.item_name,
80
+ "item_quantity": self.item_quantity,
81
+ "item_rate": self.item_rate,
82
+ "item_amount": self.item_amount,
83
+ }
84
 
85
  @dataclass
86
  class BillTotal:
 
97
  class ExtractedPage:
98
  """Page-level extraction result"""
99
  page_no: int
100
+ page_type: str
101
  line_items: List[BillLineItem]
102
  bill_totals: BillTotal
103
+ page_confidence: float = field(default=1.0, repr=False) # Internal
104
 
105
  def to_dict(self) -> Dict[str, Any]:
106
+ """Export clean output (no confidence/validation)"""
107
  return {
108
  "page_no": self.page_no,
109
  "page_type": self.page_type,
110
  "line_items": [item.to_dict() for item in self.line_items],
111
  "bill_totals": self.bill_totals.to_dict(),
 
112
  }
113
 
114
  # -------------------------------------------------------------------------
115
+ # Regular Expressions
116
  # -------------------------------------------------------------------------
117
  NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?")
118
 
 
119
  TOTAL_KEYWORDS = re.compile(
120
  r"\b(grand\s+total|net\s+payable|total\s+(?:amount|due)|amount\s+payable|bill\s+amount|"
121
  r"final\s+(?:amount|total)|balance\s+due|amount\s+due|total\s+payable|payable)\b",
 
141
  HEADER_KEYWORDS = [
142
  "description", "qty", "qty/hrs", "hrs", "rate", "unit price", "discount",
143
  "net", "amt", "amount", "price", "total", "sl.no", "s.no", "item", "service",
 
144
  ]
145
 
146
  # -------------------------------------------------------------------------
147
  # Text Cleaning & Normalization
148
  # -------------------------------------------------------------------------
149
  def sanitize_ocr_text(s: Optional[str]) -> str:
150
+ """Clean OCR text"""
151
  if not s:
152
  return ""
153
  s = s.replace("\u2014", "-").replace("\u2013", "-")
154
+ s = s.replace("\u00A0", " ")
155
  s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s)
156
  s = s.replace("\r\n", "\n").replace("\r", "\n")
157
  s = re.sub(r"[ \t]+", " ", s)
 
158
  s = re.sub(r"\b(qiy|qty|oty|gty)\b", "qty", s, flags=re.I)
159
  s = re.sub(r"\b(deseription|descriptin|desription)\b", "description", s, flags=re.I)
160
  return s.strip()
 
167
  if s == "":
168
  return None
169
 
 
170
  negative = False
171
  if s.startswith("(") and s.endswith(")"):
172
  negative = True
173
  s = s[1:-1]
174
 
 
175
  s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s)
176
  s = s.replace(",", "")
177
 
 
196
  s = s.replace("—", "-").replace("–", "-")
197
  s = re.sub(r"\s+", " ", s)
198
  s = s.strip(" -:,.=()[]{}|\\")
199
+ s = re.sub(r"\bOR\b", "DR", s)
200
  return s.strip()
201
 
202
  # -------------------------------------------------------------------------
 
209
  return (name_norm, amount_rounded)
210
 
211
  def dedupe_items_advanced(items: List[BillLineItem]) -> List[BillLineItem]:
212
+ """Remove duplicates while preserving highest-confidence versions"""
 
 
 
213
  if not items:
214
  return []
215
 
 
216
  seen: Dict[Tuple, BillLineItem] = {}
217
  for item in items:
218
  fp = item_fingerprint(item)
219
  if fp not in seen or item.confidence > seen[fp].confidence:
220
  seen[fp] = item
221
 
 
222
  final = []
223
  for item in seen.values():
224
  if item.is_description_continuation:
 
225
  if final and abs(float(final[-1].item_amount) - float(item.item_amount)) < 0.01:
 
226
  final[-1].item_name = (final[-1].item_name + " " + item.item_name).strip()
227
  continue
228
  final.append(item)
 
232
  # -------------------------------------------------------------------------
233
  # Total/Subtotal Detection
234
  # -------------------------------------------------------------------------
235
+ FINAL_TOTAL_KEYWORDS = re.compile(
236
+ r"\b(grand\s+total|final\s+(?:total|amount)|total\s+(?:due|payable|amount)|"
237
+ r"net\s+payable|amount\s+(?:due|payable)|balance\s+due|payable)\b",
238
+ re.I
239
+ )
240
+
241
  def detect_totals_in_rows(rows: List[List[Dict[str, Any]]]) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
242
+ """Scan rows for subtotal, tax, discount, final total"""
 
 
 
243
  subtotal = None
244
  tax = None
245
  discount = None
246
  final_total = None
247
 
 
248
  for row in rows:
249
  row_text = " ".join([c["text"] for c in row])
 
 
 
 
250
  row_lower = row_text.lower()
251
  tokens = row_text.split()
252
 
 
253
  amounts = []
254
  for t in tokens:
255
  if is_numeric_token(t):
 
260
  if not amounts:
261
  continue
262
 
 
263
  amount = max(amounts)
264
 
 
265
  if FINAL_TOTAL_KEYWORDS.search(row_lower):
266
  final_total = amount
267
  elif SUBTOTAL_KEYWORDS.search(row_lower):
 
273
 
274
  return subtotal, tax, discount, final_total
275
 
 
 
 
 
 
 
276
  # -------------------------------------------------------------------------
277
  # Image Preprocessing
278
  # -------------------------------------------------------------------------
279
  def pil_to_cv2(img: Image.Image) -> Any:
 
280
  arr = np.array(img)
281
  if arr.ndim == 2:
282
  return arr
 
287
  pil_img = pil_img.convert("RGB")
288
  w, h = pil_img.size
289
 
 
290
  if w < target_w:
291
  scale = target_w / float(w)
292
  pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
293
 
294
  cv_img = pil_to_cv2(pil_img)
295
 
 
296
  if cv_img.ndim == 3:
297
  gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
298
  else:
299
  gray = cv_img
300
 
 
301
  gray = cv2.fastNlMeansDenoising(gray, h=10)
302
 
 
303
  try:
304
  bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
305
  cv2.THRESH_BINARY, 41, 15)
306
  except Exception:
307
  _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
308
 
 
309
  kernel = np.ones((2, 2), np.uint8)
310
  bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel)
311
  bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel)
 
344
 
345
  cells.append({
346
  "text": txt,
347
+ "conf": max(0.0, conf) / 100.0,
348
  "left": left, "top": top, "width": width, "height": height,
349
  "center_x": center_x, "center_y": center_y
350
  })
 
376
  return rows
377
 
378
  # -------------------------------------------------------------------------
379
+ # Column Detection
380
  # -------------------------------------------------------------------------
381
  def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]:
382
  """Detect x-positions of numeric columns"""
 
388
  if len(xs) == 1:
389
  return xs
390
 
 
391
  gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)]
392
  mean_gap = float(np.mean(gaps))
393
  std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0
 
417
  return int(np.argmin(distances))
418
 
419
  # -------------------------------------------------------------------------
420
+ # Row Parsing
421
  # -------------------------------------------------------------------------
422
  def parse_rows_with_columns(
423
  rows: List[List[Dict[str, Any]]],
424
  page_cells: List[Dict[str, Any]],
425
  page_text: str = ""
426
  ) -> List[BillLineItem]:
427
+ """Parse rows into line items"""
 
 
 
428
  items = []
429
  column_centers = detect_numeric_columns(page_cells, max_columns=6)
430
 
431
+ for row in rows:
432
  tokens = [c["text"] for c in row]
433
  row_text = " ".join(tokens)
434
  row_lower = row_text.lower()
435
 
 
436
  if FOOTER_KEYWORDS.search(row_lower) and not any(is_numeric_token(t) for t in tokens):
437
  continue
438
 
 
439
  if not any(is_numeric_token(t) for t in tokens):
440
  continue
441
 
 
442
  numeric_values = []
443
  for t in tokens:
444
  if is_numeric_token(t):
 
451
 
452
  numeric_values = sorted(list(set(numeric_values)), reverse=True)
453
 
 
454
  if column_centers:
455
  left_text_parts = []
456
  numeric_buckets = {i: [] for i in range(len(column_centers))}
 
471
  item_name = " ".join(left_text_parts).strip()
472
  item_name = clean_item_name(item_name) if item_name else "UNKNOWN"
473
 
 
474
  num_cols = len(column_centers)
475
  amount = None
476
  rate = None
477
  qty = None
478
 
 
479
  if num_cols >= 1:
480
  bucket = numeric_buckets.get(num_cols - 1, [])
481
  if bucket:
 
483
  amount = normalize_num_str(amt_str, allow_zero=False)
484
 
485
  if amount is None:
 
486
  for v in numeric_values:
487
  if v > 0:
488
  amount = v
489
  break
490
 
 
491
  if num_cols >= 2:
492
  bucket = numeric_buckets.get(num_cols - 2, [])
493
  if bucket:
494
  rate = normalize_num_str(bucket[-1][0], allow_zero=False)
495
 
 
496
  if num_cols >= 3:
497
  bucket = numeric_buckets.get(num_cols - 3, [])
498
  if bucket:
499
  qty = normalize_num_str(bucket[-1][0], allow_zero=False)
500
 
 
501
  if amount and not qty and not rate and numeric_values:
502
  for cand in numeric_values:
503
  if cand <= 0.1 or cand >= amount:
 
509
  rate = cand
510
  break
511
 
 
512
  if qty and rate is None and amount and amount != 0:
513
  rate = amount / qty
514
  elif rate and qty is None and amount and amount != 0:
 
516
  elif amount and qty and rate is None:
517
  rate = amount / qty if qty != 0 else 0.0
518
 
 
519
  if qty is None:
520
  qty = 1.0
521
  if rate is None:
 
523
  if amount is None:
524
  amount = qty * rate if qty and rate else 0.0
525
 
 
526
  if amount > 0:
527
  confidence = np.mean([c.get("conf", 0.85) for c in row]) if row else 0.85
528
  items.append(BillLineItem(
 
534
  source_row=row_text,
535
  ))
536
  else:
 
537
  numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)]
538
  if not numeric_idxs:
539
  continue
 
559
  return items
560
 
561
  # -------------------------------------------------------------------------
562
+ # Tesseract OCR Pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  # -------------------------------------------------------------------------
564
  def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]:
565
+ """Tesseract pipeline"""
566
  pages_out = []
567
 
568
  try:
 
577
 
578
  for idx, pil_img in enumerate(images, start=1):
579
  try:
 
580
  proc = preprocess_image_for_tesseract(pil_img)
581
  cells = image_to_tsv_cells(proc)
582
  rows = group_cells_into_rows(cells, y_tolerance=12)
583
 
 
584
  page_text = " ".join([" ".join([c["text"] for c in r]) for r in rows])
585
 
 
586
  subtotal, tax, discount, final_total = detect_totals_in_rows(rows)
587
 
 
588
  items = parse_rows_with_columns(rows, cells, page_text)
589
 
 
590
  items = dedupe_items_advanced(items)
591
 
 
592
  filtered_items = []
593
  for item in items:
594
  name_lower = item.item_name.lower()
595
 
 
596
  if TOTAL_KEYWORDS.search(name_lower) or SUBTOTAL_KEYWORDS.search(name_lower):
597
  continue
598
 
599
  if item.item_amount > 0:
600
  filtered_items.append(item)
601
 
 
602
  bill_totals = BillTotal(
603
  subtotal_amount=subtotal,
604
  tax_amount=tax,
 
606
  final_total_amount=final_total,
607
  )
608
 
 
 
 
 
609
  page_conf = np.mean([item.confidence for item in filtered_items]) if filtered_items else 0.8
610
 
611
  pages_out.append(ExtractedPage(
 
631
  # -------------------------------------------------------------------------
632
  # FastAPI App
633
  # -------------------------------------------------------------------------
634
+ app = FastAPI(title="Enhanced Bill Extractor (Clean Output)")
635
 
636
  class BillRequest(BaseModel):
637
+ document: str
638
 
639
  class BillResponse(BaseModel):
640
  is_success: bool
641
  error: Optional[str] = None
642
  data: Dict[str, Any]
 
 
643
  token_usage: Dict[str, int]
644
 
645
  @app.post("/extract-bill-data", response_model=BillResponse)
646
  async def extract_bill_data(payload: BillRequest):
647
+ """Main extraction endpoint (clean output)"""
648
  doc_url = payload.document
649
  file_bytes = None
650
 
 
651
  if doc_url.startswith("file://"):
652
  local_path = doc_url.replace("file://", "")
653
  try:
 
658
  is_success=False,
659
  error=f"Local file read failed: {e}",
660
  data={"pagewise_line_items": [], "total_item_count": 0},
 
 
661
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
662
  )
663
  else:
 
669
  is_success=False,
670
  error=f"Download failed (status={resp.status_code})",
671
  data={"pagewise_line_items": [], "total_item_count": 0},
 
 
672
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
673
  )
674
  file_bytes = resp.content
 
677
  is_success=False,
678
  error=f"HTTP error: {e}",
679
  data={"pagewise_line_items": [], "total_item_count": 0},
 
 
680
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
681
  )
682
 
 
685
  is_success=False,
686
  error="No file bytes",
687
  data={"pagewise_line_items": [], "total_item_count": 0},
 
 
688
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
689
  )
690
 
 
691
  logger.info(f"Processing with engine: {OCR_ENGINE}")
692
  try:
693
  if OCR_ENGINE == "tesseract":
694
  pages = ocr_with_tesseract(file_bytes)
695
  else:
 
696
  pages = ocr_with_tesseract(file_bytes)
697
  except Exception as e:
698
  logger.exception("OCR failed: %s", e)
699
  pages = []
700
 
 
701
  total_items = sum(len(p.line_items) for p in pages)
702
  pages_dict = [p.to_dict() for p in pages]
703
 
 
 
 
 
 
 
 
 
 
 
 
704
  return BillResponse(
705
  is_success=True,
706
  data={
707
  "pagewise_line_items": pages_dict,
708
  "total_item_count": total_items,
709
  },
 
 
710
  token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0},
711
  )
712
 
 
715
  return {
716
  "status": "ok",
717
  "engine": OCR_ENGINE,
718
+ "message": "Enhanced Bill Extractor (Clean Output Mode)",
719
  "hint": "POST /extract-bill-data with {'document': '<url or file://path>'}",
720
  }