KarthiEz's picture
Update app.py
1607f1d verified
import os
import io
from typing import List
import gradio as gr
# docTR imports (PyTorch backend)
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
# ---------- One-time model bootstrap (CPU-friendly) ----------
# Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend.
# You can optionally pin threads for stability on small CPU runners:
os.environ.setdefault("OMP_NUM_THREADS", "4")
os.environ.setdefault("MKL_NUM_THREADS", "4")
MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch
def _collect_text_from_export(exported: dict) -> str:
"""Flatten docTR exported structure into newline-separated text per page."""
pages: List[dict] = exported.get("pages", [])
text_pages: List[str] = []
for page in pages:
page_lines = []
for block in page.get("blocks", []):
for line in block.get("lines", []):
# Join word values in the line; fallback robustly
words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)]
line_text = " ".join([w for w in words if w])
if line_text.strip():
page_lines.append(line_text)
text_pages.append("\n".join(page_lines).strip())
# Join pages with a page delimiter
return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join(
[tp for tp in text_pages if tp]
).strip()
def run_ocr(file: gr.File) -> str:
if file is None:
return "No file received."
name = (file.name or "").lower()
# Load as DocumentFile (handles PNG/JPG/PDF)
if name.endswith(".pdf"):
# Render PDF pages via pdfium backend under the hood (CPU OK)
doc = DocumentFile.from_pdf(file=file.name)
else:
# Single image fallback; also works for TIFF/PNG/JPG
doc = DocumentFile.from_images([file.name])
# Inference
result = MODEL(doc)
exported = result.export()
text = _collect_text_from_export(exported)
print("Extracted Text:\n", text)
if not text:
return "No text detected."
result_json = invoice_text_to_json(text)
print(json.dumps(result_json, indent=2))
string_json = json.dumps(result_json, indent=2)
return string_json
import re
import json
from typing import List, Dict, Any
import copy
import numpy as np
import torch
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
# ----------------------------- Schema -----------------------------
SCHEMA_JSON: Dict[str, Any] = {
"invoice_header": {
"car_number": None,
"shipment_number": None,
"shipping_point": None,
"currency": None,
"invoice_number": None,
"invoice_date": None,
"order_number": None,
"customer_order_number": None,
"our_order_number": None,
"sales_order_number": None,
"purchase_order_number": None,
"order_date": None,
"supplier_name": None,
"supplier_address": None,
"supplier_phone": None,
"supplier_email": None,
"supplier_tax_id": None,
"customer_name": None,
"customer_address": None,
"customer_phone": None,
"customer_email": None,
"customer_tax_id": None,
"ship_to_name": None,
"ship_to_address": None,
"bill_to_name": None,
"bill_to_address": None,
"remit_to_name": None,
"remit_to_address": None,
"tax_id": None,
"tax_registration_number": None,
"vat_number": None,
"payment_terms": None,
"payment_method": None,
"payment_reference": None,
"bank_account_number": None,
"iban": None,
"swift_code": None,
"total_before_tax": None,
"tax_amount": None,
"tax_rate": None,
"shipping_charges": None,
"discount": None,
"total_due": None,
"amount_paid": None,
"balance_due": None,
"due_date": None,
"invoice_status": None,
"reference_number": None,
"project_code": None,
"department": None,
"contact_person": None,
"notes": None,
"additional_info": None
},
"line_items": [
{
"quantity": None,
"units": None,
"description": None,
"footage": None,
"price": None,
"amount": None,
"notes": None
}
]
}
STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
# Synonym map
SYN2KEY: Dict[str, str] = {
"invoice no": "invoice_number",
"invoice number": "invoice_number",
"invoice#": "invoice_number",
"inv no": "invoice_number",
"inv#": "invoice_number",
"invoice date": "invoice_date",
"date of invoice": "invoice_date",
"po no": "purchase_order_number",
"po number": "purchase_order_number",
"purchase order": "purchase_order_number",
"order no": "order_number",
"order number": "order_number",
"sales order": "sales_order_number",
"customer order": "customer_order_number",
"our order": "our_order_number",
"due date": "due_date",
"date of supply": "order_date",
"gstin": "supplier_tax_id",
"gstin no": "supplier_tax_id",
"tax id": "tax_id",
"vat number": "vat_number",
"tax registration number": "tax_registration_number",
"place of supply": "shipping_point",
"state code": "additional_info",
"taxable value": "total_before_tax",
"total value": "total_due",
"total amount": "total_due",
"amount due": "total_due",
"bank": "bank_account_number",
"account no": "bank_account_number",
"account number": "bank_account_number",
"ifs code": "swift_code",
"ifsc": "payment_reference",
"swift code": "swift_code",
"iban": "iban",
"e-way bill no": "reference_number",
"eway bill": "reference_number",
"dispatched via": "additional_info",
"documents dispatched through": "additional_info",
"kind attn": "contact_person",
"billed to": "bill_to_name",
"receiver": "bill_to_name",
"shipped to": "ship_to_name",
"consignee": "ship_to_name",
}
def norm(s: str) -> str:
return re.sub(r"\s+", " ", s).strip()
def deep_copy_schema() -> Dict[str, Any]:
return json.loads(json.dumps(SCHEMA_JSON))
def extract_candidates(text: str) -> Dict[str, str]:
cands: Dict[str, str] = {}
for raw in text.splitlines():
line = raw.strip().strip("|").strip()
if not line:
continue
if ":" in line:
if "|" in raw:
parts = [p.strip() for p in raw.split("|") if p.strip()]
for cell in parts:
if ":" in cell:
k, v = cell.split(":", 1)
cands[norm(k)] = norm(v)
else:
k, v = line.split(":", 1)
cands[norm(k)] = norm(v)
for raw in text.splitlines():
m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
if m:
k = norm(m.group(1))
v = norm(m.group(2))
cands[k] = v
return cands
def regex_extract_all(text: str) -> Dict[str, str]:
out: Dict[str, str] = {}
m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
if m: out["invoice_number"] = m.group(1)
m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
if m: out["invoice_date"] = m.group(1)
m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
if m: out["purchase_order_number"] = m.group(1)
m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
if m: out["order_date"] = m.group(1)
if "order_date" not in out:
m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
if m: out["order_date"] = m.group(1)
m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
if m: out["shipping_point"] = m.group(1).strip(" |")
m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
if m: out["supplier_tax_id"] = m.group(1)
m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
if m: out["total_before_tax"] = m.group(1).replace(",", "")
cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
if cgst and sgst:
try:
tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
out["tax_amount"] = f"{tax_total:.2f}"
cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
if cgstp and sgstp:
try:
rate = float(cgstp.group(1)) + float(sgstp.group(1))
out["tax_rate"] = f"{rate:g}"
except:
pass
except:
pass
m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
if m: out["reference_number"] = m.group(1).strip()
return out
def extract_bank_block(text: str) -> Dict[str, str]:
bank: Dict[str, str] = {}
m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
if m: bank["supplier_name"] = m.group(1).strip()
m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
if m: bank["bank_account_number"] = m.group(1).strip()
m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
if m:
bank["additional_info"] = ("Bank: " + m.group(1).strip())
m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
if m: bank["payment_reference"] = m.group(1).strip()
m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
if m: bank["swift_code"] = m.group(1).strip()
branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
extra_bits = []
if branch: extra_bits.append("Branch: " + branch.group(1).strip())
if micr: extra_bits.append("MICR: " + micr.group(1).strip())
if extra_bits:
bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
return bank
def _has_real_items(items) -> bool:
return (
isinstance(items, list)
and any(
isinstance(row, dict)
and any(val not in (None, "", "null") for val in row.values())
for row in items
)
)
def parse_line_items(text: str) -> List[Dict[str, Any]]:
"""
Dynamic, header-agnostic line-item extractor.
- Auto-detects header row (no hardcoded labels)
- Supports pipe '|' tables, multi-space/tab tables, and stacked/vertical layouts
- Fuzzy maps arbitrary headers to: description, quantity, units, price, amount
- Stitches wrapped descriptions; stops at totals/subtotals
"""
import re
from typing import List, Dict, Any
import torch
from sentence_transformers import SentenceTransformer, util
# ---- local helpers (encapsulated; no external edits required) ----
def _tokenize_row(row: str) -> List[str]:
if "|" in row:
toks = [c.strip(" -") for c in row.split("|")]
else:
toks = re.split(r"\t+| {2,}", row)
toks = [c.strip(" -") for c in toks]
return [t for t in toks if t]
def _looks_like_separator(row: str) -> bool:
return bool(re.fullmatch(r"[-=–—\s]+", row))
def _numlike(s: str) -> bool:
return bool(re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", s.strip()))
def _normalize_num(s: str | None) -> str | None:
if not s: return None
return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None
STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I)
# Canonical targets + synonyms (broad, non-brittle)
CANON = ["description", "quantity", "units", "price", "amount"]
SYN = {
"description": ["description", "item", "details", "product", "material", "article", "part no", "part", "goods desc"],
"quantity": ["qty", "quantity", "qnty", "pcs", "pieces", "units qty", "ordered qty"],
"units": ["uom", "unit", "units", "measure", "type", "pkg", "pack", "u/m"],
"price": ["rate", "price", "unit price", "cost", "u/price", "list price"],
"amount": ["amount", "total", "line total", "ext price", "net", "value", "extended"]
}
def _find_header_idx(lines: List[str]) -> int:
"""Heuristic header detection for horizontal tables."""
for i, row in enumerate(lines):
if _looks_like_separator(row):
continue
toks = _tokenize_row(row)
if len(toks) < 3:
continue
# low numeric density
if sum(_numlike(t) for t in toks) > len(toks) // 2:
continue
# at least 2 synonym hits
hits = 0
lowt = [t.lower() for t in toks]
for t in lowt:
for syns in SYN.values():
if any(s in t for s in syns):
hits += 1
break
if hits >= 2:
return i
return -1
def _map_headers_dynamic(header_tokens: List[str], model) -> Dict[int, str]:
"""
Map arbitrary header tokens to canonical keys via:
1) direct/synonym contains
2) semantic similarity (best match)
"""
mapped: Dict[int, str] = {}
used = set()
low = [h.lower() for h in header_tokens]
# 1) substring / synonyms
for j, h in enumerate(low):
for key, syns in SYN.items():
if any(s in h for s in syns):
if key not in used:
mapped[j] = key
used.add(key)
break
# 2) semantic backstop for unmapped
remaining = [j for j in range(len(header_tokens)) if j not in mapped]
if remaining:
label_texts, label_keys = [], []
for k, syns in SYN.items():
for s in syns + [k]:
label_texts.append(s)
label_keys.append(k)
h_emb = model.encode([header_tokens[i] for i in remaining], normalize_embeddings=True)
l_emb = model.encode(label_texts, normalize_embeddings=True)
sim = util.cos_sim(torch.tensor(h_emb), torch.tensor(l_emb)).cpu().numpy()
for ri, j in enumerate(remaining):
k_best = int(sim[ri].argmax())
key = label_keys[k_best]
if key not in used:
mapped[j] = key
used.add(key)
return mapped
def _parse_horizontal(lines: List[str]) -> List[Dict[str, Any]]:
"""Parse pipe/whitespace horizontal tables with dynamic headers."""
header_idx = _find_header_idx(lines)
if header_idx == -1:
return []
header_tokens = _tokenize_row(lines[header_idx])
# lazy singleton on the function for perf (no external changes)
if not hasattr(parse_line_items, "_sent_model"):
parse_line_items._sent_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # type: ignore[attr-defined]
sm = parse_line_items._sent_model # type: ignore[attr-defined]
idx2key = _map_headers_dynamic(header_tokens, sm)
items: List[Dict[str, Any]] = []
for row in lines[header_idx + 1:]:
if _looks_like_separator(row):
continue
if STOP.search(row):
break
toks = _tokenize_row(row)
# continuation-line heuristic (wrapped description)
if (len(toks) == 1 or len(toks) < (max(idx2key.keys(), default=-1) + 1)) and items:
last = items[-1]
prev = (last.get("description") or "").strip()
last["description"] = (prev + " " + toks[0]).strip() if toks else prev
continue
rowd = {"description": None, "quantity": None, "units": None,
"price": None, "amount": None, "footage": None, "notes": None}
for j, tok in enumerate(toks):
key = idx2key.get(j)
if not key:
continue
val = tok.strip()
if key in ("quantity", "price", "amount"):
val = _normalize_num(val)
rowd[key] = val or rowd.get(key)
if rowd["quantity"] and rowd["units"]:
rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
if any(rowd.get(k) for k in ("description", "amount", "price")):
items.append(rowd)
# prune empties
return [it for it in items if any(v for k, v in it.items() if k != "notes")]
def _parse_vertical(text: str) -> List[Dict[str, Any]]:
"""
Deterministic stacked/vertical parser for blocks like:
Description
Type
Quantity
Rate
Amount
<desc1>
<type1>
<qty1>
<rate1>
<amt1>
<desc2> ...
Stops at totals/subtotals.
"""
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
if not lines:
return []
# Find the exact 5-label header block (order-agnostic but contiguous)
LABELS = ["description", "type", "quantity", "rate", "amount"]
def is_label(s: str) -> str | None:
t = s.lower()
if re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", t):
return None
if "desc" in t or "item" in t or "product" in t or "material" in t or "article" in t:
return "description"
if "type" in t or "uom" in t or "unit" in t or "units" in t:
return "type"
if "qty" in t or "quantity" in t:
return "quantity"
if "rate" in t or "price" in t or "unit price" in t:
return "rate"
if "amount" in t or "total" in t:
return "amount"
return None
start = -1
for i in range(len(lines) - 4):
block = lines[i:i+5]
mapped = [is_label(x) for x in block]
if None not in mapped and len(set(mapped)) == 5:
start = i
header_keys = mapped # e.g. ["description","type","quantity","rate","amount"]
break
if start == -1:
return []
# Build a position→canonical map in this exact order
pos2key = {idx: key for idx, key in enumerate(header_keys)}
# Consume values in chunks of 5
items: List[Dict[str, Any]] = []
i = start + 5
STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I)
def norm_num(s: str | None) -> str | None:
if not s: return None
return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None
while i + 4 < len(lines):
if STOP.search(lines[i]): # hit totals, bail
break
chunk = lines[i:i+5]
row = {"description": None, "units": None, "quantity": None,
"price": None, "amount": None, "footage": None, "notes": None}
# map chunk by discovered order
for j, val in enumerate(chunk):
key = pos2key[j]
if key == "type":
row["units"] = val # map "Type" -> "units"
elif key == "quantity":
row["quantity"] = norm_num(val)
elif key == "rate":
row["price"] = norm_num(val)
elif key == "amount":
row["amount"] = norm_num(val)
elif key == "description":
row["description"] = val
if row["quantity"] and row["units"]:
row["footage"] = f'{row["quantity"]} {row["units"]}'
# minimal acceptance: description or amount or price
if any(row.get(k) for k in ("description", "amount", "price")):
items.append(row)
i += 5
return items
# ---- main body ----
raw_lines = [ln.rstrip() for ln in text.splitlines()]
lines = [ln for ln in raw_lines if ln.strip()]
if not lines:
return []
# 1) Try horizontal first
items = _parse_horizontal(lines)
if items:
return items
# 2) Fallback to vertical/stacked
items = _parse_vertical(text)
return items
def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]:
if not candidates:
return {}
cand_keys = list(candidates.keys())
mapped: Dict[str, str] = {}
leftovers: Dict[str, str] = {}
for k, v in candidates.items():
lk = k.lower()
lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
hit = None
for syn, key in SYN2KEY.items():
if syn in lk_norm:
hit = key
break
if hit:
mapped[hit] = v
else:
leftovers[k] = v
if leftovers:
cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
keys_left = list(leftovers.keys())
for i, ck in enumerate(keys_left):
j = int(np.argmax(M[i]))
score = float(M[i][j])
if score >= thresh:
mapped[static_headers[j]] = leftovers[ck]
return mapped
def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
instruction = (
'Use this schema:\n'
'{\n'
' "invoice_header": {\n'
' "car_number": "string or null",\n'
' "shipment_number": "string or null",\n'
' "shipping_point": "string or null",\n'
' "currency": "string or null",\n'
' "invoice_number": "string or null",\n'
' "invoice_date": "string or null",\n'
' "order_number": "string or null",\n'
' "customer_order_number": "string or null",\n'
' "our_order_number": "string or null",\n'
' "sales_order_number": "string or null",\n'
' "purchase_order_number": "string or null",\n'
' "order_date": "string or null",\n'
' "supplier_name": "string or null",\n'
' "supplier_address": "string or null",\n'
' "supplier_phone": "string or null",\n'
' "supplier_email": "string or null",\n'
' "supplier_tax_id": "string or null",\n'
' "customer_name": "string or null",\n'
' "customer_address": "string or null",\n'
' "customer_phone": "string or null",\n'
' "customer_email": "string or null",\n'
' "customer_tax_id": "string or null",\n'
' "ship_to_name": "string or null",\n'
' "ship_to_address": "string or null",\n'
' "bill_to_name": "string or null",\n'
' "bill_to_address": "string or null",\n'
' "remit_to_name": "string or null",\n'
' "remit_to_address": "string or null",\n'
' "tax_id": "string or null",\n'
' "tax_registration_number": "string or null",\n'
' "vat_number": "string or null",\n'
' "payment_terms": "string or null",\n'
' "payment_method": "string or null",\n'
' "payment_reference": "string or null",\n'
' "bank_account_number": "string or null",\n'
' "iban": "string or null",\n'
' "swift_code": "string or null",\n'
' "total_before_tax": "string or null",\n'
' "tax_amount": "string or null",\n'
' "tax_rate": "string or null",\n'
' "shipping_charges": "string or null",\n'
' "discount": "string or null",\n'
' "total_due": "string or null",\n'
' "amount_paid": "string or null",\n'
' "balance_due": "string or null",\n'
' "due_date": "string or null",\n'
' "invoice_status": "string or null",\n'
' "reference_number": "string or null",\n'
' "project_code": "string or null",\n'
' "department": "string or null",\n'
' "contact_person": "string or null",\n'
' "notes": "string or null",\n'
' "additional_info": "string or null"\n'
' },\n'
' "line_items": [\n'
' {\n'
' "quantity": "string or null",\n'
' "units": "string or null",\n'
' "description": "string or null",\n'
' "footage": "string or null",\n'
' "price": "string or null",\n'
' "amount": "string or null",\n'
' "notes": "string or null"\n'
' }\n'
' ]\n'
'}\n'
'If a field is missing for a line item or header, use null. '
'Do not invent fields. Do not add any header or shipment data to any line item. '
'Return ONLY the JSON object, no explanation.\n'
)
hints = ""
if mapped_hints:
hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
if items_hints:
try:
hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
except:
pass
return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
def strict_json(text: str) -> Dict[str, Any]:
try:
return json.loads(text)
except:
pass
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start:end+1])
except:
pass
raise ValueError("Model did not return valid JSON.")
def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
final = copy.deepcopy(rule_json)
# --- headers (rules win where present) ---
hdr = final["invoice_header"]
mdl_hdr = (model_json.get("invoice_header") or {})
for k in hdr.keys():
if hdr[k] in [None, "", "null"]:
v = mdl_hdr.get(k, None)
if v not in [None, "", "null"]:
hdr[k] = v
# --- line_items (prefer parsed items -> model -> empty) ---
rule_items = rule_json.get("line_items") or []
model_items = model_json.get("line_items") or []
if _has_real_items(rule_items):
final["line_items"] = rule_items
elif _has_real_items(model_items):
final["line_items"] = model_items
else:
final["line_items"] = []
return final
def _prune_empty_items(payload: Dict[str, Any]) -> Dict[str, Any]:
items = payload.get("line_items")
if isinstance(items, list):
payload["line_items"] = [
it for it in items
if isinstance(it, dict) and any(v not in (None, "", "null") for v in it.values())
]
return payload
# ---------------------- MAIN FUNCTION ----------------------
def invoice_text_to_json(
invoice_text: str,
threshold: float = 0.60,
max_new_tokens: int = 512
) -> Dict[str, Any]:
# Load models once (cache if you like for production)
sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
txt = invoice_text
# 1) Deterministic extraction
candidates = extract_candidates(txt)
hard = regex_extract_all(txt)
bank = extract_bank_block(txt)
items = parse_line_items(txt)
print("Extracted line items:", items)
sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model)
header_found: Dict[str, Any] = {}
header_found.update(sem_mapped)
header_found.update(hard)
header_found.update(bank)
# 2) Build RULE JSON (schema-shaped, rules filled)
rule_json = deep_copy_schema()
if _has_real_items(items):
rule_json["line_items"] = items
else:
rule_json["line_items"] = []
for k, v in header_found.items():
if k in rule_json["invoice_header"]:
rule_json["invoice_header"][k] = v
# 3) MD2JSON generation with strong hints
prompt = build_prompt(txt, header_found, items)
gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
try:
model_json = strict_json(gen)
except Exception as e:
model_json = deep_copy_schema() # model failed; keep empty shape
# 4) Final merge (rules win)
final_json = merge_schema(rule_json, model_json)
final_json = _prune_empty_items(final_json)
return final_json
from typing import Optional
# ----- replace old run_ocr with unified dispatcher -----
def run_pipeline(file: Optional[gr.File], raw_txt: Optional[str]) -> str:
"""
Orchestrates two intake lanes:
1) If raw_txt is provided (non-empty), skip OCR → directly map to schema.
2) Else, run OCR on the uploaded file and map to schema.
"""
raw_txt = (raw_txt or "").strip()
# Lane A: Raw text → JSON
if raw_txt:
try:
result_json = invoice_text_to_json(raw_txt)
return json.dumps(result_json, indent=2, ensure_ascii=False)
except Exception as e:
return f"Error while converting pasted text to JSON schema: {e}"
# Lane B: File → OCR → JSON
if not file:
return "No input received. Upload an image/PDF or paste raw text."
try:
name = (file.name or "").lower()
# Load as DocumentFile (handles PNG/JPG/PDF)
if name.endswith(".pdf"):
doc = DocumentFile.from_pdf(file=file.name)
else:
doc = DocumentFile.from_images([file.name])
# Inference
result = MODEL(doc)
exported = result.export()
text = _collect_text_from_export(exported)
if not text:
return "No text detected by OCR."
result_json = invoice_text_to_json(text)
return json.dumps(result_json, indent=2, ensure_ascii=False)
except Exception as e:
return f"OCR pipeline error: {e}"
# ---------- Gradio UI ----------
# ---------- Gradio UI ----------
TITLE = "docTR OCR — Text Extractor"
DESC = (
"Upload an image or PDF OR paste raw text. Uses docTR for OCR or directly maps raw text to the invoice JSON schema."
)
with gr.Blocks(theme="soft", title=TITLE) as demo:
gr.Markdown(f"# {TITLE}\n{DESC}")
with gr.Tabs():
with gr.Tab("Upload File"):
inp = gr.File(
label="Upload image/PDF",
file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"]
)
# keep symmetrical inputs for single-click wiring
raw_txt_hidden = gr.Textbox(visible=False)
with gr.Tab("Paste Raw Text"):
raw_txt = gr.Textbox(
label="Paste raw invoice text (we’ll map directly to JSON schema)",
lines=18,
placeholder="Paste the OCR’d/plain text of the invoice here…"
)
file_hidden = gr.File(visible=False)
out = gr.Code(label="Extracted JSON", language="json")
run_btn = gr.Button("Generate JSON", variant="primary")
# One button → unified function; we pass both lanes (visible/hidden)
run_btn.click(
fn=run_pipeline,
inputs=[inp, raw_txt],
outputs=out,
)
gr.Markdown(
"ℹ️ **Usage:** Prefer *Paste Raw Text* when you already have text. "
"If both file and text are provided, we’ll **prioritize the pasted text**."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)