Spaces:
Sleeping
Sleeping
| # Enhanced Bill Extraction API (Improved Name Detection) | |
| # Focused on: Accurate item name extraction with intelligent cleaning | |
| # | |
| # Improvements: | |
| # 1. Advanced name normalization and cleaning | |
| # 2. OCR error correction for common names | |
| # 3. Smart multi-word item detection | |
| # 4. Context-aware name validation | |
| # 5. Medical/pharmacy/retail term recognition | |
| # 6. Remove junk characters and formatting | |
| # 7. Consolidate similar names (fuzzy matching) | |
| import os | |
| import re | |
| import json | |
| import logging | |
| from io import BytesIO | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass, asdict, field | |
| from difflib import SequenceMatcher | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import requests | |
| from PIL import Image | |
| from pdf2image import convert_from_bytes | |
| import numpy as np | |
| import cv2 | |
| import pytesseract | |
| from pytesseract import Output | |
| try: | |
| import boto3 | |
| except Exception: | |
| boto3 = None | |
| try: | |
| from google.cloud import vision | |
| except Exception: | |
| vision = None | |
| # ------------------------------------------------------------------------- | |
| # Configuration | |
| # ------------------------------------------------------------------------- | |
| OCR_ENGINE = os.getenv("OCR_ENGINE", "tesseract").lower() | |
| AWS_REGION = os.getenv("AWS_REGION", "us-east-1") | |
| TESSERACT_PSM = os.getenv("TESSERACT_PSM", "6") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("bill-extractor-improved") | |
| _textract_client = None | |
| _vision_client = None | |
| def textract_client(): | |
| global _textract_client | |
| if _textract_client is None: | |
| if boto3 is None: | |
| raise RuntimeError("boto3 not installed") | |
| _textract_client = boto3.client("textract", region_name=AWS_REGION) | |
| return _textract_client | |
| def vision_client(): | |
| global _vision_client | |
| if _vision_client is None: | |
| if vision is None: | |
| raise RuntimeError("google-cloud-vision not installed") | |
| _vision_client = vision.ImageAnnotatorClient() | |
| return _vision_client | |
| # ------------------------------------------------------------------------- | |
| # Header Detection for Tables | |
| # ------------------------------------------------------------------------- | |
| HEADER_KEYWORDS = [ | |
| "description", "qty", "hrs", "rate", "discount", "net", "amt", "amount", | |
| "consultation", "address", "sex", "age", "mobile", "patient", "category", | |
| "doctor", "dr", "invoice", "bill", "subtotal", "total", "charges", "service" | |
| ] | |
| HEADER_PHRASES = [ | |
| "description qty / hrs consultation rate discount net amt", | |
| "description qty / hrs rate discount net amt", | |
| "description qty / hrs rate net amt", | |
| "description qty hrs rate discount net amt", | |
| "description qty / hrs rate discount net amt", | |
| ] | |
| HEADER_PHRASES = [h.lower() for h in HEADER_PHRASES] | |
| # ------------------------------------------------------------------------- | |
| # Enhanced Name Correction Dictionary | |
| # ------------------------------------------------------------------------- | |
| OCR_CORRECTIONS = { | |
| # Medical terms | |
| "consuitation": "Consultation", | |
| "consulation": "Consultation", | |
| "consultatior": "Consultation", | |
| "consultaion": "Consultation", | |
| "consultion": "Consultation", | |
| "consultaon": "Consultation", | |
| "consuftation": "Consultation", | |
| # Lab tests | |
| "cbc": "Complete Blood Count (CBC)", | |
| "lft": "Liver Function Test (LFT)", | |
| "rft": "Renal Function Test (RFT)", | |
| "thyroid": "Thyroid Profile", | |
| "lipid": "Lipid Profile", | |
| "sugar": "Blood Sugar Test", | |
| "glucose": "Blood Glucose", | |
| "haemoglobin": "Hemoglobin", | |
| "hemoglobin": "Hemoglobin", | |
| # Procedures | |
| "xray": "X-Ray", | |
| "x-ray": "X-Ray", | |
| "xra": "X-Ray", | |
| "ctscan": "CT Scan", | |
| "ct-scan": "CT Scan", | |
| "ultrasound": "Ultrasound", | |
| "mri": "MRI Scan", | |
| "ecg": "ECG", | |
| "ekg": "ECG", | |
| # Medicines | |
| "amoxicilin": "Amoxicillin", | |
| "amoxicilen": "Amoxicillin", | |
| "antibiotic": "Antibiotic", | |
| "paracetamol": "Paracetamol", | |
| "cough-syrup": "Cough Syrup", | |
| "coughsyrup": "Cough Syrup", | |
| # Pharmacy | |
| "strip": "Strip", | |
| "tablet": "Tablet", | |
| "capsuie": "Capsule", | |
| "capsule": "Capsule", | |
| "bottle": "Bottle", | |
| "ml": "ml", | |
| # Pharmacy/Retail | |
| "pack": "Pack", | |
| "box": "Box", | |
| "blister": "Blister", | |
| "nos": "Nos", | |
| "pcs": "Pcs", | |
| } | |
| # Medical/pharmacy keywords to recognize item types | |
| MEDICAL_KEYWORDS = { | |
| "consultation", "check-up", "checkup", "visit", "appointment", | |
| "diagnosis", "treatment", "examination", "exam", | |
| } | |
| LAB_TEST_KEYWORDS = { | |
| "test", "cbc", "lft", "rft", "blood", "urine", "stool", "sample", | |
| "profile", "thyroid", "lipid", "glucose", "hemoglobin", "sugar", | |
| "covid", "screening", "culture", "pathology", | |
| } | |
| PROCEDURE_KEYWORDS = { | |
| "xray", "x-ray", "scan", "ultrasound", "ct", "mri", "echo", "ecg", | |
| "procedure", "surgery", "operation", "imaging", "radiography", | |
| "endoscopy", "colonoscopy", "sonography", | |
| } | |
| MEDICINE_KEYWORDS = { | |
| "tablet", "capsule", "strip", "bottle", "syrup", "cream", "ointment", | |
| "injection", "medicine", "drug", "antibiotic", "paracetamol", | |
| "aspirin", "cough", "vitamin", "supplement", | |
| } | |
| # ------------------------------------------------------------------------- | |
| # Data Models | |
| # ------------------------------------------------------------------------- | |
| class BillLineItem: | |
| """Represents a single line item in a bill""" | |
| item_name: str | |
| item_quantity: float = 1.0 | |
| item_rate: float = 0.0 | |
| item_amount: float = 0.0 | |
| # Internal fields (not exported) | |
| confidence: float = field(default=1.0, repr=False) | |
| source_row: str = field(default="", repr=False) | |
| is_description_continuation: bool = field(default=False, repr=False) | |
| name_confidence: float = field(default=1.0, repr=False) # Name-specific confidence | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Export only public fields""" | |
| return { | |
| "item_name": self.item_name, | |
| "item_quantity": self.item_quantity, | |
| "item_rate": self.item_rate, | |
| "item_amount": self.item_amount, | |
| } | |
| class BillTotal: | |
| """Subtotal and total information""" | |
| subtotal_amount: Optional[float] = None | |
| tax_amount: Optional[float] = None | |
| discount_amount: Optional[float] = None | |
| final_total_amount: Optional[float] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| return {k: v for k, v in asdict(self).items() if v is not None} | |
| class ExtractedPage: | |
| """Page-level extraction result""" | |
| page_no: int | |
| page_type: str | |
| line_items: List[BillLineItem] | |
| bill_totals: BillTotal | |
| page_confidence: float = field(default=1.0, repr=False) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Export clean output""" | |
| return { | |
| "page_no": self.page_no, | |
| "page_type": self.page_type, | |
| "line_items": [item.to_dict() for item in self.line_items], | |
| "bill_totals": self.bill_totals.to_dict(), | |
| } | |
| # ------------------------------------------------------------------------- | |
| # Advanced Name Cleaning & Validation | |
| # ------------------------------------------------------------------------- | |
| def correct_ocr_errors(text: str) -> str: | |
| """Correct common OCR errors in text""" | |
| text_lower = text.lower().strip() | |
| # Check dictionary | |
| if text_lower in OCR_CORRECTIONS: | |
| return OCR_CORRECTIONS[text_lower] | |
| # Try substring match for common errors | |
| for wrong, correct in OCR_CORRECTIONS.items(): | |
| if wrong in text_lower: | |
| text = text.replace(wrong, correct) | |
| text = text.replace(wrong.upper(), correct.upper()) | |
| return text | |
| def normalize_name(s: str) -> str: | |
| """Deep normalization of item names""" | |
| if not s: | |
| return "UNKNOWN" | |
| # 1. Strip and basic cleanup | |
| s = s.strip() | |
| # 2. Remove extra spaces | |
| s = re.sub(r'\s+', ' ', s) | |
| # 3. Fix common separators | |
| s = s.replace('|', ' ') | |
| s = s.replace('||', ' ') | |
| s = s.replace('/', ' / ') | |
| s = re.sub(r'\s+/\s+', ' / ', s) | |
| # 4. Remove leading/trailing junk | |
| s = s.strip(' -:,.=()[]{}|\\/') | |
| # 5. OCR error correction | |
| s = correct_ocr_errors(s) | |
| # 6. Capitalize properly | |
| s = capitalize_name(s) | |
| # 7. Remove duplicate words | |
| words = s.split() | |
| seen = set() | |
| unique_words = [] | |
| for word in words: | |
| word_lower = word.lower() | |
| if word_lower not in seen or len(seen) < 3: # Allow some repetition | |
| unique_words.append(word) | |
| seen.add(word_lower) | |
| s = ' '.join(unique_words) | |
| # 8. Final trim | |
| s = s.strip() | |
| return s if s else "UNKNOWN" | |
| def capitalize_name(s: str) -> str: | |
| """Intelligent capitalization for names""" | |
| if not s: | |
| return s | |
| # Special cases (all caps) | |
| all_caps = ["CBC", "LFT", "RFT", "ECG", "EKG", "MRI", "CT", "COVID", "GST", "SGST", "CGST"] | |
| for term in all_caps: | |
| pattern = re.compile(r'\b' + term.lower() + r'\b', re.I) | |
| s = pattern.sub(term, s) | |
| # Title case for regular terms | |
| words = s.split() | |
| result = [] | |
| for word in words: | |
| # Don't capitalize small words between | |
| if word.lower() in ["for", "the", "and", "or", "in", "of", "to", "a", "an", "ml", "mg", "mg/ml"]: | |
| if result: # Not first word | |
| result.append(word.lower()) | |
| else: | |
| result.append(word.capitalize()) | |
| else: | |
| result.append(word.capitalize()) | |
| return ' '.join(result) | |
| def validate_name(name: str, context_amount: float = 0) -> Tuple[str, float]: | |
| """ | |
| Validate and enhance name with context awareness. | |
| Returns: (validated_name, confidence_score) | |
| """ | |
| if not name or name == "UNKNOWN": | |
| return "UNKNOWN", 0.0 | |
| name_lower = name.lower() | |
| confidence = 0.85 # Default | |
| # Medical consultation context | |
| if any(kw in name_lower for kw in MEDICAL_KEYWORDS): | |
| confidence = 0.95 | |
| if context_amount > 0 and context_amount < 2000: | |
| confidence = 0.98 # Typical consultation price range | |
| # Lab test context | |
| elif any(kw in name_lower for kw in LAB_TEST_KEYWORDS): | |
| confidence = 0.92 | |
| if context_amount > 0 and context_amount < 5000: | |
| confidence = 0.96 | |
| # Procedure context | |
| elif any(kw in name_lower for kw in PROCEDURE_KEYWORDS): | |
| confidence = 0.90 | |
| if context_amount > 0 and context_amount < 10000: | |
| confidence = 0.94 | |
| # Medicine context | |
| elif any(kw in name_lower for kw in MEDICINE_KEYWORDS): | |
| confidence = 0.88 | |
| if context_amount > 0 and context_amount < 500: | |
| confidence = 0.92 | |
| # Length penalty (too short = less confident) | |
| if len(name) < 3: | |
| confidence *= 0.7 | |
| # Length bonus (reasonable length) | |
| elif 5 <= len(name) <= 50: | |
| confidence = min(1.0, confidence + 0.05) | |
| # Remove redundant text | |
| name = remove_redundant_text(name) | |
| return name, min(1.0, confidence) | |
| def remove_redundant_text(name: str) -> str: | |
| """Remove redundant or unnecessary words""" | |
| if not name: | |
| return name | |
| name_lower = name.lower() | |
| # Remove common redundant patterns | |
| patterns = [ | |
| r'\b(item|name|description|service|product)\b', | |
| r'\b(ref|reference)\s*:?\s*', | |
| r'\b(qty|quantity)\b', | |
| r'\b(unit|units)\b', | |
| r'^-+\s*|-+$', # Leading/trailing dashes | |
| r'\s+x\s+$', # Trailing "x" | |
| r'\s+,\s*$', # Trailing comma | |
| ] | |
| for pattern in patterns: | |
| name = re.sub(pattern, '', name, flags=re.I) | |
| return name.strip() | |
| def merge_similar_names(items: List[BillLineItem], similarity_threshold: float = 0.85) -> List[BillLineItem]: | |
| """ | |
| Merge items with very similar names. | |
| Example: "Consultation" and "Consultation for checkup" → "Consultation for checkup" | |
| """ | |
| if len(items) <= 1: | |
| return items | |
| merged = [] | |
| used_indices = set() | |
| for i, item1 in enumerate(items): | |
| if i in used_indices: | |
| continue | |
| # Find similar items | |
| similar_group = [item1] | |
| for j, item2 in enumerate(items[i+1:], start=i+1): | |
| if j in used_indices: | |
| continue | |
| # Calculate similarity | |
| sim = SequenceMatcher(None, | |
| item1.item_name.lower(), | |
| item2.item_name.lower()).ratio() | |
| if sim > similarity_threshold: | |
| # Keep the longer, more detailed name | |
| if len(item2.item_name) > len(item1.item_name): | |
| similar_group = [item2] + similar_group | |
| similar_group.append(item2) | |
| used_indices.add(j) | |
| # Use the best (longest/most detailed) name | |
| best_item = max(similar_group, key=lambda x: (len(x.item_name), x.name_confidence)) | |
| merged.append(best_item) | |
| used_indices.add(i) | |
| return merged | |
| # ------------------------------------------------------------------------- | |
| # Regular Expressions (Enhanced) | |
| # ------------------------------------------------------------------------- | |
| NUM_RE = re.compile(r"[-+]?\d{1,3}(?:[,0-9]*)(?:\.\d+)?") | |
| TOTAL_KEYWORDS = re.compile( | |
| r"\b(grand\s+total|net\s+payable|total\s+(?:amount|due)|amount\s+payable|bill\s+amount|" | |
| r"final\s+(?:amount|total)|balance\s+due|amount\s+due|total\s+payable|payable)\b", | |
| re.I | |
| ) | |
| SUBTOTAL_KEYWORDS = re.compile( | |
| r"\b(sub\s*[\-\s]?total|subtotal|sub\s+total|items\s+total|line\s+items\s+total)\b", | |
| re.I | |
| ) | |
| TAX_KEYWORDS = re.compile( | |
| r"\b(tax|gst|vat|sgst|cgst|igst|sales\s+tax|service\s+tax)\b", | |
| re.I | |
| ) | |
| DISCOUNT_KEYWORDS = re.compile( | |
| r"\b(discount|rebate|deduction)\b", | |
| re.I | |
| ) | |
| FOOTER_KEYWORDS = re.compile( | |
| r"(page|printed\s+on|printed|date|time|signature|authorized|terms|conditions)", | |
| re.I | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Text Cleaning & Normalization | |
| # ------------------------------------------------------------------------- | |
| def sanitize_ocr_text(s: Optional[str]) -> str: | |
| """Clean OCR text""" | |
| if not s: | |
| return "" | |
| s = s.replace("\u2014", "-").replace("\u2013", "-") | |
| s = s.replace("\u00A0", " ") | |
| s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E]", " ", s) | |
| s = s.replace("\r\n", "\n").replace("\r", "\n") | |
| s = re.sub(r"[ \t]+", " ", s) | |
| s = re.sub(r"\b(qiy|qty|oty|gty)\b", "qty", s, flags=re.I) | |
| s = re.sub(r"\b(deseription|descriptin|desription)\b", "description", s, flags=re.I) | |
| return s.strip() | |
| def normalize_num_str(s: Optional[str], allow_zero: bool = False) -> Optional[float]: | |
| """Robust number parsing""" | |
| if s is None: | |
| return None | |
| s = str(s).strip() | |
| if s == "": | |
| return None | |
| negative = False | |
| if s.startswith("(") and s.endswith(")"): | |
| negative = True | |
| s = s[1:-1] | |
| s = re.sub(r"[^\d\-\+\,\.\(\)]", "", s) | |
| s = s.replace(",", "") | |
| if s in ("", "-", "+"): | |
| return None | |
| try: | |
| val = float(s) | |
| val = -val if negative else val | |
| if val == 0 and not allow_zero: | |
| return None | |
| return val | |
| except Exception: | |
| return None | |
| def is_numeric_token(t: Optional[str]) -> bool: | |
| """Check if token is numeric""" | |
| return bool(t and NUM_RE.search(str(t))) | |
| # ------------------------------------------------------------------------- | |
| # Item Fingerprinting | |
| # ------------------------------------------------------------------------- | |
| def item_fingerprint(item: BillLineItem) -> Tuple[str, float]: | |
| """Create fingerprint for deduplication""" | |
| name_norm = re.sub(r"\s+", " ", item.item_name.lower()).strip()[:100] | |
| amount_rounded = round(float(item.item_amount), 2) | |
| return (name_norm, amount_rounded) | |
| def dedupe_items_advanced(items: List[BillLineItem]) -> List[BillLineItem]: | |
| """Remove duplicates with improved name handling""" | |
| if not items: | |
| return [] | |
| seen: Dict[Tuple, BillLineItem] = {} | |
| for item in items: | |
| fp = item_fingerprint(item) | |
| if fp not in seen or item.confidence > seen[fp].confidence: | |
| seen[fp] = item | |
| final = list(seen.values()) | |
| # Merge similar names | |
| final = merge_similar_names(final, similarity_threshold=0.85) | |
| return final | |
| # ------------------------------------------------------------------------- | |
| # Total Detection | |
| # ------------------------------------------------------------------------- | |
| FINAL_TOTAL_KEYWORDS = re.compile( | |
| r"\b(grand\s+total|final\s+(?:total|amount)|total\s+(?:due|payable|amount)|" | |
| r"net\s+payable|amount\s+(?:due|payable)|balance\s+due|payable)\b", | |
| re.I | |
| ) | |
| def detect_totals_in_rows(rows: List[List[Dict[str, Any]]]) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]: | |
| """Scan rows for subtotal, tax, discount, final total""" | |
| subtotal = None | |
| tax = None | |
| discount = None | |
| final_total = None | |
| for row in rows: | |
| row_text = " ".join([c["text"] for c in row]) | |
| row_lower = row_text.lower() | |
| header_hit_count = sum(1 for h in HEADER_KEYWORDS if h in row_lower) | |
| if any(phrase in row_lower for phrase in HEADER_PHRASES) or header_hit_count >= 3: | |
| continue | |
| tokens = row_text.split() | |
| amounts = [] | |
| for t in tokens: | |
| if is_numeric_token(t): | |
| v = normalize_num_str(t, allow_zero=True) | |
| if v is not None: | |
| amounts.append(v) | |
| if not amounts: | |
| continue | |
| amount = max(amounts) | |
| if FINAL_TOTAL_KEYWORDS.search(row_lower): | |
| final_total = amount | |
| elif SUBTOTAL_KEYWORDS.search(row_lower): | |
| subtotal = amount | |
| elif TAX_KEYWORDS.search(row_lower): | |
| tax = amount | |
| elif DISCOUNT_KEYWORDS.search(row_lower): | |
| discount = amount | |
| return subtotal, tax, discount, final_total | |
| # ------------------------------------------------------------------------- | |
| # Image Preprocessing | |
| # ------------------------------------------------------------------------- | |
| def pil_to_cv2(img: Image.Image) -> Any: | |
| arr = np.array(img) | |
| if arr.ndim == 2: | |
| return arr | |
| return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) | |
| def preprocess_image_for_tesseract(pil_img: Image.Image, target_w: int = 1500) -> Any: | |
| """Enhanced preprocessing""" | |
| pil_img = pil_img.convert("RGB") | |
| w, h = pil_img.size | |
| if w < target_w: | |
| scale = target_w / float(w) | |
| pil_img = pil_img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| cv_img = pil_to_cv2(pil_img) | |
| if cv_img.ndim == 3: | |
| gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = cv_img | |
| gray = cv2.fastNlMeansDenoising(gray, h=10) | |
| try: | |
| bw = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
| cv2.THRESH_BINARY, 41, 15) | |
| except Exception: | |
| _, bw = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) | |
| kernel = np.ones((2, 2), np.uint8) | |
| bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, kernel) | |
| bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, kernel) | |
| return bw | |
| def image_to_tsv_cells(cv_img: Any) -> List[Dict[str, Any]]: | |
| """Extract OCR cells from image""" | |
| try: | |
| o = pytesseract.image_to_data(cv_img, output_type=Output.DICT, config=f"--psm {TESSERACT_PSM}") | |
| except Exception: | |
| o = pytesseract.image_to_data(cv_img, output_type=Output.DICT) | |
| cells = [] | |
| n = len(o.get("text", [])) | |
| for i in range(n): | |
| raw = o["text"][i] | |
| if raw is None: | |
| continue | |
| txt = str(raw).strip() | |
| if not txt: | |
| continue | |
| try: | |
| conf_raw = o.get("conf", [])[i] | |
| conf = float(conf_raw) if conf_raw not in (None, "", "-1") else -1.0 | |
| except Exception: | |
| conf = -1.0 | |
| left = int(o.get("left", [0])[i]) | |
| top = int(o.get("top", [0])[i]) | |
| width = int(o.get("width", [0])[i]) | |
| height = int(o.get("height", [0])[i]) | |
| center_y = top + height / 2.0 | |
| center_x = left + width / 2.0 | |
| cells.append({ | |
| "text": txt, | |
| "conf": max(0.0, conf) / 100.0, | |
| "left": left, "top": top, "width": width, "height": height, | |
| "center_x": center_x, "center_y": center_y | |
| }) | |
| return cells | |
| def group_cells_into_rows(cells: List[Dict[str, Any]], y_tolerance: int = 12) -> List[List[Dict[str, Any]]]: | |
| """Group cells by horizontal position (rows)""" | |
| if not cells: | |
| return [] | |
| sorted_cells = sorted(cells, key=lambda c: (c["center_y"], c["center_x"])) | |
| rows = [] | |
| current = [sorted_cells[0]] | |
| last_y = sorted_cells[0]["center_y"] | |
| for c in sorted_cells[1:]: | |
| if abs(c["center_y"] - last_y) <= y_tolerance: | |
| current.append(c) | |
| last_y = (last_y * (len(current) - 1) + c["center_y"]) / len(current) | |
| else: | |
| rows.append(sorted(current, key=lambda cc: cc["left"])) | |
| current = [c] | |
| last_y = c["center_y"] | |
| if current: | |
| rows.append(sorted(current, key=lambda cc: cc["left"])) | |
| return rows | |
| # ------------------------------------------------------------------------- | |
| # Column Detection | |
| # ------------------------------------------------------------------------- | |
| def detect_numeric_columns(cells: List[Dict[str, Any]], max_columns: int = 6) -> List[float]: | |
| """Detect x-positions of numeric columns""" | |
| xs = [c["center_x"] for c in cells if is_numeric_token(c["text"])] | |
| if not xs: | |
| return [] | |
| xs = sorted(set(xs)) | |
| if len(xs) == 1: | |
| return xs | |
| gaps = [xs[i+1] - xs[i] for i in range(len(xs)-1)] | |
| mean_gap = float(np.mean(gaps)) | |
| std_gap = float(np.std(gaps)) if len(gaps) > 1 else 0.0 | |
| gap_thresh = max(35.0, mean_gap + 0.7 * std_gap) | |
| clusters = [] | |
| curr = [xs[0]] | |
| for i, g in enumerate(gaps): | |
| if g > gap_thresh and len(clusters) < (max_columns - 1): | |
| clusters.append(curr) | |
| curr = [xs[i+1]] | |
| else: | |
| curr.append(xs[i+1]) | |
| clusters.append(curr) | |
| centers = [float(np.median(c)) for c in clusters] | |
| if len(centers) > max_columns: | |
| centers = centers[-max_columns:] | |
| return sorted(centers) | |
| def assign_token_to_column(token_x: float, column_centers: List[float]) -> Optional[int]: | |
| """Find closest column index for token""" | |
| if not column_centers: | |
| return None | |
| distances = [abs(token_x - cx) for cx in column_centers] | |
| return int(np.argmin(distances)) | |
| # ------------------------------------------------------------------------- | |
| # Row Parsing (Improved Name Handling) | |
| # ------------------------------------------------------------------------- | |
| def parse_rows_with_columns( | |
| rows: List[List[Dict[str, Any]]], | |
| page_cells: List[Dict[str, Any]], | |
| page_text: str = "" | |
| ) -> List[BillLineItem]: | |
| """Parse rows into line items with improved name detection""" | |
| items = [] | |
| column_centers = detect_numeric_columns(page_cells, max_columns=6) | |
| for row in rows: | |
| tokens = [c["text"] for c in row] | |
| row_text = " ".join(tokens) | |
| row_lower = row_text.lower() | |
| if FOOTER_KEYWORDS.search(row_lower) and not any(is_numeric_token(t) for t in tokens): | |
| continue | |
| if not any(is_numeric_token(t) for t in tokens): | |
| continue | |
| numeric_values = [] | |
| for t in tokens: | |
| if is_numeric_token(t): | |
| v = normalize_num_str(t, allow_zero=False) | |
| if v is not None: | |
| numeric_values.append(float(v)) | |
| if not numeric_values: | |
| continue | |
| numeric_values = sorted(list(set(numeric_values)), reverse=True) | |
| if column_centers: | |
| left_text_parts = [] | |
| numeric_buckets = {i: [] for i in range(len(column_centers))} | |
| for c in row: | |
| t = c["text"] | |
| cx = c["center_x"] | |
| conf = c.get("conf", 1.0) | |
| if is_numeric_token(t): | |
| col_idx = assign_token_to_column(cx, column_centers) | |
| if col_idx is None: | |
| col_idx = len(column_centers) - 1 | |
| numeric_buckets[col_idx].append((t, conf)) | |
| else: | |
| left_text_parts.append(t) | |
| raw_name = " ".join(left_text_parts).strip() | |
| # ★ IMPROVED NAME NORMALIZATION | |
| item_name = normalize_name(raw_name) if raw_name else "UNKNOWN" | |
| name_confidence_score = 0.85 | |
| # Validate with context | |
| num_cols = len(column_centers) | |
| amount = None | |
| rate = None | |
| qty = None | |
| if num_cols >= 1: | |
| bucket = numeric_buckets.get(num_cols - 1, []) | |
| if bucket: | |
| amt_str = bucket[-1][0] | |
| amount = normalize_num_str(amt_str, allow_zero=False) | |
| if amount is None: | |
| for v in numeric_values: | |
| if v > 0: | |
| amount = v | |
| break | |
| if num_cols >= 2: | |
| bucket = numeric_buckets.get(num_cols - 2, []) | |
| if bucket: | |
| rate = normalize_num_str(bucket[-1][0], allow_zero=False) | |
| if num_cols >= 3: | |
| bucket = numeric_buckets.get(num_cols - 3, []) | |
| if bucket: | |
| qty = normalize_num_str(bucket[-1][0], allow_zero=False) | |
| if amount and not qty and not rate and numeric_values: | |
| for cand in numeric_values: | |
| if cand <= 0.1 or cand >= amount: | |
| continue | |
| ratio = amount / cand | |
| r = round(ratio) | |
| if 1 <= r <= 100 and abs(ratio - r) <= 0.15 * r: | |
| qty = float(r) | |
| rate = cand | |
| break | |
| if qty and rate is None and amount and amount != 0: | |
| rate = amount / qty | |
| elif rate and qty is None and amount and amount != 0: | |
| qty = amount / rate | |
| elif amount and qty and rate is None: | |
| rate = amount / qty if qty != 0 else 0.0 | |
| if qty is None: | |
| qty = 1.0 | |
| if rate is None: | |
| rate = 0.0 | |
| if amount is None: | |
| amount = qty * rate if qty and rate else 0.0 | |
| if amount > 0: | |
| confidence = np.mean([c.get("conf", 0.85) for c in row]) if row else 0.85 | |
| # ★ VALIDATE NAME WITH CONTEXT | |
| validated_name, name_conf = validate_name(item_name, context_amount=amount) | |
| items.append(BillLineItem( | |
| item_name=validated_name, | |
| item_quantity=float(qty), | |
| item_rate=float(round(rate, 2)), | |
| item_amount=float(round(amount, 2)), | |
| confidence=min(1.0, max(0.0, confidence)), | |
| source_row=row_text, | |
| name_confidence=name_conf, | |
| )) | |
| else: | |
| numeric_idxs = [i for i, t in enumerate(tokens) if is_numeric_token(t)] | |
| if not numeric_idxs: | |
| continue | |
| last = numeric_idxs[-1] | |
| amount = normalize_num_str(tokens[last], allow_zero=False) | |
| if amount is None: | |
| continue | |
| raw_name = " ".join(tokens[:last]).strip() | |
| # ★ IMPROVED NAME NORMALIZATION | |
| name = normalize_name(raw_name) if raw_name else "UNKNOWN" | |
| validated_name, name_conf = validate_name(name, context_amount=amount) | |
| confidence = np.mean([c.get("conf", 0.85) for c in row]) if row else 0.85 | |
| items.append(BillLineItem( | |
| item_name=validated_name, | |
| item_quantity=1.0, | |
| item_rate=0.0, | |
| item_amount=float(round(amount, 2)), | |
| confidence=min(1.0, max(0.0, confidence)), | |
| source_row=row_text, | |
| name_confidence=name_conf, | |
| )) | |
| return items | |
| # ------------------------------------------------------------------------- | |
| # Tesseract OCR Pipeline | |
| # ------------------------------------------------------------------------- | |
| def ocr_with_tesseract(file_bytes: bytes) -> List[ExtractedPage]: | |
| """Tesseract pipeline""" | |
| pages_out = [] | |
| try: | |
| images = convert_from_bytes(file_bytes) | |
| except Exception: | |
| try: | |
| im = Image.open(BytesIO(file_bytes)) | |
| images = [im] | |
| except Exception as e: | |
| logger.exception("Tesseract: file open failed: %s", e) | |
| return [] | |
| for idx, pil_img in enumerate(images, start=1): | |
| try: | |
| proc = preprocess_image_for_tesseract(pil_img) | |
| cells = image_to_tsv_cells(proc) | |
| rows = group_cells_into_rows(cells, y_tolerance=12) | |
| page_text = " ".join([" ".join([c["text"] for c in r]) for r in rows]) | |
| subtotal, tax, discount, final_total = detect_totals_in_rows(rows) | |
| items = parse_rows_with_columns(rows, cells, page_text) | |
| items = dedupe_items_advanced(items) | |
| filtered_items = [] | |
| for item in items: | |
| name_lower = item.item_name.lower() | |
| if TOTAL_KEYWORDS.search(name_lower) or SUBTOTAL_KEYWORDS.search(name_lower): | |
| continue | |
| if item.item_amount > 0: | |
| filtered_items.append(item) | |
| bill_totals = BillTotal( | |
| subtotal_amount=subtotal, | |
| tax_amount=tax, | |
| discount_amount=discount, | |
| final_total_amount=final_total, | |
| ) | |
| page_conf = np.mean([item.confidence for item in filtered_items]) if filtered_items else 0.8 | |
| pages_out.append(ExtractedPage( | |
| page_no=idx, | |
| page_type="Bill Detail", | |
| line_items=filtered_items, | |
| bill_totals=bill_totals, | |
| page_confidence=page_conf, | |
| )) | |
| except Exception as e: | |
| logger.exception(f"Tesseract page {idx} failed: %s", e) | |
| pages_out.append(ExtractedPage( | |
| page_no=idx, | |
| page_type="Bill Detail", | |
| line_items=[], | |
| bill_totals=BillTotal(), | |
| page_confidence=0.0, | |
| )) | |
| return pages_out | |
| # ------------------------------------------------------------------------- | |
| # FastAPI App | |
| # ------------------------------------------------------------------------- | |
| app = FastAPI(title="Enhanced Bill Extractor (Improved Names)") | |
| class BillRequest(BaseModel): | |
| document: str | |
| class BillResponse(BaseModel): | |
| is_success: bool | |
| token_usage: Dict[str, int] | |
| data: Dict[str, Any] | |
| async def extract_bill_data(payload: BillRequest): | |
| """Main extraction endpoint""" | |
| doc_url = payload.document | |
| file_bytes = None | |
| if doc_url.startswith("file://"): | |
| local_path = doc_url.replace("file://", "") | |
| try: | |
| with open(local_path, "rb") as f: | |
| file_bytes = f.read() | |
| except Exception as e: | |
| return BillResponse( | |
| is_success=False, | |
| token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, | |
| data={"pagewise_line_items": [], "total_item_count": 0}, | |
| ) | |
| else: | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| resp = requests.get(doc_url, headers=headers, timeout=30) | |
| if resp.status_code != 200: | |
| return BillResponse( | |
| is_success=False, | |
| token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, | |
| data={"pagewise_line_items": [], "total_item_count": 0}, | |
| ) | |
| file_bytes = resp.content | |
| except Exception as e: | |
| return BillResponse( | |
| is_success=False, | |
| token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, | |
| data={"pagewise_line_items": [], "total_item_count": 0}, | |
| ) | |
| if not file_bytes: | |
| return BillResponse( | |
| is_success=False, | |
| token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, | |
| data={"pagewise_line_items": [], "total_item_count": 0}, | |
| ) | |
| logger.info(f"Processing with engine: {OCR_ENGINE}") | |
| try: | |
| if OCR_ENGINE == "tesseract": | |
| pages = ocr_with_tesseract(file_bytes) | |
| else: | |
| pages = ocr_with_tesseract(file_bytes) | |
| except Exception as e: | |
| logger.exception("OCR failed: %s", e) | |
| pages = [] | |
| total_items = sum(len(p.line_items) for p in pages) | |
| pages_dict = [p.to_dict() for p in pages] | |
| return BillResponse( | |
| is_success=True, | |
| token_usage={"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}, | |
| data={ | |
| "pagewise_line_items": pages_dict, | |
| "total_item_count": total_items, | |
| }, | |
| ) | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "engine": OCR_ENGINE, | |
| "message": "Enhanced Bill Extractor (Improved Name Detection)", | |
| "hint": "POST /extract-bill-data with {'document': '<url or file://path>'}", | |
| } | |