KarthiEz commited on
Commit
17e6b8d
·
verified ·
1 Parent(s): 0616ae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +818 -558
app.py CHANGED
@@ -1,558 +1,818 @@
1
- import os
2
- import io
3
- from typing import List
4
- import gradio as gr
5
- # docTR imports (PyTorch backend)
6
- from doctr.io import DocumentFile
7
- from doctr.models import ocr_predictor
8
-
9
- # ---------- One-time model bootstrap (CPU-friendly) ----------
10
- # Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend.
11
- # You can optionally pin threads for stability on small CPU runners:
12
- os.environ.setdefault("OMP_NUM_THREADS", "4")
13
- os.environ.setdefault("MKL_NUM_THREADS", "4")
14
-
15
- MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch
16
-
17
- def _collect_text_from_export(exported: dict) -> str:
18
- """Flatten docTR exported structure into newline-separated text per page."""
19
- pages: List[dict] = exported.get("pages", [])
20
- text_pages: List[str] = []
21
-
22
- for page in pages:
23
- page_lines = []
24
- for block in page.get("blocks", []):
25
- for line in block.get("lines", []):
26
- # Join word values in the line; fallback robustly
27
- words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)]
28
- line_text = " ".join([w for w in words if w])
29
- if line_text.strip():
30
- page_lines.append(line_text)
31
- text_pages.append("\n".join(page_lines).strip())
32
-
33
- # Join pages with a page delimiter
34
- return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join(
35
- [tp for tp in text_pages if tp]
36
- ).strip()
37
-
38
- def run_ocr(file: gr.File) -> str:
39
- if file is None:
40
- return "No file received."
41
-
42
- name = (file.name or "").lower()
43
-
44
- # Load as DocumentFile (handles PNG/JPG/PDF)
45
- if name.endswith(".pdf"):
46
- # Render PDF pages via pdfium backend under the hood (CPU OK)
47
- doc = DocumentFile.from_pdf(file=file.name)
48
- else:
49
- # Single image fallback; also works for TIFF/PNG/JPG
50
- doc = DocumentFile.from_images([file.name])
51
-
52
- # Inference
53
- result = MODEL(doc)
54
- exported = result.export()
55
- text = _collect_text_from_export(exported)
56
- print("Extracted Text:\n", text)
57
-
58
- if not text:
59
- return "No text detected."
60
- result_json = invoice_text_to_json(text)
61
- print(json.dumps(result_json, indent=2))
62
- string_json = json.dumps(result_json, indent=2)
63
- return string_json
64
-
65
- import re
66
- import json
67
- from typing import List, Dict, Any
68
- import copy
69
- import numpy as np
70
- import torch
71
- from transformers import pipeline
72
- from sentence_transformers import SentenceTransformer, util
73
-
74
- # ----------------------------- Schema -----------------------------
75
- SCHEMA_JSON: Dict[str, Any] = {
76
- "invoice_header": {
77
- "car_number": None,
78
- "shipment_number": None,
79
- "shipping_point": None,
80
- "currency": None,
81
- "invoice_number": None,
82
- "invoice_date": None,
83
- "order_number": None,
84
- "customer_order_number": None,
85
- "our_order_number": None,
86
- "sales_order_number": None,
87
- "purchase_order_number": None,
88
- "order_date": None,
89
- "supplier_name": None,
90
- "supplier_address": None,
91
- "supplier_phone": None,
92
- "supplier_email": None,
93
- "supplier_tax_id": None,
94
- "customer_name": None,
95
- "customer_address": None,
96
- "customer_phone": None,
97
- "customer_email": None,
98
- "customer_tax_id": None,
99
- "ship_to_name": None,
100
- "ship_to_address": None,
101
- "bill_to_name": None,
102
- "bill_to_address": None,
103
- "remit_to_name": None,
104
- "remit_to_address": None,
105
- "tax_id": None,
106
- "tax_registration_number": None,
107
- "vat_number": None,
108
- "payment_terms": None,
109
- "payment_method": None,
110
- "payment_reference": None,
111
- "bank_account_number": None,
112
- "iban": None,
113
- "swift_code": None,
114
- "total_before_tax": None,
115
- "tax_amount": None,
116
- "tax_rate": None,
117
- "shipping_charges": None,
118
- "discount": None,
119
- "total_due": None,
120
- "amount_paid": None,
121
- "balance_due": None,
122
- "due_date": None,
123
- "invoice_status": None,
124
- "reference_number": None,
125
- "project_code": None,
126
- "department": None,
127
- "contact_person": None,
128
- "notes": None,
129
- "additional_info": None
130
- },
131
- "line_items": [
132
- {
133
- "quantity": None,
134
- "units": None,
135
- "description": None,
136
- "footage": None,
137
- "price": None,
138
- "amount": None,
139
- "notes": None
140
- }
141
- ]
142
- }
143
- STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
144
-
145
- # Synonym map
146
- SYN2KEY: Dict[str, str] = {
147
- "invoice no": "invoice_number",
148
- "invoice number": "invoice_number",
149
- "invoice#": "invoice_number",
150
- "inv no": "invoice_number",
151
- "inv#": "invoice_number",
152
- "invoice date": "invoice_date",
153
- "date of invoice": "invoice_date",
154
- "po no": "purchase_order_number",
155
- "po number": "purchase_order_number",
156
- "purchase order": "purchase_order_number",
157
- "order no": "order_number",
158
- "order number": "order_number",
159
- "sales order": "sales_order_number",
160
- "customer order": "customer_order_number",
161
- "our order": "our_order_number",
162
- "due date": "due_date",
163
- "date of supply": "order_date",
164
- "gstin": "supplier_tax_id",
165
- "gstin no": "supplier_tax_id",
166
- "tax id": "tax_id",
167
- "vat number": "vat_number",
168
- "tax registration number": "tax_registration_number",
169
- "place of supply": "shipping_point",
170
- "state code": "additional_info",
171
- "taxable value": "total_before_tax",
172
- "total value": "total_due",
173
- "total amount": "total_due",
174
- "amount due": "total_due",
175
- "bank": "bank_account_number",
176
- "account no": "bank_account_number",
177
- "account number": "bank_account_number",
178
- "ifs code": "swift_code",
179
- "ifsc": "payment_reference",
180
- "swift code": "swift_code",
181
- "iban": "iban",
182
- "e-way bill no": "reference_number",
183
- "eway bill": "reference_number",
184
- "dispatched via": "additional_info",
185
- "documents dispatched through": "additional_info",
186
- "kind attn": "contact_person",
187
- "billed to": "bill_to_name",
188
- "receiver": "bill_to_name",
189
- "shipped to": "ship_to_name",
190
- "consignee": "ship_to_name",
191
- }
192
-
193
- def norm(s: str) -> str:
194
- return re.sub(r"\s+", " ", s).strip()
195
-
196
- def deep_copy_schema() -> Dict[str, Any]:
197
- return json.loads(json.dumps(SCHEMA_JSON))
198
-
199
- def extract_candidates(text: str) -> Dict[str, str]:
200
- cands: Dict[str, str] = {}
201
- for raw in text.splitlines():
202
- line = raw.strip().strip("|").strip()
203
- if not line:
204
- continue
205
- if ":" in line:
206
- if "|" in raw:
207
- parts = [p.strip() for p in raw.split("|") if p.strip()]
208
- for cell in parts:
209
- if ":" in cell:
210
- k, v = cell.split(":", 1)
211
- cands[norm(k)] = norm(v)
212
- else:
213
- k, v = line.split(":", 1)
214
- cands[norm(k)] = norm(v)
215
- for raw in text.splitlines():
216
- m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
217
- if m:
218
- k = norm(m.group(1))
219
- v = norm(m.group(2))
220
- cands[k] = v
221
- return cands
222
-
223
- def regex_extract_all(text: str) -> Dict[str, str]:
224
- out: Dict[str, str] = {}
225
- m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
226
- if m: out["invoice_number"] = m.group(1)
227
- m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
228
- if m: out["invoice_date"] = m.group(1)
229
- m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
230
- if m: out["purchase_order_number"] = m.group(1)
231
- m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
232
- if m: out["order_date"] = m.group(1)
233
- if "order_date" not in out:
234
- m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
235
- if m: out["order_date"] = m.group(1)
236
- m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
237
- if m: out["shipping_point"] = m.group(1).strip(" |")
238
- m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
239
- if m: out["supplier_tax_id"] = m.group(1)
240
- m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
241
- if m: out["total_before_tax"] = m.group(1).replace(",", "")
242
- cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
243
- sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
244
- if cgst and sgst:
245
- try:
246
- tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
247
- out["tax_amount"] = f"{tax_total:.2f}"
248
- cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
249
- sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
250
- if cgstp and sgstp:
251
- try:
252
- rate = float(cgstp.group(1)) + float(sgstp.group(1))
253
- out["tax_rate"] = f"{rate:g}"
254
- except:
255
- pass
256
- except:
257
- pass
258
- m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
259
- if m: out["reference_number"] = m.group(1).strip()
260
- return out
261
-
262
- def extract_bank_block(text: str) -> Dict[str, str]:
263
- bank: Dict[str, str] = {}
264
- m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
265
- if m: bank["supplier_name"] = m.group(1).strip()
266
- m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
267
- if m: bank["bank_account_number"] = m.group(1).strip()
268
- m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
269
- if m:
270
- bank["additional_info"] = ("Bank: " + m.group(1).strip())
271
- m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
272
- if m: bank["payment_reference"] = m.group(1).strip()
273
- m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
274
- if m: bank["swift_code"] = m.group(1).strip()
275
- branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
276
- micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
277
- extra_bits = []
278
- if branch: extra_bits.append("Branch: " + branch.group(1).strip())
279
- if micr: extra_bits.append("MICR: " + micr.group(1).strip())
280
- if extra_bits:
281
- bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
282
- return bank
283
-
284
- def parse_line_items(text: str) -> List[Dict[str, Any]]:
285
- items: List[Dict[str, Any]] = []
286
- lines = [ln for ln in text.splitlines() if ln.strip()]
287
- header_idx = -1
288
- for i, ln in enumerate(lines):
289
- if ("|") in ln and ("Description" in ln and ("Qty" in ln or "QTY" in ln)) and ("Rate" in ln or "Price" in ln) and ("Total" in ln):
290
- header_idx = i
291
- break
292
- if header_idx == -1:
293
- return items
294
- headers = [c.strip().lower() for c in lines[header_idx].split("|")]
295
- headers = [h for h in headers if h and set(h) - set("-")]
296
- for j in range(header_idx + 1, len(lines)):
297
- row = lines[j]
298
- if row.strip().startswith("|") and row.count("|") >= 2:
299
- cells = [c.strip() for c in row.split("|")]
300
- cells = [c for c in cells if c and set(c) - set("-")]
301
- if len(cells) < 3:
302
- continue
303
- rowd = {"quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None}
304
- def idx_of(name_parts: List[str]) -> int:
305
- for k, h in enumerate(headers):
306
- if any(p in h for p in name_parts):
307
- return k
308
- return -1
309
- i_desc = idx_of(["description", "item"])
310
- i_qty = idx_of(["qty", "quantity"])
311
- i_uom = idx_of(["uom", "unit"])
312
- i_rate = idx_of(["rate", "price"])
313
- i_amt = idx_of(["total value", "amount", "total"])
314
- def safe(i: int) -> str:
315
- return cells[i] if 0 <= i < len(cells) else ""
316
- if i_desc != -1: rowd["description"] = safe(i_desc) or None
317
- if i_qty != -1: rowd["quantity"] = safe(i_qty) or None
318
- if i_uom != -1: rowd["units"] = safe(i_uom) or None
319
- if i_rate != -1: rowd["price"] = safe(i_rate) or None
320
- if i_amt != -1: rowd["amount"] = safe(i_amt) or None
321
- if rowd["units"] and rowd["quantity"]:
322
- rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
323
- items.append(rowd)
324
- else:
325
- if j > header_idx + 1:
326
- break
327
- return items
328
-
329
- def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]:
330
- if not candidates:
331
- return {}
332
- cand_keys = list(candidates.keys())
333
- mapped: Dict[str, str] = {}
334
- leftovers: Dict[str, str] = {}
335
- for k, v in candidates.items():
336
- lk = k.lower()
337
- lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
338
- hit = None
339
- for syn, key in SYN2KEY.items():
340
- if syn in lk_norm:
341
- hit = key
342
- break
343
- if hit:
344
- mapped[hit] = v
345
- else:
346
- leftovers[k] = v
347
- if leftovers:
348
- cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
349
- head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
350
- M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
351
- keys_left = list(leftovers.keys())
352
- for i, ck in enumerate(keys_left):
353
- j = int(np.argmax(M[i]))
354
- score = float(M[i][j])
355
- if score >= thresh:
356
- mapped[static_headers[j]] = leftovers[ck]
357
- return mapped
358
-
359
- def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
360
- instruction = (
361
- 'Use this schema:\n'
362
- '{\n'
363
- ' "invoice_header": {\n'
364
- ' "car_number": "string or null",\n'
365
- ' "shipment_number": "string or null",\n'
366
- ' "shipping_point": "string or null",\n'
367
- ' "currency": "string or null",\n'
368
- ' "invoice_number": "string or null",\n'
369
- ' "invoice_date": "string or null",\n'
370
- ' "order_number": "string or null",\n'
371
- ' "customer_order_number": "string or null",\n'
372
- ' "our_order_number": "string or null",\n'
373
- ' "sales_order_number": "string or null",\n'
374
- ' "purchase_order_number": "string or null",\n'
375
- ' "order_date": "string or null",\n'
376
- ' "supplier_name": "string or null",\n'
377
- ' "supplier_address": "string or null",\n'
378
- ' "supplier_phone": "string or null",\n'
379
- ' "supplier_email": "string or null",\n'
380
- ' "supplier_tax_id": "string or null",\n'
381
- ' "customer_name": "string or null",\n'
382
- ' "customer_address": "string or null",\n'
383
- ' "customer_phone": "string or null",\n'
384
- ' "customer_email": "string or null",\n'
385
- ' "customer_tax_id": "string or null",\n'
386
- ' "ship_to_name": "string or null",\n'
387
- ' "ship_to_address": "string or null",\n'
388
- ' "bill_to_name": "string or null",\n'
389
- ' "bill_to_address": "string or null",\n'
390
- ' "remit_to_name": "string or null",\n'
391
- ' "remit_to_address": "string or null",\n'
392
- ' "tax_id": "string or null",\n'
393
- ' "tax_registration_number": "string or null",\n'
394
- ' "vat_number": "string or null",\n'
395
- ' "payment_terms": "string or null",\n'
396
- ' "payment_method": "string or null",\n'
397
- ' "payment_reference": "string or null",\n'
398
- ' "bank_account_number": "string or null",\n'
399
- ' "iban": "string or null",\n'
400
- ' "swift_code": "string or null",\n'
401
- ' "total_before_tax": "string or null",\n'
402
- ' "tax_amount": "string or null",\n'
403
- ' "tax_rate": "string or null",\n'
404
- ' "shipping_charges": "string or null",\n'
405
- ' "discount": "string or null",\n'
406
- ' "total_due": "string or null",\n'
407
- ' "amount_paid": "string or null",\n'
408
- ' "balance_due": "string or null",\n'
409
- ' "due_date": "string or null",\n'
410
- ' "invoice_status": "string or null",\n'
411
- ' "reference_number": "string or null",\n'
412
- ' "project_code": "string or null",\n'
413
- ' "department": "string or null",\n'
414
- ' "contact_person": "string or null",\n'
415
- ' "notes": "string or null",\n'
416
- ' "additional_info": "string or null"\n'
417
- ' },\n'
418
- ' "line_items": [\n'
419
- ' {\n'
420
- ' "quantity": "string or null",\n'
421
- ' "units": "string or null",\n'
422
- ' "description": "string or null",\n'
423
- ' "footage": "string or null",\n'
424
- ' "price": "string or null",\n'
425
- ' "amount": "string or null",\n'
426
- ' "notes": "string or null"\n'
427
- ' }\n'
428
- ' ]\n'
429
- '}\n'
430
- 'If a field is missing for a line item or header, use null. '
431
- 'Do not invent fields. Do not add any header or shipment data to any line item. '
432
- 'Return ONLY the JSON object, no explanation.\n'
433
- )
434
- hints = ""
435
- if mapped_hints:
436
- hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
437
- if items_hints:
438
- try:
439
- hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
440
- except:
441
- pass
442
- return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
443
-
444
- def strict_json(text: str) -> Dict[str, Any]:
445
- try:
446
- return json.loads(text)
447
- except:
448
- pass
449
- start = text.find("{")
450
- end = text.rfind("}")
451
- if start != -1 and end != -1 and end > start:
452
- try:
453
- return json.loads(text[start:end+1])
454
- except:
455
- pass
456
- raise ValueError("Model did not return valid JSON.")
457
-
458
- def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
459
- final = copy.deepcopy(rule_json)
460
- hdr = final["invoice_header"]
461
- mdl_hdr = (model_json.get("invoice_header") or {})
462
- for k in hdr.keys():
463
- if hdr[k] in [None, "", "null"]:
464
- v = mdl_hdr.get(k, None)
465
- if v not in [None, "", "null"]:
466
- hdr[k] = v
467
- if final["line_items"] and any(any(v for v in row.values() if v not in [None, "", "null"]) for row in final["line_items"]):
468
- pass
469
- else:
470
- mdl_items = model_json.get("line_items")
471
- if isinstance(mdl_items, list) and mdl_items:
472
- final["line_items"] = mdl_items
473
- return final
474
-
475
- # ---------------------- MAIN FUNCTION ----------------------
476
- def invoice_text_to_json(
477
- invoice_text: str,
478
- threshold: float = 0.60,
479
- max_new_tokens: int = 512
480
- ) -> Dict[str, Any]:
481
- # Load models once (cache if you like for production)
482
- sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
483
- json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
484
-
485
- txt = invoice_text
486
-
487
- # 1) Deterministic extraction
488
- candidates = extract_candidates(txt)
489
- hard = regex_extract_all(txt)
490
- bank = extract_bank_block(txt)
491
- items = parse_line_items(txt)
492
- sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model)
493
- header_found: Dict[str, Any] = {}
494
- header_found.update(sem_mapped)
495
- header_found.update(hard)
496
- header_found.update(bank)
497
-
498
- # 2) Build RULE JSON (schema-shaped, rules filled)
499
- rule_json = deep_copy_schema()
500
- for k, v in header_found.items():
501
- if k in rule_json["invoice_header"]:
502
- rule_json["invoice_header"][k] = v
503
- if items:
504
- rule_json["line_items"] = items
505
-
506
- # 3) MD2JSON generation with strong hints
507
- prompt = build_prompt(txt, header_found, items)
508
- gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
509
- try:
510
- model_json = strict_json(gen)
511
- except Exception as e:
512
- model_json = deep_copy_schema() # model failed; keep empty shape
513
-
514
- # 4) Final merge (rules win)
515
- final_json = merge_schema(rule_json, model_json)
516
- return final_json
517
-
518
- # ---------- Gradio UI ----------
519
- TITLE = "docTR OCR — Text Extractor"
520
- DESC = (
521
- "Upload an image or PDF. This Space uses Mindee's docTR (PyTorch backend) to detect & recognize text, "
522
- "and returns plain text per page. CPU-friendly and ready for enterprise prototyping."
523
- )
524
-
525
- with gr.Blocks(theme="soft", title=TITLE) as demo:
526
- gr.Markdown(f"# {TITLE}\n{DESC}")
527
-
528
- with gr.Row():
529
- inp = gr.File(label="Upload image/PDF", file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"])
530
- out = gr.Code(label="Extracted JSON", language="json")
531
-
532
-
533
- run_btn = gr.Button("Run OCR", variant="primary")
534
- run_btn.click(fn=run_ocr, inputs=inp, outputs=out)
535
-
536
- gr.Examples(
537
- examples=[
538
- # You can drop a couple of public sample URLs here if desired,
539
- # but Spaces won't auto-download without code. Leave empty by default.
540
- ],
541
- inputs=inp,
542
- outputs=out,
543
- cache_examples=False,
544
- label="(Optional) Examples"
545
- )
546
-
547
- gr.Markdown(
548
- "Tip: For multi-page PDFs, the output shows a **PAGE BREAK** separator between pages.\n"
549
- "For production pipelines, capture this output and route it to your parsing/LLM layer."
550
- )
551
-
552
- if __name__ == "__main__":
553
- demo.launch(
554
- server_name="0.0.0.0",
555
- server_port=7860,
556
- share=True,
557
- show_error=True
558
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from typing import List
4
+ import gradio as gr
5
+ # docTR imports (PyTorch backend)
6
+ from doctr.io import DocumentFile
7
+ from doctr.models import ocr_predictor
8
+
9
+ # ---------- One-time model bootstrap (CPU-friendly) ----------
10
+ # Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend.
11
+ # You can optionally pin threads for stability on small CPU runners:
12
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
13
+ os.environ.setdefault("MKL_NUM_THREADS", "4")
14
+
15
+ MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch
16
+
17
+ def _collect_text_from_export(exported: dict) -> str:
18
+ """Flatten docTR exported structure into newline-separated text per page."""
19
+ pages: List[dict] = exported.get("pages", [])
20
+ text_pages: List[str] = []
21
+
22
+ for page in pages:
23
+ page_lines = []
24
+ for block in page.get("blocks", []):
25
+ for line in block.get("lines", []):
26
+ # Join word values in the line; fallback robustly
27
+ words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)]
28
+ line_text = " ".join([w for w in words if w])
29
+ if line_text.strip():
30
+ page_lines.append(line_text)
31
+ text_pages.append("\n".join(page_lines).strip())
32
+
33
+ # Join pages with a page delimiter
34
+ return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join(
35
+ [tp for tp in text_pages if tp]
36
+ ).strip()
37
+
38
+ def run_ocr(file: gr.File) -> str:
39
+ if file is None:
40
+ return "No file received."
41
+
42
+ name = (file.name or "").lower()
43
+
44
+ # Load as DocumentFile (handles PNG/JPG/PDF)
45
+ if name.endswith(".pdf"):
46
+ # Render PDF pages via pdfium backend under the hood (CPU OK)
47
+ doc = DocumentFile.from_pdf(file=file.name)
48
+ else:
49
+ # Single image fallback; also works for TIFF/PNG/JPG
50
+ doc = DocumentFile.from_images([file.name])
51
+
52
+ # Inference
53
+ result = MODEL(doc)
54
+ exported = result.export()
55
+ text = _collect_text_from_export(exported)
56
+ print("Extracted Text:\n", text)
57
+
58
+ if not text:
59
+ return "No text detected."
60
+ result_json = invoice_text_to_json(text)
61
+ print(json.dumps(result_json, indent=2))
62
+ string_json = json.dumps(result_json, indent=2)
63
+ return string_json
64
+
65
+ import re
66
+ import json
67
+ from typing import List, Dict, Any
68
+ import copy
69
+ import numpy as np
70
+ import torch
71
+ from transformers import pipeline
72
+ from sentence_transformers import SentenceTransformer, util
73
+
74
+ # ----------------------------- Schema -----------------------------
75
+ SCHEMA_JSON: Dict[str, Any] = {
76
+ "invoice_header": {
77
+ "car_number": None,
78
+ "shipment_number": None,
79
+ "shipping_point": None,
80
+ "currency": None,
81
+ "invoice_number": None,
82
+ "invoice_date": None,
83
+ "order_number": None,
84
+ "customer_order_number": None,
85
+ "our_order_number": None,
86
+ "sales_order_number": None,
87
+ "purchase_order_number": None,
88
+ "order_date": None,
89
+ "supplier_name": None,
90
+ "supplier_address": None,
91
+ "supplier_phone": None,
92
+ "supplier_email": None,
93
+ "supplier_tax_id": None,
94
+ "customer_name": None,
95
+ "customer_address": None,
96
+ "customer_phone": None,
97
+ "customer_email": None,
98
+ "customer_tax_id": None,
99
+ "ship_to_name": None,
100
+ "ship_to_address": None,
101
+ "bill_to_name": None,
102
+ "bill_to_address": None,
103
+ "remit_to_name": None,
104
+ "remit_to_address": None,
105
+ "tax_id": None,
106
+ "tax_registration_number": None,
107
+ "vat_number": None,
108
+ "payment_terms": None,
109
+ "payment_method": None,
110
+ "payment_reference": None,
111
+ "bank_account_number": None,
112
+ "iban": None,
113
+ "swift_code": None,
114
+ "total_before_tax": None,
115
+ "tax_amount": None,
116
+ "tax_rate": None,
117
+ "shipping_charges": None,
118
+ "discount": None,
119
+ "total_due": None,
120
+ "amount_paid": None,
121
+ "balance_due": None,
122
+ "due_date": None,
123
+ "invoice_status": None,
124
+ "reference_number": None,
125
+ "project_code": None,
126
+ "department": None,
127
+ "contact_person": None,
128
+ "notes": None,
129
+ "additional_info": None
130
+ },
131
+ "line_items": [
132
+ {
133
+ "quantity": None,
134
+ "units": None,
135
+ "description": None,
136
+ "footage": None,
137
+ "price": None,
138
+ "amount": None,
139
+ "notes": None
140
+ }
141
+ ]
142
+ }
143
+ STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
144
+
145
+ # Synonym map
146
+ SYN2KEY: Dict[str, str] = {
147
+ "invoice no": "invoice_number",
148
+ "invoice number": "invoice_number",
149
+ "invoice#": "invoice_number",
150
+ "inv no": "invoice_number",
151
+ "inv#": "invoice_number",
152
+ "invoice date": "invoice_date",
153
+ "date of invoice": "invoice_date",
154
+ "po no": "purchase_order_number",
155
+ "po number": "purchase_order_number",
156
+ "purchase order": "purchase_order_number",
157
+ "order no": "order_number",
158
+ "order number": "order_number",
159
+ "sales order": "sales_order_number",
160
+ "customer order": "customer_order_number",
161
+ "our order": "our_order_number",
162
+ "due date": "due_date",
163
+ "date of supply": "order_date",
164
+ "gstin": "supplier_tax_id",
165
+ "gstin no": "supplier_tax_id",
166
+ "tax id": "tax_id",
167
+ "vat number": "vat_number",
168
+ "tax registration number": "tax_registration_number",
169
+ "place of supply": "shipping_point",
170
+ "state code": "additional_info",
171
+ "taxable value": "total_before_tax",
172
+ "total value": "total_due",
173
+ "total amount": "total_due",
174
+ "amount due": "total_due",
175
+ "bank": "bank_account_number",
176
+ "account no": "bank_account_number",
177
+ "account number": "bank_account_number",
178
+ "ifs code": "swift_code",
179
+ "ifsc": "payment_reference",
180
+ "swift code": "swift_code",
181
+ "iban": "iban",
182
+ "e-way bill no": "reference_number",
183
+ "eway bill": "reference_number",
184
+ "dispatched via": "additional_info",
185
+ "documents dispatched through": "additional_info",
186
+ "kind attn": "contact_person",
187
+ "billed to": "bill_to_name",
188
+ "receiver": "bill_to_name",
189
+ "shipped to": "ship_to_name",
190
+ "consignee": "ship_to_name",
191
+ }
192
+
193
+ def norm(s: str) -> str:
194
+ return re.sub(r"\s+", " ", s).strip()
195
+
196
+ def deep_copy_schema() -> Dict[str, Any]:
197
+ return json.loads(json.dumps(SCHEMA_JSON))
198
+
199
+ def extract_candidates(text: str) -> Dict[str, str]:
200
+ cands: Dict[str, str] = {}
201
+ for raw in text.splitlines():
202
+ line = raw.strip().strip("|").strip()
203
+ if not line:
204
+ continue
205
+ if ":" in line:
206
+ if "|" in raw:
207
+ parts = [p.strip() for p in raw.split("|") if p.strip()]
208
+ for cell in parts:
209
+ if ":" in cell:
210
+ k, v = cell.split(":", 1)
211
+ cands[norm(k)] = norm(v)
212
+ else:
213
+ k, v = line.split(":", 1)
214
+ cands[norm(k)] = norm(v)
215
+ for raw in text.splitlines():
216
+ m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
217
+ if m:
218
+ k = norm(m.group(1))
219
+ v = norm(m.group(2))
220
+ cands[k] = v
221
+ return cands
222
+
223
+ def regex_extract_all(text: str) -> Dict[str, str]:
224
+ out: Dict[str, str] = {}
225
+ m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
226
+ if m: out["invoice_number"] = m.group(1)
227
+ m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
228
+ if m: out["invoice_date"] = m.group(1)
229
+ m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
230
+ if m: out["purchase_order_number"] = m.group(1)
231
+ m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
232
+ if m: out["order_date"] = m.group(1)
233
+ if "order_date" not in out:
234
+ m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
235
+ if m: out["order_date"] = m.group(1)
236
+ m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
237
+ if m: out["shipping_point"] = m.group(1).strip(" |")
238
+ m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
239
+ if m: out["supplier_tax_id"] = m.group(1)
240
+ m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
241
+ if m: out["total_before_tax"] = m.group(1).replace(",", "")
242
+ cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
243
+ sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
244
+ if cgst and sgst:
245
+ try:
246
+ tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
247
+ out["tax_amount"] = f"{tax_total:.2f}"
248
+ cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
249
+ sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
250
+ if cgstp and sgstp:
251
+ try:
252
+ rate = float(cgstp.group(1)) + float(sgstp.group(1))
253
+ out["tax_rate"] = f"{rate:g}"
254
+ except:
255
+ pass
256
+ except:
257
+ pass
258
+ m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
259
+ if m: out["reference_number"] = m.group(1).strip()
260
+ return out
261
+
262
+ def extract_bank_block(text: str) -> Dict[str, str]:
263
+ bank: Dict[str, str] = {}
264
+ m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
265
+ if m: bank["supplier_name"] = m.group(1).strip()
266
+ m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
267
+ if m: bank["bank_account_number"] = m.group(1).strip()
268
+ m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
269
+ if m:
270
+ bank["additional_info"] = ("Bank: " + m.group(1).strip())
271
+ m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
272
+ if m: bank["payment_reference"] = m.group(1).strip()
273
+ m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
274
+ if m: bank["swift_code"] = m.group(1).strip()
275
+ branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
276
+ micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
277
+ extra_bits = []
278
+ if branch: extra_bits.append("Branch: " + branch.group(1).strip())
279
+ if micr: extra_bits.append("MICR: " + micr.group(1).strip())
280
+ if extra_bits:
281
+ bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
282
+ return bank
283
+
284
+ def _has_real_items(items) -> bool:
285
+ return (
286
+ isinstance(items, list)
287
+ and any(
288
+ isinstance(row, dict)
289
+ and any(val not in (None, "", "null") for val in row.values())
290
+ for row in items
291
+ )
292
+ )
293
+
294
+ def parse_line_items(text: str) -> List[Dict[str, Any]]:
295
+ """
296
+ Dynamic, header-agnostic line-item extractor.
297
+ - Auto-detects header row (no hardcoded labels)
298
+ - Supports pipe '|' tables, multi-space/tab tables, and stacked/vertical layouts
299
+ - Fuzzy maps arbitrary headers to: description, quantity, units, price, amount
300
+ - Stitches wrapped descriptions; stops at totals/subtotals
301
+ """
302
+ import re
303
+ from typing import List, Dict, Any
304
+ import torch
305
+ from sentence_transformers import SentenceTransformer, util
306
+
307
+ # ---- local helpers (encapsulated; no external edits required) ----
308
+ def _tokenize_row(row: str) -> List[str]:
309
+ if "|" in row:
310
+ toks = [c.strip(" -") for c in row.split("|")]
311
+ else:
312
+ toks = re.split(r"\t+| {2,}", row)
313
+ toks = [c.strip(" -") for c in toks]
314
+ return [t for t in toks if t]
315
+
316
+ def _looks_like_separator(row: str) -> bool:
317
+ return bool(re.fullmatch(r"[-=–—\s]+", row))
318
+
319
+ def _numlike(s: str) -> bool:
320
+ return bool(re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", s.strip()))
321
+
322
+ def _normalize_num(s: str | None) -> str | None:
323
+ if not s: return None
324
+ return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None
325
+
326
+ STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I)
327
+
328
+ # Canonical targets + synonyms (broad, non-brittle)
329
+ CANON = ["description", "quantity", "units", "price", "amount"]
330
+ SYN = {
331
+ "description": ["description", "item", "details", "product", "material", "article", "part no", "part", "goods desc"],
332
+ "quantity": ["qty", "quantity", "qnty", "pcs", "pieces", "units qty", "ordered qty"],
333
+ "units": ["uom", "unit", "units", "measure", "type", "pkg", "pack", "u/m"],
334
+ "price": ["rate", "price", "unit price", "cost", "u/price", "list price"],
335
+ "amount": ["amount", "total", "line total", "ext price", "net", "value", "extended"]
336
+ }
337
+
338
+ def _find_header_idx(lines: List[str]) -> int:
339
+ """Heuristic header detection for horizontal tables."""
340
+ for i, row in enumerate(lines):
341
+ if _looks_like_separator(row):
342
+ continue
343
+ toks = _tokenize_row(row)
344
+ if len(toks) < 3:
345
+ continue
346
+ # low numeric density
347
+ if sum(_numlike(t) for t in toks) > len(toks) // 2:
348
+ continue
349
+ # at least 2 synonym hits
350
+ hits = 0
351
+ lowt = [t.lower() for t in toks]
352
+ for t in lowt:
353
+ for syns in SYN.values():
354
+ if any(s in t for s in syns):
355
+ hits += 1
356
+ break
357
+ if hits >= 2:
358
+ return i
359
+ return -1
360
+
361
+ def _map_headers_dynamic(header_tokens: List[str], model) -> Dict[int, str]:
362
+ """
363
+ Map arbitrary header tokens to canonical keys via:
364
+ 1) direct/synonym contains
365
+ 2) semantic similarity (best match)
366
+ """
367
+ mapped: Dict[int, str] = {}
368
+ used = set()
369
+
370
+ low = [h.lower() for h in header_tokens]
371
+ # 1) substring / synonyms
372
+ for j, h in enumerate(low):
373
+ for key, syns in SYN.items():
374
+ if any(s in h for s in syns):
375
+ if key not in used:
376
+ mapped[j] = key
377
+ used.add(key)
378
+ break
379
+
380
+ # 2) semantic backstop for unmapped
381
+ remaining = [j for j in range(len(header_tokens)) if j not in mapped]
382
+ if remaining:
383
+ label_texts, label_keys = [], []
384
+ for k, syns in SYN.items():
385
+ for s in syns + [k]:
386
+ label_texts.append(s)
387
+ label_keys.append(k)
388
+ h_emb = model.encode([header_tokens[i] for i in remaining], normalize_embeddings=True)
389
+ l_emb = model.encode(label_texts, normalize_embeddings=True)
390
+ sim = util.cos_sim(torch.tensor(h_emb), torch.tensor(l_emb)).cpu().numpy()
391
+ for ri, j in enumerate(remaining):
392
+ k_best = int(sim[ri].argmax())
393
+ key = label_keys[k_best]
394
+ if key not in used:
395
+ mapped[j] = key
396
+ used.add(key)
397
+
398
+ return mapped
399
+
400
+ def _parse_horizontal(lines: List[str]) -> List[Dict[str, Any]]:
401
+ """Parse pipe/whitespace horizontal tables with dynamic headers."""
402
+ header_idx = _find_header_idx(lines)
403
+ if header_idx == -1:
404
+ return []
405
+
406
+ header_tokens = _tokenize_row(lines[header_idx])
407
+
408
+ # lazy singleton on the function for perf (no external changes)
409
+ if not hasattr(parse_line_items, "_sent_model"):
410
+ parse_line_items._sent_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # type: ignore[attr-defined]
411
+ sm = parse_line_items._sent_model # type: ignore[attr-defined]
412
+
413
+ idx2key = _map_headers_dynamic(header_tokens, sm)
414
+
415
+ items: List[Dict[str, Any]] = []
416
+ for row in lines[header_idx + 1:]:
417
+ if _looks_like_separator(row):
418
+ continue
419
+ if STOP.search(row):
420
+ break
421
+
422
+ toks = _tokenize_row(row)
423
+
424
+ # continuation-line heuristic (wrapped description)
425
+ if (len(toks) == 1 or len(toks) < (max(idx2key.keys(), default=-1) + 1)) and items:
426
+ last = items[-1]
427
+ prev = (last.get("description") or "").strip()
428
+ last["description"] = (prev + " " + toks[0]).strip() if toks else prev
429
+ continue
430
+
431
+ rowd = {"description": None, "quantity": None, "units": None,
432
+ "price": None, "amount": None, "footage": None, "notes": None}
433
+
434
+ for j, tok in enumerate(toks):
435
+ key = idx2key.get(j)
436
+ if not key:
437
+ continue
438
+ val = tok.strip()
439
+ if key in ("quantity", "price", "amount"):
440
+ val = _normalize_num(val)
441
+ rowd[key] = val or rowd.get(key)
442
+
443
+ if rowd["quantity"] and rowd["units"]:
444
+ rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
445
+
446
+ if any(rowd.get(k) for k in ("description", "amount", "price")):
447
+ items.append(rowd)
448
+
449
+ # prune empties
450
+ return [it for it in items if any(v for k, v in it.items() if k != "notes")]
451
+
452
+ def _parse_vertical(text: str) -> List[Dict[str, Any]]:
453
+ """
454
+ Deterministic stacked/vertical parser for blocks like:
455
+
456
+ Description
457
+ Type
458
+ Quantity
459
+ Rate
460
+ Amount
461
+ <desc1>
462
+ <type1>
463
+ <qty1>
464
+ <rate1>
465
+ <amt1>
466
+ <desc2> ...
467
+
468
+ Stops at totals/subtotals.
469
+ """
470
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
471
+ if not lines:
472
+ return []
473
+
474
+ # Find the exact 5-label header block (order-agnostic but contiguous)
475
+ LABELS = ["description", "type", "quantity", "rate", "amount"]
476
+ def is_label(s: str) -> str | None:
477
+ t = s.lower()
478
+ if re.fullmatch(r"[₹$€]?\s*\d[\d,]*(?:\.\d+)?", t):
479
+ return None
480
+ if "desc" in t or "item" in t or "product" in t or "material" in t or "article" in t:
481
+ return "description"
482
+ if "type" in t or "uom" in t or "unit" in t or "units" in t:
483
+ return "type"
484
+ if "qty" in t or "quantity" in t:
485
+ return "quantity"
486
+ if "rate" in t or "price" in t or "unit price" in t:
487
+ return "rate"
488
+ if "amount" in t or "total" in t:
489
+ return "amount"
490
+ return None
491
+
492
+ start = -1
493
+ for i in range(len(lines) - 4):
494
+ block = lines[i:i+5]
495
+ mapped = [is_label(x) for x in block]
496
+ if None not in mapped and len(set(mapped)) == 5:
497
+ start = i
498
+ header_keys = mapped # e.g. ["description","type","quantity","rate","amount"]
499
+ break
500
+ if start == -1:
501
+ return []
502
+
503
+ # Build a position→canonical map in this exact order
504
+ pos2key = {idx: key for idx, key in enumerate(header_keys)}
505
+
506
+ # Consume values in chunks of 5
507
+ items: List[Dict[str, Any]] = []
508
+ i = start + 5
509
+ STOP = re.compile(r"^\s*(subtotal|tax|vat|gst|cgst|sgst|igst|total\b|grand total|amount due|balance due)\b", re.I)
510
+
511
+ def norm_num(s: str | None) -> str | None:
512
+ if not s: return None
513
+ return s.replace(",", "").replace("₹", "").replace("$", "").replace("€", "").strip() or None
514
+
515
+ while i + 4 < len(lines):
516
+ if STOP.search(lines[i]): # hit totals, bail
517
+ break
518
+ chunk = lines[i:i+5]
519
+
520
+ row = {"description": None, "units": None, "quantity": None,
521
+ "price": None, "amount": None, "footage": None, "notes": None}
522
+
523
+ # map chunk by discovered order
524
+ for j, val in enumerate(chunk):
525
+ key = pos2key[j]
526
+ if key == "type":
527
+ row["units"] = val # map "Type" -> "units"
528
+ elif key == "quantity":
529
+ row["quantity"] = norm_num(val)
530
+ elif key == "rate":
531
+ row["price"] = norm_num(val)
532
+ elif key == "amount":
533
+ row["amount"] = norm_num(val)
534
+ elif key == "description":
535
+ row["description"] = val
536
+
537
+ if row["quantity"] and row["units"]:
538
+ row["footage"] = f'{row["quantity"]} {row["units"]}'
539
+
540
+ # minimal acceptance: description or amount or price
541
+ if any(row.get(k) for k in ("description", "amount", "price")):
542
+ items.append(row)
543
+
544
+ i += 5
545
+
546
+ return items
547
+
548
+ # ---- main body ----
549
+ raw_lines = [ln.rstrip() for ln in text.splitlines()]
550
+ lines = [ln for ln in raw_lines if ln.strip()]
551
+ if not lines:
552
+ return []
553
+
554
+ # 1) Try horizontal first
555
+ items = _parse_horizontal(lines)
556
+ if items:
557
+ return items
558
+
559
+ # 2) Fallback to vertical/stacked
560
+ items = _parse_vertical(text)
561
+ return items
562
+
563
+
564
+
565
+ def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]:
566
+ if not candidates:
567
+ return {}
568
+ cand_keys = list(candidates.keys())
569
+ mapped: Dict[str, str] = {}
570
+ leftovers: Dict[str, str] = {}
571
+ for k, v in candidates.items():
572
+ lk = k.lower()
573
+ lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
574
+ hit = None
575
+ for syn, key in SYN2KEY.items():
576
+ if syn in lk_norm:
577
+ hit = key
578
+ break
579
+ if hit:
580
+ mapped[hit] = v
581
+ else:
582
+ leftovers[k] = v
583
+ if leftovers:
584
+ cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
585
+ head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
586
+ M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
587
+ keys_left = list(leftovers.keys())
588
+ for i, ck in enumerate(keys_left):
589
+ j = int(np.argmax(M[i]))
590
+ score = float(M[i][j])
591
+ if score >= thresh:
592
+ mapped[static_headers[j]] = leftovers[ck]
593
+ return mapped
594
+
595
+ def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
596
+ instruction = (
597
+ 'Use this schema:\n'
598
+ '{\n'
599
+ ' "invoice_header": {\n'
600
+ ' "car_number": "string or null",\n'
601
+ ' "shipment_number": "string or null",\n'
602
+ ' "shipping_point": "string or null",\n'
603
+ ' "currency": "string or null",\n'
604
+ ' "invoice_number": "string or null",\n'
605
+ ' "invoice_date": "string or null",\n'
606
+ ' "order_number": "string or null",\n'
607
+ ' "customer_order_number": "string or null",\n'
608
+ ' "our_order_number": "string or null",\n'
609
+ ' "sales_order_number": "string or null",\n'
610
+ ' "purchase_order_number": "string or null",\n'
611
+ ' "order_date": "string or null",\n'
612
+ ' "supplier_name": "string or null",\n'
613
+ ' "supplier_address": "string or null",\n'
614
+ ' "supplier_phone": "string or null",\n'
615
+ ' "supplier_email": "string or null",\n'
616
+ ' "supplier_tax_id": "string or null",\n'
617
+ ' "customer_name": "string or null",\n'
618
+ ' "customer_address": "string or null",\n'
619
+ ' "customer_phone": "string or null",\n'
620
+ ' "customer_email": "string or null",\n'
621
+ ' "customer_tax_id": "string or null",\n'
622
+ ' "ship_to_name": "string or null",\n'
623
+ ' "ship_to_address": "string or null",\n'
624
+ ' "bill_to_name": "string or null",\n'
625
+ ' "bill_to_address": "string or null",\n'
626
+ ' "remit_to_name": "string or null",\n'
627
+ ' "remit_to_address": "string or null",\n'
628
+ ' "tax_id": "string or null",\n'
629
+ ' "tax_registration_number": "string or null",\n'
630
+ ' "vat_number": "string or null",\n'
631
+ ' "payment_terms": "string or null",\n'
632
+ ' "payment_method": "string or null",\n'
633
+ ' "payment_reference": "string or null",\n'
634
+ ' "bank_account_number": "string or null",\n'
635
+ ' "iban": "string or null",\n'
636
+ ' "swift_code": "string or null",\n'
637
+ ' "total_before_tax": "string or null",\n'
638
+ ' "tax_amount": "string or null",\n'
639
+ ' "tax_rate": "string or null",\n'
640
+ ' "shipping_charges": "string or null",\n'
641
+ ' "discount": "string or null",\n'
642
+ ' "total_due": "string or null",\n'
643
+ ' "amount_paid": "string or null",\n'
644
+ ' "balance_due": "string or null",\n'
645
+ ' "due_date": "string or null",\n'
646
+ ' "invoice_status": "string or null",\n'
647
+ ' "reference_number": "string or null",\n'
648
+ ' "project_code": "string or null",\n'
649
+ ' "department": "string or null",\n'
650
+ ' "contact_person": "string or null",\n'
651
+ ' "notes": "string or null",\n'
652
+ ' "additional_info": "string or null"\n'
653
+ ' },\n'
654
+ ' "line_items": [\n'
655
+ ' {\n'
656
+ ' "quantity": "string or null",\n'
657
+ ' "units": "string or null",\n'
658
+ ' "description": "string or null",\n'
659
+ ' "footage": "string or null",\n'
660
+ ' "price": "string or null",\n'
661
+ ' "amount": "string or null",\n'
662
+ ' "notes": "string or null"\n'
663
+ ' }\n'
664
+ ' ]\n'
665
+ '}\n'
666
+ 'If a field is missing for a line item or header, use null. '
667
+ 'Do not invent fields. Do not add any header or shipment data to any line item. '
668
+ 'Return ONLY the JSON object, no explanation.\n'
669
+ )
670
+ hints = ""
671
+ if mapped_hints:
672
+ hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
673
+ if items_hints:
674
+ try:
675
+ hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
676
+ except:
677
+ pass
678
+ return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
679
+
680
+ def strict_json(text: str) -> Dict[str, Any]:
681
+ try:
682
+ return json.loads(text)
683
+ except:
684
+ pass
685
+ start = text.find("{")
686
+ end = text.rfind("}")
687
+ if start != -1 and end != -1 and end > start:
688
+ try:
689
+ return json.loads(text[start:end+1])
690
+ except:
691
+ pass
692
+ raise ValueError("Model did not return valid JSON.")
693
+
694
+ def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
695
+ final = copy.deepcopy(rule_json)
696
+
697
+ # --- headers (rules win where present) ---
698
+ hdr = final["invoice_header"]
699
+ mdl_hdr = (model_json.get("invoice_header") or {})
700
+ for k in hdr.keys():
701
+ if hdr[k] in [None, "", "null"]:
702
+ v = mdl_hdr.get(k, None)
703
+ if v not in [None, "", "null"]:
704
+ hdr[k] = v
705
+
706
+ # --- line_items (prefer parsed items -> model -> empty) ---
707
+ rule_items = rule_json.get("line_items") or []
708
+ model_items = model_json.get("line_items") or []
709
+
710
+ if _has_real_items(rule_items):
711
+ final["line_items"] = rule_items
712
+ elif _has_real_items(model_items):
713
+ final["line_items"] = model_items
714
+ else:
715
+ final["line_items"] = []
716
+
717
+ return final
718
+
719
+ def _prune_empty_items(payload: Dict[str, Any]) -> Dict[str, Any]:
720
+ items = payload.get("line_items")
721
+ if isinstance(items, list):
722
+ payload["line_items"] = [
723
+ it for it in items
724
+ if isinstance(it, dict) and any(v not in (None, "", "null") for v in it.values())
725
+ ]
726
+ return payload
727
+
728
+
729
+ # ---------------------- MAIN FUNCTION ----------------------
730
+ def invoice_text_to_json(
731
+ invoice_text: str,
732
+ threshold: float = 0.60,
733
+ max_new_tokens: int = 512
734
+ ) -> Dict[str, Any]:
735
+ # Load models once (cache if you like for production)
736
+ sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
737
+ json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
738
+
739
+ txt = invoice_text
740
+
741
+ # 1) Deterministic extraction
742
+ candidates = extract_candidates(txt)
743
+ hard = regex_extract_all(txt)
744
+ bank = extract_bank_block(txt)
745
+ items = parse_line_items(txt)
746
+ print("Extracted line items:", items)
747
+
748
+ sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model)
749
+ header_found: Dict[str, Any] = {}
750
+ header_found.update(sem_mapped)
751
+ header_found.update(hard)
752
+ header_found.update(bank)
753
+
754
+ # 2) Build RULE JSON (schema-shaped, rules filled)
755
+ rule_json = deep_copy_schema()
756
+ if _has_real_items(items):
757
+ rule_json["line_items"] = items
758
+ else:
759
+ rule_json["line_items"] = []
760
+ for k, v in header_found.items():
761
+ if k in rule_json["invoice_header"]:
762
+ rule_json["invoice_header"][k] = v
763
+
764
+
765
+ # 3) MD2JSON generation with strong hints
766
+ prompt = build_prompt(txt, header_found, items)
767
+ gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
768
+ try:
769
+ model_json = strict_json(gen)
770
+ except Exception as e:
771
+ model_json = deep_copy_schema() # model failed; keep empty shape
772
+
773
+ # 4) Final merge (rules win)
774
+ final_json = merge_schema(rule_json, model_json)
775
+ final_json = _prune_empty_items(final_json)
776
+ return final_json
777
+
778
+ # ---------- Gradio UI ----------
779
+ TITLE = "docTR OCR — Text Extractor"
780
+ DESC = (
781
+ "Upload an image or PDF. This Space uses Mindee's docTR (PyTorch backend) to detect & recognize text, "
782
+ "and returns plain text per page. CPU-friendly and ready for enterprise prototyping."
783
+ )
784
+
785
+ with gr.Blocks(theme="soft", title=TITLE) as demo:
786
+ gr.Markdown(f"# {TITLE}\n{DESC}")
787
+
788
+ with gr.Row():
789
+ inp = gr.File(label="Upload image/PDF", file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"])
790
+ out = gr.Code(label="Extracted JSON", language="json")
791
+
792
+
793
+ run_btn = gr.Button("Run OCR", variant="primary")
794
+ run_btn.click(fn=run_ocr, inputs=inp, outputs=out)
795
+
796
+ gr.Examples(
797
+ examples=[
798
+ # You can drop a couple of public sample URLs here if desired,
799
+ # but Spaces won't auto-download without code. Leave empty by default.
800
+ ],
801
+ inputs=inp,
802
+ outputs=out,
803
+ cache_examples=False,
804
+ label="(Optional) Examples"
805
+ )
806
+
807
+ gr.Markdown(
808
+ "Tip: For multi-page PDFs, the output shows a **PAGE BREAK** separator between pages.\n"
809
+ "For production pipelines, capture this output and route it to your parsing/LLM layer."
810
+ )
811
+
812
+ if __name__ == "__main__":
813
+ demo.launch(
814
+ server_name="0.0.0.0",
815
+ server_port=7860,
816
+ share=True,
817
+ show_error=True
818
+ )