import os import io from typing import List import gradio as gr # docTR imports (PyTorch backend) from doctr.io import DocumentFile from doctr.models import ocr_predictor # ---------- One-time model bootstrap (CPU-friendly) ---------- # Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend. # You can optionally pin threads for stability on small CPU runners: os.environ.setdefault("OMP_NUM_THREADS", "4") os.environ.setdefault("MKL_NUM_THREADS", "4") MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch def _collect_text_from_export(exported: dict) -> str: """Flatten docTR exported structure into newline-separated text per page.""" pages: List[dict] = exported.get("pages", []) text_pages: List[str] = [] for page in pages: page_lines = [] for block in page.get("blocks", []): for line in block.get("lines", []): # Join word values in the line; fallback robustly words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)] line_text = " ".join([w for w in words if w]) if line_text.strip(): page_lines.append(line_text) text_pages.append("\n".join(page_lines).strip()) # Join pages with a page delimiter return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join( [tp for tp in text_pages if tp] ).strip() def run_ocr(file: gr.File) -> str: if file is None: return "No file received." name = (file.name or "").lower() # Load as DocumentFile (handles PNG/JPG/PDF) if name.endswith(".pdf"): # Render PDF pages via pdfium backend under the hood (CPU OK) doc = DocumentFile.from_pdf(file=file.name) else: # Single image fallback; also works for TIFF/PNG/JPG doc = DocumentFile.from_images([file.name]) # Inference result = MODEL(doc) exported = result.export() text = _collect_text_from_export(exported) print("Extracted Text:\n", text) if not text: return "No text detected." result_json = invoice_text_to_json(text) print(json.dumps(result_json, indent=2)) string_json = json.dumps(result_json, indent=2) return string_json import re import json from typing import List, Dict, Any import copy import numpy as np import torch from transformers import pipeline from sentence_transformers import SentenceTransformer, util # ----------------------------- Schema ----------------------------- SCHEMA_JSON: Dict[str, Any] = { "invoice_header": { "car_number": None, "shipment_number": None, "shipping_point": None, "currency": None, "invoice_number": None, "invoice_date": None, "order_number": None, "customer_order_number": None, "our_order_number": None, "sales_order_number": None, "purchase_order_number": None, "order_date": None, "supplier_name": None, "supplier_address": None, "supplier_phone": None, "supplier_email": None, "supplier_tax_id": None, "customer_name": None, "customer_address": None, "customer_phone": None, "customer_email": None, "customer_tax_id": None, "ship_to_name": None, "ship_to_address": None, "bill_to_name": None, "bill_to_address": None, "remit_to_name": None, "remit_to_address": None, "tax_id": None, "tax_registration_number": None, "vat_number": None, "payment_terms": None, "payment_method": None, "payment_reference": None, "bank_account_number": None, "iban": None, "swift_code": None, "total_before_tax": None, "tax_amount": None, "tax_rate": None, "shipping_charges": None, "discount": None, "total_due": None, "amount_paid": None, "balance_due": None, "due_date": None, "invoice_status": None, "reference_number": None, "project_code": None, "department": None, "contact_person": None, "notes": None, "additional_info": None }, "line_items": [ { "quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None } ] } STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys()) # Synonym map SYN2KEY: Dict[str, str] = { "invoice no": "invoice_number", "invoice number": "invoice_number", "invoice#": "invoice_number", "inv no": "invoice_number", "inv#": "invoice_number", "invoice date": "invoice_date", "date of invoice": "invoice_date", "po no": "purchase_order_number", "po number": "purchase_order_number", "purchase order": "purchase_order_number", "order no": "order_number", "order number": "order_number", "sales order": "sales_order_number", "customer order": "customer_order_number", "our order": "our_order_number", "due date": "due_date", "date of supply": "order_date", "gstin": "supplier_tax_id", "gstin no": "supplier_tax_id", "tax id": "tax_id", "vat number": "vat_number", "tax registration number": "tax_registration_number", "place of supply": "shipping_point", "state code": "additional_info", "taxable value": "total_before_tax", "total value": "total_due", "total amount": "total_due", "amount due": "total_due", "bank": "bank_account_number", "account no": "bank_account_number", "account number": "bank_account_number", "ifs code": "swift_code", "ifsc": "payment_reference", "swift code": "swift_code", "iban": "iban", "e-way bill no": "reference_number", "eway bill": "reference_number", "dispatched via": "additional_info", "documents dispatched through": "additional_info", "kind attn": "contact_person", "billed to": "bill_to_name", "receiver": "bill_to_name", "shipped to": "ship_to_name", "consignee": "ship_to_name", } def norm(s: str) -> str: return re.sub(r"\s+", " ", s).strip() def deep_copy_schema() -> Dict[str, Any]: return json.loads(json.dumps(SCHEMA_JSON)) def extract_candidates(text: str) -> Dict[str, str]: cands: Dict[str, str] = {} for raw in text.splitlines(): line = raw.strip().strip("|").strip() if not line: continue if ":" in line: if "|" in raw: parts = [p.strip() for p in raw.split("|") if p.strip()] for cell in parts: if ":" in cell: k, v = cell.split(":", 1) cands[norm(k)] = norm(v) else: k, v = line.split(":", 1) cands[norm(k)] = norm(v) for raw in text.splitlines(): m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I) if m: k = norm(m.group(1)) v = norm(m.group(2)) cands[k] = v return cands def regex_extract_all(text: str) -> Dict[str, str]: out: Dict[str, str] = {} m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I) if m: out["invoice_number"] = m.group(1) m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) if m: out["invoice_date"] = m.group(1) m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I) if m: out["purchase_order_number"] = m.group(1) m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) if m: out["order_date"] = m.group(1) if "order_date" not in out: m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I) if m: out["order_date"] = m.group(1) m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I) if m: out["shipping_point"] = m.group(1).strip(" |") m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I) if m: out["supplier_tax_id"] = m.group(1) m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) if m: out["total_before_tax"] = m.group(1).replace(",", "") cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I) if cgst and sgst: try: tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", "")) out["tax_amount"] = f"{tax_total:.2f}" cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I) sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I) if cgstp and sgstp: try: rate = float(cgstp.group(1)) + float(sgstp.group(1)) out["tax_rate"] = f"{rate:g}" except: pass except: pass m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I) if m: out["reference_number"] = m.group(1).strip() return out def extract_bank_block(text: str) -> Dict[str, str]: bank: Dict[str, str] = {} m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I) if m: bank["supplier_name"] = m.group(1).strip() m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I) if m: bank["bank_account_number"] = m.group(1).strip() m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I) if m: bank["additional_info"] = ("Bank: " + m.group(1).strip()) m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I) if m: bank["payment_reference"] = m.group(1).strip() m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I) if m: bank["swift_code"] = m.group(1).strip() branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I) micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I) extra_bits = [] if branch: extra_bits.append("Branch: " + branch.group(1).strip()) if micr: extra_bits.append("MICR: " + micr.group(1).strip()) if extra_bits: bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits) return bank def _has_real_items(items) -> bool: return ( isinstance(items, list) and any( isinstance(row, dict) and any(val not in (None, "", "null") for val in row.values()) for row in items ) ) def parse_line_items(text: str) -> List[Dict[str, Any]]: """ Dynamic, header-agnostic line-item extractor. - Auto-detects header row (no hardcoded labels) - Supports pipe '|' tables, multi-space/tab tables, and stacked/vertical layouts - Fuzzy maps arbitrary headers to: description, quantity, units, price, amount - Stitches wrapped descriptions; stops at totals/subtotals """ import re from typing import List, Dict, Any import torch from sentence_transformers import SentenceTransformer, util # ---- local helpers (encapsulated; no external edits required) ---- def _tokenize_row(row: str) -> List[str]: if "|" in row: toks = [c.strip(" -") for c in row.split("|")] else: toks = re.split(r"\t+| {2,}", row) toks = [c.strip(" -") for c in toks] return [t for t in toks if t] def _looks_like_separator(row: str) -> bool: return bool(re.fullmatch(r"[-=–—\s]+", row)) def _numlike(s: str) -> bool: return bool(re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", s.strip())) def _normalize_num(s: str | None) -> str | None: if not s: return None return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I) # Canonical targets + synonyms (broad, non-brittle) CANON = ["description", "quantity", "units", "price", "amount"] SYN = { "description": ["description", "item", "details", "product", "material", "article", "part no", "part", "goods desc"], "quantity": ["qty", "quantity", "qnty", "pcs", "pieces", "units qty", "ordered qty"], "units": ["uom", "unit", "units", "measure", "type", "pkg", "pack", "u/m"], "price": ["rate", "price", "unit price", "cost", "u/price", "list price"], "amount": ["amount", "total", "line total", "ext price", "net", "value", "extended"] } def _find_header_idx(lines: List[str]) -> int: """Heuristic header detection for horizontal tables.""" for i, row in enumerate(lines): if _looks_like_separator(row): continue toks = _tokenize_row(row) if len(toks) < 3: continue # low numeric density if sum(_numlike(t) for t in toks) > len(toks) // 2: continue # at least 2 synonym hits hits = 0 lowt = [t.lower() for t in toks] for t in lowt: for syns in SYN.values(): if any(s in t for s in syns): hits += 1 break if hits >= 2: return i return -1 def _map_headers_dynamic(header_tokens: List[str], model) -> Dict[int, str]: """ Map arbitrary header tokens to canonical keys via: 1) direct/synonym contains 2) semantic similarity (best match) """ mapped: Dict[int, str] = {} used = set() low = [h.lower() for h in header_tokens] # 1) substring / synonyms for j, h in enumerate(low): for key, syns in SYN.items(): if any(s in h for s in syns): if key not in used: mapped[j] = key used.add(key) break # 2) semantic backstop for unmapped remaining = [j for j in range(len(header_tokens)) if j not in mapped] if remaining: label_texts, label_keys = [], [] for k, syns in SYN.items(): for s in syns + [k]: label_texts.append(s) label_keys.append(k) h_emb = model.encode([header_tokens[i] for i in remaining], normalize_embeddings=True) l_emb = model.encode(label_texts, normalize_embeddings=True) sim = util.cos_sim(torch.tensor(h_emb), torch.tensor(l_emb)).cpu().numpy() for ri, j in enumerate(remaining): k_best = int(sim[ri].argmax()) key = label_keys[k_best] if key not in used: mapped[j] = key used.add(key) return mapped def _parse_horizontal(lines: List[str]) -> List[Dict[str, Any]]: """Parse pipe/whitespace horizontal tables with dynamic headers.""" header_idx = _find_header_idx(lines) if header_idx == -1: return [] header_tokens = _tokenize_row(lines[header_idx]) # lazy singleton on the function for perf (no external changes) if not hasattr(parse_line_items, "_sent_model"): parse_line_items._sent_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # type: ignore[attr-defined] sm = parse_line_items._sent_model # type: ignore[attr-defined] idx2key = _map_headers_dynamic(header_tokens, sm) items: List[Dict[str, Any]] = [] for row in lines[header_idx + 1:]: if _looks_like_separator(row): continue if STOP.search(row): break toks = _tokenize_row(row) # continuation-line heuristic (wrapped description) if (len(toks) == 1 or len(toks) < (max(idx2key.keys(), default=-1) + 1)) and items: last = items[-1] prev = (last.get("description") or "").strip() last["description"] = (prev + " " + toks[0]).strip() if toks else prev continue rowd = {"description": None, "quantity": None, "units": None, "price": None, "amount": None, "footage": None, "notes": None} for j, tok in enumerate(toks): key = idx2key.get(j) if not key: continue val = tok.strip() if key in ("quantity", "price", "amount"): val = _normalize_num(val) rowd[key] = val or rowd.get(key) if rowd["quantity"] and rowd["units"]: rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}' if any(rowd.get(k) for k in ("description", "amount", "price")): items.append(rowd) # prune empties return [it for it in items if any(v for k, v in it.items() if k != "notes")] def _parse_vertical(text: str) -> List[Dict[str, Any]]: """ Deterministic stacked/vertical parser for blocks like: Description Type Quantity Rate Amount ... Stops at totals/subtotals. """ lines = [ln.strip() for ln in text.splitlines() if ln.strip()] if not lines: return [] # Find the exact 5-label header block (order-agnostic but contiguous) LABELS = ["description", "type", "quantity", "rate", "amount"] def is_label(s: str) -> str | None: t = s.lower() if re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", t): return None if "desc" in t or "item" in t or "product" in t or "material" in t or "article" in t: return "description" if "type" in t or "uom" in t or "unit" in t or "units" in t: return "type" if "qty" in t or "quantity" in t: return "quantity" if "rate" in t or "price" in t or "unit price" in t: return "rate" if "amount" in t or "total" in t: return "amount" return None start = -1 for i in range(len(lines) - 4): block = lines[i:i+5] mapped = [is_label(x) for x in block] if None not in mapped and len(set(mapped)) == 5: start = i header_keys = mapped # e.g. ["description","type","quantity","rate","amount"] break if start == -1: return [] # Build a position→canonical map in this exact order pos2key = {idx: key for idx, key in enumerate(header_keys)} # Consume values in chunks of 5 items: List[Dict[str, Any]] = [] i = start + 5 STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I) def norm_num(s: str | None) -> str | None: if not s: return None return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None while i + 4 < len(lines): if STOP.search(lines[i]): # hit totals, bail break chunk = lines[i:i+5] row = {"description": None, "units": None, "quantity": None, "price": None, "amount": None, "footage": None, "notes": None} # map chunk by discovered order for j, val in enumerate(chunk): key = pos2key[j] if key == "type": row["units"] = val # map "Type" -> "units" elif key == "quantity": row["quantity"] = norm_num(val) elif key == "rate": row["price"] = norm_num(val) elif key == "amount": row["amount"] = norm_num(val) elif key == "description": row["description"] = val if row["quantity"] and row["units"]: row["footage"] = f'{row["quantity"]} {row["units"]}' # minimal acceptance: description or amount or price if any(row.get(k) for k in ("description", "amount", "price")): items.append(row) i += 5 return items # ---- main body ---- raw_lines = [ln.rstrip() for ln in text.splitlines()] lines = [ln for ln in raw_lines if ln.strip()] if not lines: return [] # 1) Try horizontal first items = _parse_horizontal(lines) if items: return items # 2) Fallback to vertical/stacked items = _parse_vertical(text) return items def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]: if not candidates: return {} cand_keys = list(candidates.keys()) mapped: Dict[str, str] = {} leftovers: Dict[str, str] = {} for k, v in candidates.items(): lk = k.lower() lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip() hit = None for syn, key in SYN2KEY.items(): if syn in lk_norm: hit = key break if hit: mapped[hit] = v else: leftovers[k] = v if leftovers: cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True) head_emb = sentence_model.encode(static_headers, normalize_embeddings=True) M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy() keys_left = list(leftovers.keys()) for i, ck in enumerate(keys_left): j = int(np.argmax(M[i])) score = float(M[i][j]) if score >= thresh: mapped[static_headers[j]] = leftovers[ck] return mapped def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str: instruction = ( 'Use this schema:\n' '{\n' ' "invoice_header": {\n' ' "car_number": "string or null",\n' ' "shipment_number": "string or null",\n' ' "shipping_point": "string or null",\n' ' "currency": "string or null",\n' ' "invoice_number": "string or null",\n' ' "invoice_date": "string or null",\n' ' "order_number": "string or null",\n' ' "customer_order_number": "string or null",\n' ' "our_order_number": "string or null",\n' ' "sales_order_number": "string or null",\n' ' "purchase_order_number": "string or null",\n' ' "order_date": "string or null",\n' ' "supplier_name": "string or null",\n' ' "supplier_address": "string or null",\n' ' "supplier_phone": "string or null",\n' ' "supplier_email": "string or null",\n' ' "supplier_tax_id": "string or null",\n' ' "customer_name": "string or null",\n' ' "customer_address": "string or null",\n' ' "customer_phone": "string or null",\n' ' "customer_email": "string or null",\n' ' "customer_tax_id": "string or null",\n' ' "ship_to_name": "string or null",\n' ' "ship_to_address": "string or null",\n' ' "bill_to_name": "string or null",\n' ' "bill_to_address": "string or null",\n' ' "remit_to_name": "string or null",\n' ' "remit_to_address": "string or null",\n' ' "tax_id": "string or null",\n' ' "tax_registration_number": "string or null",\n' ' "vat_number": "string or null",\n' ' "payment_terms": "string or null",\n' ' "payment_method": "string or null",\n' ' "payment_reference": "string or null",\n' ' "bank_account_number": "string or null",\n' ' "iban": "string or null",\n' ' "swift_code": "string or null",\n' ' "total_before_tax": "string or null",\n' ' "tax_amount": "string or null",\n' ' "tax_rate": "string or null",\n' ' "shipping_charges": "string or null",\n' ' "discount": "string or null",\n' ' "total_due": "string or null",\n' ' "amount_paid": "string or null",\n' ' "balance_due": "string or null",\n' ' "due_date": "string or null",\n' ' "invoice_status": "string or null",\n' ' "reference_number": "string or null",\n' ' "project_code": "string or null",\n' ' "department": "string or null",\n' ' "contact_person": "string or null",\n' ' "notes": "string or null",\n' ' "additional_info": "string or null"\n' ' },\n' ' "line_items": [\n' ' {\n' ' "quantity": "string or null",\n' ' "units": "string or null",\n' ' "description": "string or null",\n' ' "footage": "string or null",\n' ' "price": "string or null",\n' ' "amount": "string or null",\n' ' "notes": "string or null"\n' ' }\n' ' ]\n' '}\n' 'If a field is missing for a line item or header, use null. ' 'Do not invent fields. Do not add any header or shipment data to any line item. ' 'Return ONLY the JSON object, no explanation.\n' ) hints = "" if mapped_hints: hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()]) if items_hints: try: hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False) except: pass return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints def strict_json(text: str) -> Dict[str, Any]: try: return json.loads(text) except: pass start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: try: return json.loads(text[start:end+1]) except: pass raise ValueError("Model did not return valid JSON.") def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]: final = copy.deepcopy(rule_json) # --- headers (rules win where present) --- hdr = final["invoice_header"] mdl_hdr = (model_json.get("invoice_header") or {}) for k in hdr.keys(): if hdr[k] in [None, "", "null"]: v = mdl_hdr.get(k, None) if v not in [None, "", "null"]: hdr[k] = v # --- line_items (prefer parsed items -> model -> empty) --- rule_items = rule_json.get("line_items") or [] model_items = model_json.get("line_items") or [] if _has_real_items(rule_items): final["line_items"] = rule_items elif _has_real_items(model_items): final["line_items"] = model_items else: final["line_items"] = [] return final def _prune_empty_items(payload: Dict[str, Any]) -> Dict[str, Any]: items = payload.get("line_items") if isinstance(items, list): payload["line_items"] = [ it for it in items if isinstance(it, dict) and any(v not in (None, "", "null") for v in it.values()) ] return payload # ---------------------- MAIN FUNCTION ---------------------- def invoice_text_to_json( invoice_text: str, threshold: float = 0.60, max_new_tokens: int = 512 ) -> Dict[str, Any]: # Load models once (cache if you like for production) sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1") txt = invoice_text # 1) Deterministic extraction candidates = extract_candidates(txt) hard = regex_extract_all(txt) bank = extract_bank_block(txt) items = parse_line_items(txt) print("Extracted line items:", items) sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model) header_found: Dict[str, Any] = {} header_found.update(sem_mapped) header_found.update(hard) header_found.update(bank) # 2) Build RULE JSON (schema-shaped, rules filled) rule_json = deep_copy_schema() if _has_real_items(items): rule_json["line_items"] = items else: rule_json["line_items"] = [] for k, v in header_found.items(): if k in rule_json["invoice_header"]: rule_json["invoice_header"][k] = v # 3) MD2JSON generation with strong hints prompt = build_prompt(txt, header_found, items) gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"] try: model_json = strict_json(gen) except Exception as e: model_json = deep_copy_schema() # model failed; keep empty shape # 4) Final merge (rules win) final_json = merge_schema(rule_json, model_json) final_json = _prune_empty_items(final_json) return final_json from typing import Optional # ----- replace old run_ocr with unified dispatcher ----- def run_pipeline(file: Optional[gr.File], raw_txt: Optional[str]) -> str: """ Orchestrates two intake lanes: 1) If raw_txt is provided (non-empty), skip OCR → directly map to schema. 2) Else, run OCR on the uploaded file and map to schema. """ raw_txt = (raw_txt or "").strip() # Lane A: Raw text → JSON if raw_txt: try: result_json = invoice_text_to_json(raw_txt) return json.dumps(result_json, indent=2, ensure_ascii=False) except Exception as e: return f"Error while converting pasted text to JSON schema: {e}" # Lane B: File → OCR → JSON if not file: return "No input received. Upload an image/PDF or paste raw text." try: name = (file.name or "").lower() # Load as DocumentFile (handles PNG/JPG/PDF) if name.endswith(".pdf"): doc = DocumentFile.from_pdf(file=file.name) else: doc = DocumentFile.from_images([file.name]) # Inference result = MODEL(doc) exported = result.export() text = _collect_text_from_export(exported) if not text: return "No text detected by OCR." result_json = invoice_text_to_json(text) return json.dumps(result_json, indent=2, ensure_ascii=False) except Exception as e: return f"OCR pipeline error: {e}" # ---------- Gradio UI ---------- # ---------- Gradio UI ---------- TITLE = "docTR OCR — Text Extractor" DESC = ( "Upload an image or PDF OR paste raw text. Uses docTR for OCR or directly maps raw text to the invoice JSON schema." ) with gr.Blocks(theme="soft", title=TITLE) as demo: gr.Markdown(f"# {TITLE}\n{DESC}") with gr.Tabs(): with gr.Tab("Upload File"): inp = gr.File( label="Upload image/PDF", file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"] ) # keep symmetrical inputs for single-click wiring raw_txt_hidden = gr.Textbox(visible=False) with gr.Tab("Paste Raw Text"): raw_txt = gr.Textbox( label="Paste raw invoice text (we’ll map directly to JSON schema)", lines=18, placeholder="Paste the OCR’d/plain text of the invoice here…" ) file_hidden = gr.File(visible=False) out = gr.Code(label="Extracted JSON", language="json") run_btn = gr.Button("Generate JSON", variant="primary") # One button → unified function; we pass both lanes (visible/hidden) run_btn.click( fn=run_pipeline, inputs=[inp, raw_txt], outputs=out, ) gr.Markdown( "ℹ️ **Usage:** Prefer *Paste Raw Text* when you already have text. " "If both file and text are provided, we’ll **prioritize the pasted text**." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)