Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| from typing import List | |
| import gradio as gr | |
| # docTR imports (PyTorch backend) | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| # ---------- One-time model bootstrap (CPU-friendly) ---------- | |
| # Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend. | |
| # You can optionally pin threads for stability on small CPU runners: | |
| os.environ.setdefault("OMP_NUM_THREADS", "4") | |
| os.environ.setdefault("MKL_NUM_THREADS", "4") | |
| MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch | |
| def _collect_text_from_export(exported: dict) -> str: | |
| """Flatten docTR exported structure into newline-separated text per page.""" | |
| pages: List[dict] = exported.get("pages", []) | |
| text_pages: List[str] = [] | |
| for page in pages: | |
| page_lines = [] | |
| for block in page.get("blocks", []): | |
| for line in block.get("lines", []): | |
| # Join word values in the line; fallback robustly | |
| words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)] | |
| line_text = " ".join([w for w in words if w]) | |
| if line_text.strip(): | |
| page_lines.append(line_text) | |
| text_pages.append("\n".join(page_lines).strip()) | |
| # Join pages with a page delimiter | |
| return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join( | |
| [tp for tp in text_pages if tp] | |
| ).strip() | |
| def run_ocr(file: gr.File) -> str: | |
| if file is None: | |
| return "No file received." | |
| name = (file.name or "").lower() | |
| # Load as DocumentFile (handles PNG/JPG/PDF) | |
| if name.endswith(".pdf"): | |
| # Render PDF pages via pdfium backend under the hood (CPU OK) | |
| doc = DocumentFile.from_pdf(file=file.name) | |
| else: | |
| # Single image fallback; also works for TIFF/PNG/JPG | |
| doc = DocumentFile.from_images([file.name]) | |
| # Inference | |
| result = MODEL(doc) | |
| exported = result.export() | |
| text = _collect_text_from_export(exported) | |
| print("Extracted Text:\n", text) | |
| if not text: | |
| return "No text detected." | |
| result_json = invoice_text_to_json(text) | |
| print(json.dumps(result_json, indent=2)) | |
| string_json = json.dumps(result_json, indent=2) | |
| return string_json | |
| import re | |
| import json | |
| from typing import List, Dict, Any | |
| import copy | |
| import numpy as np | |
| import torch | |
| from transformers import pipeline | |
| from sentence_transformers import SentenceTransformer, util | |
| # ----------------------------- Schema ----------------------------- | |
| SCHEMA_JSON: Dict[str, Any] = { | |
| "invoice_header": { | |
| "car_number": None, | |
| "shipment_number": None, | |
| "shipping_point": None, | |
| "currency": None, | |
| "invoice_number": None, | |
| "invoice_date": None, | |
| "order_number": None, | |
| "customer_order_number": None, | |
| "our_order_number": None, | |
| "sales_order_number": None, | |
| "purchase_order_number": None, | |
| "order_date": None, | |
| "supplier_name": None, | |
| "supplier_address": None, | |
| "supplier_phone": None, | |
| "supplier_email": None, | |
| "supplier_tax_id": None, | |
| "customer_name": None, | |
| "customer_address": None, | |
| "customer_phone": None, | |
| "customer_email": None, | |
| "customer_tax_id": None, | |
| "ship_to_name": None, | |
| "ship_to_address": None, | |
| "bill_to_name": None, | |
| "bill_to_address": None, | |
| "remit_to_name": None, | |
| "remit_to_address": None, | |
| "tax_id": None, | |
| "tax_registration_number": None, | |
| "vat_number": None, | |
| "payment_terms": None, | |
| "payment_method": None, | |
| "payment_reference": None, | |
| "bank_account_number": None, | |
| "iban": None, | |
| "swift_code": None, | |
| "total_before_tax": None, | |
| "tax_amount": None, | |
| "tax_rate": None, | |
| "shipping_charges": None, | |
| "discount": None, | |
| "total_due": None, | |
| "amount_paid": None, | |
| "balance_due": None, | |
| "due_date": None, | |
| "invoice_status": None, | |
| "reference_number": None, | |
| "project_code": None, | |
| "department": None, | |
| "contact_person": None, | |
| "notes": None, | |
| "additional_info": None | |
| }, | |
| "line_items": [ | |
| { | |
| "quantity": None, | |
| "units": None, | |
| "description": None, | |
| "footage": None, | |
| "price": None, | |
| "amount": None, | |
| "notes": None | |
| } | |
| ] | |
| } | |
| STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys()) | |
| # Synonym map | |
| SYN2KEY: Dict[str, str] = { | |
| "invoice no": "invoice_number", | |
| "invoice number": "invoice_number", | |
| "invoice#": "invoice_number", | |
| "inv no": "invoice_number", | |
| "inv#": "invoice_number", | |
| "invoice date": "invoice_date", | |
| "date of invoice": "invoice_date", | |
| "po no": "purchase_order_number", | |
| "po number": "purchase_order_number", | |
| "purchase order": "purchase_order_number", | |
| "order no": "order_number", | |
| "order number": "order_number", | |
| "sales order": "sales_order_number", | |
| "customer order": "customer_order_number", | |
| "our order": "our_order_number", | |
| "due date": "due_date", | |
| "date of supply": "order_date", | |
| "gstin": "supplier_tax_id", | |
| "gstin no": "supplier_tax_id", | |
| "tax id": "tax_id", | |
| "vat number": "vat_number", | |
| "tax registration number": "tax_registration_number", | |
| "place of supply": "shipping_point", | |
| "state code": "additional_info", | |
| "taxable value": "total_before_tax", | |
| "total value": "total_due", | |
| "total amount": "total_due", | |
| "amount due": "total_due", | |
| "bank": "bank_account_number", | |
| "account no": "bank_account_number", | |
| "account number": "bank_account_number", | |
| "ifs code": "swift_code", | |
| "ifsc": "payment_reference", | |
| "swift code": "swift_code", | |
| "iban": "iban", | |
| "e-way bill no": "reference_number", | |
| "eway bill": "reference_number", | |
| "dispatched via": "additional_info", | |
| "documents dispatched through": "additional_info", | |
| "kind attn": "contact_person", | |
| "billed to": "bill_to_name", | |
| "receiver": "bill_to_name", | |
| "shipped to": "ship_to_name", | |
| "consignee": "ship_to_name", | |
| } | |
| def norm(s: str) -> str: | |
| return re.sub(r"\s+", " ", s).strip() | |
| def deep_copy_schema() -> Dict[str, Any]: | |
| return json.loads(json.dumps(SCHEMA_JSON)) | |
| def extract_candidates(text: str) -> Dict[str, str]: | |
| cands: Dict[str, str] = {} | |
| for raw in text.splitlines(): | |
| line = raw.strip().strip("|").strip() | |
| if not line: | |
| continue | |
| if ":" in line: | |
| if "|" in raw: | |
| parts = [p.strip() for p in raw.split("|") if p.strip()] | |
| for cell in parts: | |
| if ":" in cell: | |
| k, v = cell.split(":", 1) | |
| cands[norm(k)] = norm(v) | |
| else: | |
| k, v = line.split(":", 1) | |
| cands[norm(k)] = norm(v) | |
| for raw in text.splitlines(): | |
| m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I) | |
| if m: | |
| k = norm(m.group(1)) | |
| v = norm(m.group(2)) | |
| cands[k] = v | |
| return cands | |
| def regex_extract_all(text: str) -> Dict[str, str]: | |
| out: Dict[str, str] = {} | |
| m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I) | |
| if m: out["invoice_number"] = m.group(1) | |
| m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) | |
| if m: out["invoice_date"] = m.group(1) | |
| m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I) | |
| if m: out["purchase_order_number"] = m.group(1) | |
| m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) | |
| if m: out["order_date"] = m.group(1) | |
| if "order_date" not in out: | |
| m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) | |
| if m: out["order_date"] = m.group(1) | |
| m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I) | |
| if m: out["shipping_point"] = m.group(1).strip(" |") | |
| m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I) | |
| if m: out["supplier_tax_id"] = m.group(1) | |
| m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) | |
| if m: out["total_before_tax"] = m.group(1).replace(",", "") | |
| cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) | |
| sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) | |
| if cgst and sgst: | |
| try: | |
| tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", "")) | |
| out["tax_amount"] = f"{tax_total:.2f}" | |
| cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I) | |
| sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I) | |
| if cgstp and sgstp: | |
| try: | |
| rate = float(cgstp.group(1)) + float(sgstp.group(1)) | |
| out["tax_rate"] = f"{rate:g}" | |
| except: | |
| pass | |
| except: | |
| pass | |
| m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I) | |
| if m: out["reference_number"] = m.group(1).strip() | |
| return out | |
| def extract_bank_block(text: str) -> Dict[str, str]: | |
| bank: Dict[str, str] = {} | |
| m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I) | |
| if m: bank["supplier_name"] = m.group(1).strip() | |
| m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I) | |
| if m: bank["bank_account_number"] = m.group(1).strip() | |
| m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I) | |
| if m: | |
| bank["additional_info"] = ("Bank: " + m.group(1).strip()) | |
| m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I) | |
| if m: bank["payment_reference"] = m.group(1).strip() | |
| m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I) | |
| if m: bank["swift_code"] = m.group(1).strip() | |
| branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I) | |
| micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I) | |
| extra_bits = [] | |
| if branch: extra_bits.append("Branch: " + branch.group(1).strip()) | |
| if micr: extra_bits.append("MICR: " + micr.group(1).strip()) | |
| if extra_bits: | |
| bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits) | |
| return bank | |
| def _has_real_items(items) -> bool: | |
| return ( | |
| isinstance(items, list) | |
| and any( | |
| isinstance(row, dict) | |
| and any(val not in (None, "", "null") for val in row.values()) | |
| for row in items | |
| ) | |
| ) | |
| def parse_line_items(text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Dynamic, header-agnostic line-item extractor. | |
| - Auto-detects header row (no hardcoded labels) | |
| - Supports pipe '|' tables, multi-space/tab tables, and stacked/vertical layouts | |
| - Fuzzy maps arbitrary headers to: description, quantity, units, price, amount | |
| - Stitches wrapped descriptions; stops at totals/subtotals | |
| """ | |
| import re | |
| from typing import List, Dict, Any | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| # ---- local helpers (encapsulated; no external edits required) ---- | |
| def _tokenize_row(row: str) -> List[str]: | |
| if "|" in row: | |
| toks = [c.strip(" -") for c in row.split("|")] | |
| else: | |
| toks = re.split(r"\t+| {2,}", row) | |
| toks = [c.strip(" -") for c in toks] | |
| return [t for t in toks if t] | |
| def _looks_like_separator(row: str) -> bool: | |
| return bool(re.fullmatch(r"[-=–—\s]+", row)) | |
| def _numlike(s: str) -> bool: | |
| return bool(re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", s.strip())) | |
| def _normalize_num(s: str | None) -> str | None: | |
| if not s: return None | |
| return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None | |
| STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I) | |
| # Canonical targets + synonyms (broad, non-brittle) | |
| CANON = ["description", "quantity", "units", "price", "amount"] | |
| SYN = { | |
| "description": ["description", "item", "details", "product", "material", "article", "part no", "part", "goods desc"], | |
| "quantity": ["qty", "quantity", "qnty", "pcs", "pieces", "units qty", "ordered qty"], | |
| "units": ["uom", "unit", "units", "measure", "type", "pkg", "pack", "u/m"], | |
| "price": ["rate", "price", "unit price", "cost", "u/price", "list price"], | |
| "amount": ["amount", "total", "line total", "ext price", "net", "value", "extended"] | |
| } | |
| def _find_header_idx(lines: List[str]) -> int: | |
| """Heuristic header detection for horizontal tables.""" | |
| for i, row in enumerate(lines): | |
| if _looks_like_separator(row): | |
| continue | |
| toks = _tokenize_row(row) | |
| if len(toks) < 3: | |
| continue | |
| # low numeric density | |
| if sum(_numlike(t) for t in toks) > len(toks) // 2: | |
| continue | |
| # at least 2 synonym hits | |
| hits = 0 | |
| lowt = [t.lower() for t in toks] | |
| for t in lowt: | |
| for syns in SYN.values(): | |
| if any(s in t for s in syns): | |
| hits += 1 | |
| break | |
| if hits >= 2: | |
| return i | |
| return -1 | |
| def _map_headers_dynamic(header_tokens: List[str], model) -> Dict[int, str]: | |
| """ | |
| Map arbitrary header tokens to canonical keys via: | |
| 1) direct/synonym contains | |
| 2) semantic similarity (best match) | |
| """ | |
| mapped: Dict[int, str] = {} | |
| used = set() | |
| low = [h.lower() for h in header_tokens] | |
| # 1) substring / synonyms | |
| for j, h in enumerate(low): | |
| for key, syns in SYN.items(): | |
| if any(s in h for s in syns): | |
| if key not in used: | |
| mapped[j] = key | |
| used.add(key) | |
| break | |
| # 2) semantic backstop for unmapped | |
| remaining = [j for j in range(len(header_tokens)) if j not in mapped] | |
| if remaining: | |
| label_texts, label_keys = [], [] | |
| for k, syns in SYN.items(): | |
| for s in syns + [k]: | |
| label_texts.append(s) | |
| label_keys.append(k) | |
| h_emb = model.encode([header_tokens[i] for i in remaining], normalize_embeddings=True) | |
| l_emb = model.encode(label_texts, normalize_embeddings=True) | |
| sim = util.cos_sim(torch.tensor(h_emb), torch.tensor(l_emb)).cpu().numpy() | |
| for ri, j in enumerate(remaining): | |
| k_best = int(sim[ri].argmax()) | |
| key = label_keys[k_best] | |
| if key not in used: | |
| mapped[j] = key | |
| used.add(key) | |
| return mapped | |
| def _parse_horizontal(lines: List[str]) -> List[Dict[str, Any]]: | |
| """Parse pipe/whitespace horizontal tables with dynamic headers.""" | |
| header_idx = _find_header_idx(lines) | |
| if header_idx == -1: | |
| return [] | |
| header_tokens = _tokenize_row(lines[header_idx]) | |
| # lazy singleton on the function for perf (no external changes) | |
| if not hasattr(parse_line_items, "_sent_model"): | |
| parse_line_items._sent_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # type: ignore[attr-defined] | |
| sm = parse_line_items._sent_model # type: ignore[attr-defined] | |
| idx2key = _map_headers_dynamic(header_tokens, sm) | |
| items: List[Dict[str, Any]] = [] | |
| for row in lines[header_idx + 1:]: | |
| if _looks_like_separator(row): | |
| continue | |
| if STOP.search(row): | |
| break | |
| toks = _tokenize_row(row) | |
| # continuation-line heuristic (wrapped description) | |
| if (len(toks) == 1 or len(toks) < (max(idx2key.keys(), default=-1) + 1)) and items: | |
| last = items[-1] | |
| prev = (last.get("description") or "").strip() | |
| last["description"] = (prev + " " + toks[0]).strip() if toks else prev | |
| continue | |
| rowd = {"description": None, "quantity": None, "units": None, | |
| "price": None, "amount": None, "footage": None, "notes": None} | |
| for j, tok in enumerate(toks): | |
| key = idx2key.get(j) | |
| if not key: | |
| continue | |
| val = tok.strip() | |
| if key in ("quantity", "price", "amount"): | |
| val = _normalize_num(val) | |
| rowd[key] = val or rowd.get(key) | |
| if rowd["quantity"] and rowd["units"]: | |
| rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}' | |
| if any(rowd.get(k) for k in ("description", "amount", "price")): | |
| items.append(rowd) | |
| # prune empties | |
| return [it for it in items if any(v for k, v in it.items() if k != "notes")] | |
| def _parse_vertical(text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Deterministic stacked/vertical parser for blocks like: | |
| Description | |
| Type | |
| Quantity | |
| Rate | |
| Amount | |
| <desc1> | |
| <type1> | |
| <qty1> | |
| <rate1> | |
| <amt1> | |
| <desc2> ... | |
| Stops at totals/subtotals. | |
| """ | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip()] | |
| if not lines: | |
| return [] | |
| # Find the exact 5-label header block (order-agnostic but contiguous) | |
| LABELS = ["description", "type", "quantity", "rate", "amount"] | |
| def is_label(s: str) -> str | None: | |
| t = s.lower() | |
| if re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", t): | |
| return None | |
| if "desc" in t or "item" in t or "product" in t or "material" in t or "article" in t: | |
| return "description" | |
| if "type" in t or "uom" in t or "unit" in t or "units" in t: | |
| return "type" | |
| if "qty" in t or "quantity" in t: | |
| return "quantity" | |
| if "rate" in t or "price" in t or "unit price" in t: | |
| return "rate" | |
| if "amount" in t or "total" in t: | |
| return "amount" | |
| return None | |
| start = -1 | |
| for i in range(len(lines) - 4): | |
| block = lines[i:i+5] | |
| mapped = [is_label(x) for x in block] | |
| if None not in mapped and len(set(mapped)) == 5: | |
| start = i | |
| header_keys = mapped # e.g. ["description","type","quantity","rate","amount"] | |
| break | |
| if start == -1: | |
| return [] | |
| # Build a position→canonical map in this exact order | |
| pos2key = {idx: key for idx, key in enumerate(header_keys)} | |
| # Consume values in chunks of 5 | |
| items: List[Dict[str, Any]] = [] | |
| i = start + 5 | |
| STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I) | |
| def norm_num(s: str | None) -> str | None: | |
| if not s: return None | |
| return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None | |
| while i + 4 < len(lines): | |
| if STOP.search(lines[i]): # hit totals, bail | |
| break | |
| chunk = lines[i:i+5] | |
| row = {"description": None, "units": None, "quantity": None, | |
| "price": None, "amount": None, "footage": None, "notes": None} | |
| # map chunk by discovered order | |
| for j, val in enumerate(chunk): | |
| key = pos2key[j] | |
| if key == "type": | |
| row["units"] = val # map "Type" -> "units" | |
| elif key == "quantity": | |
| row["quantity"] = norm_num(val) | |
| elif key == "rate": | |
| row["price"] = norm_num(val) | |
| elif key == "amount": | |
| row["amount"] = norm_num(val) | |
| elif key == "description": | |
| row["description"] = val | |
| if row["quantity"] and row["units"]: | |
| row["footage"] = f'{row["quantity"]} {row["units"]}' | |
| # minimal acceptance: description or amount or price | |
| if any(row.get(k) for k in ("description", "amount", "price")): | |
| items.append(row) | |
| i += 5 | |
| return items | |
| # ---- main body ---- | |
| raw_lines = [ln.rstrip() for ln in text.splitlines()] | |
| lines = [ln for ln in raw_lines if ln.strip()] | |
| if not lines: | |
| return [] | |
| # 1) Try horizontal first | |
| items = _parse_horizontal(lines) | |
| if items: | |
| return items | |
| # 2) Fallback to vertical/stacked | |
| items = _parse_vertical(text) | |
| return items | |
| def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]: | |
| if not candidates: | |
| return {} | |
| cand_keys = list(candidates.keys()) | |
| mapped: Dict[str, str] = {} | |
| leftovers: Dict[str, str] = {} | |
| for k, v in candidates.items(): | |
| lk = k.lower() | |
| lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip() | |
| hit = None | |
| for syn, key in SYN2KEY.items(): | |
| if syn in lk_norm: | |
| hit = key | |
| break | |
| if hit: | |
| mapped[hit] = v | |
| else: | |
| leftovers[k] = v | |
| if leftovers: | |
| cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True) | |
| head_emb = sentence_model.encode(static_headers, normalize_embeddings=True) | |
| M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy() | |
| keys_left = list(leftovers.keys()) | |
| for i, ck in enumerate(keys_left): | |
| j = int(np.argmax(M[i])) | |
| score = float(M[i][j]) | |
| if score >= thresh: | |
| mapped[static_headers[j]] = leftovers[ck] | |
| return mapped | |
| def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str: | |
| instruction = ( | |
| 'Use this schema:\n' | |
| '{\n' | |
| ' "invoice_header": {\n' | |
| ' "car_number": "string or null",\n' | |
| ' "shipment_number": "string or null",\n' | |
| ' "shipping_point": "string or null",\n' | |
| ' "currency": "string or null",\n' | |
| ' "invoice_number": "string or null",\n' | |
| ' "invoice_date": "string or null",\n' | |
| ' "order_number": "string or null",\n' | |
| ' "customer_order_number": "string or null",\n' | |
| ' "our_order_number": "string or null",\n' | |
| ' "sales_order_number": "string or null",\n' | |
| ' "purchase_order_number": "string or null",\n' | |
| ' "order_date": "string or null",\n' | |
| ' "supplier_name": "string or null",\n' | |
| ' "supplier_address": "string or null",\n' | |
| ' "supplier_phone": "string or null",\n' | |
| ' "supplier_email": "string or null",\n' | |
| ' "supplier_tax_id": "string or null",\n' | |
| ' "customer_name": "string or null",\n' | |
| ' "customer_address": "string or null",\n' | |
| ' "customer_phone": "string or null",\n' | |
| ' "customer_email": "string or null",\n' | |
| ' "customer_tax_id": "string or null",\n' | |
| ' "ship_to_name": "string or null",\n' | |
| ' "ship_to_address": "string or null",\n' | |
| ' "bill_to_name": "string or null",\n' | |
| ' "bill_to_address": "string or null",\n' | |
| ' "remit_to_name": "string or null",\n' | |
| ' "remit_to_address": "string or null",\n' | |
| ' "tax_id": "string or null",\n' | |
| ' "tax_registration_number": "string or null",\n' | |
| ' "vat_number": "string or null",\n' | |
| ' "payment_terms": "string or null",\n' | |
| ' "payment_method": "string or null",\n' | |
| ' "payment_reference": "string or null",\n' | |
| ' "bank_account_number": "string or null",\n' | |
| ' "iban": "string or null",\n' | |
| ' "swift_code": "string or null",\n' | |
| ' "total_before_tax": "string or null",\n' | |
| ' "tax_amount": "string or null",\n' | |
| ' "tax_rate": "string or null",\n' | |
| ' "shipping_charges": "string or null",\n' | |
| ' "discount": "string or null",\n' | |
| ' "total_due": "string or null",\n' | |
| ' "amount_paid": "string or null",\n' | |
| ' "balance_due": "string or null",\n' | |
| ' "due_date": "string or null",\n' | |
| ' "invoice_status": "string or null",\n' | |
| ' "reference_number": "string or null",\n' | |
| ' "project_code": "string or null",\n' | |
| ' "department": "string or null",\n' | |
| ' "contact_person": "string or null",\n' | |
| ' "notes": "string or null",\n' | |
| ' "additional_info": "string or null"\n' | |
| ' },\n' | |
| ' "line_items": [\n' | |
| ' {\n' | |
| ' "quantity": "string or null",\n' | |
| ' "units": "string or null",\n' | |
| ' "description": "string or null",\n' | |
| ' "footage": "string or null",\n' | |
| ' "price": "string or null",\n' | |
| ' "amount": "string or null",\n' | |
| ' "notes": "string or null"\n' | |
| ' }\n' | |
| ' ]\n' | |
| '}\n' | |
| 'If a field is missing for a line item or header, use null. ' | |
| 'Do not invent fields. Do not add any header or shipment data to any line item. ' | |
| 'Return ONLY the JSON object, no explanation.\n' | |
| ) | |
| hints = "" | |
| if mapped_hints: | |
| hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()]) | |
| if items_hints: | |
| try: | |
| hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False) | |
| except: | |
| pass | |
| return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints | |
| def strict_json(text: str) -> Dict[str, Any]: | |
| try: | |
| return json.loads(text) | |
| except: | |
| pass | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| return json.loads(text[start:end+1]) | |
| except: | |
| pass | |
| raise ValueError("Model did not return valid JSON.") | |
| def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]: | |
| final = copy.deepcopy(rule_json) | |
| # --- headers (rules win where present) --- | |
| hdr = final["invoice_header"] | |
| mdl_hdr = (model_json.get("invoice_header") or {}) | |
| for k in hdr.keys(): | |
| if hdr[k] in [None, "", "null"]: | |
| v = mdl_hdr.get(k, None) | |
| if v not in [None, "", "null"]: | |
| hdr[k] = v | |
| # --- line_items (prefer parsed items -> model -> empty) --- | |
| rule_items = rule_json.get("line_items") or [] | |
| model_items = model_json.get("line_items") or [] | |
| if _has_real_items(rule_items): | |
| final["line_items"] = rule_items | |
| elif _has_real_items(model_items): | |
| final["line_items"] = model_items | |
| else: | |
| final["line_items"] = [] | |
| return final | |
| def _prune_empty_items(payload: Dict[str, Any]) -> Dict[str, Any]: | |
| items = payload.get("line_items") | |
| if isinstance(items, list): | |
| payload["line_items"] = [ | |
| it for it in items | |
| if isinstance(it, dict) and any(v not in (None, "", "null") for v in it.values()) | |
| ] | |
| return payload | |
| # ---------------------- MAIN FUNCTION ---------------------- | |
| def invoice_text_to_json( | |
| invoice_text: str, | |
| threshold: float = 0.60, | |
| max_new_tokens: int = 512 | |
| ) -> Dict[str, Any]: | |
| # Load models once (cache if you like for production) | |
| sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1") | |
| txt = invoice_text | |
| # 1) Deterministic extraction | |
| candidates = extract_candidates(txt) | |
| hard = regex_extract_all(txt) | |
| bank = extract_bank_block(txt) | |
| items = parse_line_items(txt) | |
| print("Extracted line items:", items) | |
| sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model) | |
| header_found: Dict[str, Any] = {} | |
| header_found.update(sem_mapped) | |
| header_found.update(hard) | |
| header_found.update(bank) | |
| # 2) Build RULE JSON (schema-shaped, rules filled) | |
| rule_json = deep_copy_schema() | |
| if _has_real_items(items): | |
| rule_json["line_items"] = items | |
| else: | |
| rule_json["line_items"] = [] | |
| for k, v in header_found.items(): | |
| if k in rule_json["invoice_header"]: | |
| rule_json["invoice_header"][k] = v | |
| # 3) MD2JSON generation with strong hints | |
| prompt = build_prompt(txt, header_found, items) | |
| gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"] | |
| try: | |
| model_json = strict_json(gen) | |
| except Exception as e: | |
| model_json = deep_copy_schema() # model failed; keep empty shape | |
| # 4) Final merge (rules win) | |
| final_json = merge_schema(rule_json, model_json) | |
| final_json = _prune_empty_items(final_json) | |
| return final_json | |
| from typing import Optional | |
| # ----- replace old run_ocr with unified dispatcher ----- | |
| def run_pipeline(file: Optional[gr.File], raw_txt: Optional[str]) -> str: | |
| """ | |
| Orchestrates two intake lanes: | |
| 1) If raw_txt is provided (non-empty), skip OCR → directly map to schema. | |
| 2) Else, run OCR on the uploaded file and map to schema. | |
| """ | |
| raw_txt = (raw_txt or "").strip() | |
| # Lane A: Raw text → JSON | |
| if raw_txt: | |
| try: | |
| result_json = invoice_text_to_json(raw_txt) | |
| return json.dumps(result_json, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return f"Error while converting pasted text to JSON schema: {e}" | |
| # Lane B: File → OCR → JSON | |
| if not file: | |
| return "No input received. Upload an image/PDF or paste raw text." | |
| try: | |
| name = (file.name or "").lower() | |
| # Load as DocumentFile (handles PNG/JPG/PDF) | |
| if name.endswith(".pdf"): | |
| doc = DocumentFile.from_pdf(file=file.name) | |
| else: | |
| doc = DocumentFile.from_images([file.name]) | |
| # Inference | |
| result = MODEL(doc) | |
| exported = result.export() | |
| text = _collect_text_from_export(exported) | |
| if not text: | |
| return "No text detected by OCR." | |
| result_json = invoice_text_to_json(text) | |
| return json.dumps(result_json, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return f"OCR pipeline error: {e}" | |
| # ---------- Gradio UI ---------- | |
| # ---------- Gradio UI ---------- | |
| TITLE = "docTR OCR — Text Extractor" | |
| DESC = ( | |
| "Upload an image or PDF OR paste raw text. Uses docTR for OCR or directly maps raw text to the invoice JSON schema." | |
| ) | |
| with gr.Blocks(theme="soft", title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}\n{DESC}") | |
| with gr.Tabs(): | |
| with gr.Tab("Upload File"): | |
| inp = gr.File( | |
| label="Upload image/PDF", | |
| file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"] | |
| ) | |
| # keep symmetrical inputs for single-click wiring | |
| raw_txt_hidden = gr.Textbox(visible=False) | |
| with gr.Tab("Paste Raw Text"): | |
| raw_txt = gr.Textbox( | |
| label="Paste raw invoice text (we’ll map directly to JSON schema)", | |
| lines=18, | |
| placeholder="Paste the OCR’d/plain text of the invoice here…" | |
| ) | |
| file_hidden = gr.File(visible=False) | |
| out = gr.Code(label="Extracted JSON", language="json") | |
| run_btn = gr.Button("Generate JSON", variant="primary") | |
| # One button → unified function; we pass both lanes (visible/hidden) | |
| run_btn.click( | |
| fn=run_pipeline, | |
| inputs=[inp, raw_txt], | |
| outputs=out, | |
| ) | |
| gr.Markdown( | |
| "ℹ️ **Usage:** Prefer *Paste Raw Text* when you already have text. " | |
| "If both file and text are provided, we’ll **prioritize the pasted text**." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True) | |