KarthiEz commited on
Commit
dcd1a3c
·
verified ·
1 Parent(s): 089db88

Upload 2 files

Browse files
Files changed (2) hide show
  1. req.txt +34 -0
  2. test.py +558 -0
req.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core runtime
2
+ #paddlepaddle==2.6.1
3
+ #paddleocr==2.7.0.3
4
+
5
+ # PDF renderer compatible with PaddleOCR 2.7.0.3 (requires <1.21.0)
6
+ #pymupdf==1.20.2
7
+
8
+ # OpenCV: PaddleOCR 2.7 expects <=4.6.0.66 and needs contrib; use headless for servers
9
+ opencv-contrib-python-headless==4.6.0.66
10
+
11
+ # Numerics & imaging
12
+ numpy==1.26.4
13
+ Pillow==10.4.0
14
+
15
+ # UI
16
+ gradio==4.26.0
17
+ gradio-client==0.15.1
18
+ fastapi==0.109.2
19
+ starlette==0.36.3
20
+ pydantic==2.6.4
21
+ anyio==4.1.0
22
+
23
+ sentence-transformers==3.0.1
24
+ scikit-learn>=1.3
25
+
26
+ # Quality-of-life
27
+ tqdm==4.67.1
28
+
29
+
30
+ python-doctr[torch,viz]>=0.11.0
31
+ pypdfium2>=4.30.0
32
+
33
+ transformers==4.57.1
34
+ sentence-transformers
test.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from typing import List
4
+ import gradio as gr
5
+ # docTR imports (PyTorch backend)
6
+ from doctr.io import DocumentFile
7
+ from doctr.models import ocr_predictor
8
+
9
+ # ---------- One-time model bootstrap (CPU-friendly) ----------
10
+ # Ensure torch runs in CPU mode on Spaces; docTR auto-detects backend.
11
+ # You can optionally pin threads for stability on small CPU runners:
12
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
13
+ os.environ.setdefault("MKL_NUM_THREADS", "4")
14
+
15
+ MODEL = ocr_predictor(pretrained=True) # DBNet + CRNN (default) on PyTorch
16
+
17
+ def _collect_text_from_export(exported: dict) -> str:
18
+ """Flatten docTR exported structure into newline-separated text per page."""
19
+ pages: List[dict] = exported.get("pages", [])
20
+ text_pages: List[str] = []
21
+
22
+ for page in pages:
23
+ page_lines = []
24
+ for block in page.get("blocks", []):
25
+ for line in block.get("lines", []):
26
+ # Join word values in the line; fallback robustly
27
+ words = [w.get("value", "") for w in line.get("words", []) if isinstance(w, dict)]
28
+ line_text = " ".join([w for w in words if w])
29
+ if line_text.strip():
30
+ page_lines.append(line_text)
31
+ text_pages.append("\n".join(page_lines).strip())
32
+
33
+ # Join pages with a page delimiter
34
+ return ("\n\n" + ("─" * 32) + " PAGE BREAK " + ("─" * 32) + "\n\n").join(
35
+ [tp for tp in text_pages if tp]
36
+ ).strip()
37
+
38
+ def run_ocr(file: gr.File) -> str:
39
+ if file is None:
40
+ return "No file received."
41
+
42
+ name = (file.name or "").lower()
43
+
44
+ # Load as DocumentFile (handles PNG/JPG/PDF)
45
+ if name.endswith(".pdf"):
46
+ # Render PDF pages via pdfium backend under the hood (CPU OK)
47
+ doc = DocumentFile.from_pdf(file=file.name)
48
+ else:
49
+ # Single image fallback; also works for TIFF/PNG/JPG
50
+ doc = DocumentFile.from_images([file.name])
51
+
52
+ # Inference
53
+ result = MODEL(doc)
54
+ exported = result.export()
55
+ text = _collect_text_from_export(exported)
56
+ print("Extracted Text:\n", text)
57
+
58
+ if not text:
59
+ return "No text detected."
60
+ result_json = invoice_text_to_json(text)
61
+ print(json.dumps(result_json, indent=2))
62
+ string_json = json.dumps(result_json, indent=2)
63
+ return string_json
64
+
65
+ import re
66
+ import json
67
+ from typing import List, Dict, Any
68
+ import copy
69
+ import numpy as np
70
+ import torch
71
+ from transformers import pipeline
72
+ from sentence_transformers import SentenceTransformer, util
73
+
74
+ # ----------------------------- Schema -----------------------------
75
+ SCHEMA_JSON: Dict[str, Any] = {
76
+ "invoice_header": {
77
+ "car_number": None,
78
+ "shipment_number": None,
79
+ "shipping_point": None,
80
+ "currency": None,
81
+ "invoice_number": None,
82
+ "invoice_date": None,
83
+ "order_number": None,
84
+ "customer_order_number": None,
85
+ "our_order_number": None,
86
+ "sales_order_number": None,
87
+ "purchase_order_number": None,
88
+ "order_date": None,
89
+ "supplier_name": None,
90
+ "supplier_address": None,
91
+ "supplier_phone": None,
92
+ "supplier_email": None,
93
+ "supplier_tax_id": None,
94
+ "customer_name": None,
95
+ "customer_address": None,
96
+ "customer_phone": None,
97
+ "customer_email": None,
98
+ "customer_tax_id": None,
99
+ "ship_to_name": None,
100
+ "ship_to_address": None,
101
+ "bill_to_name": None,
102
+ "bill_to_address": None,
103
+ "remit_to_name": None,
104
+ "remit_to_address": None,
105
+ "tax_id": None,
106
+ "tax_registration_number": None,
107
+ "vat_number": None,
108
+ "payment_terms": None,
109
+ "payment_method": None,
110
+ "payment_reference": None,
111
+ "bank_account_number": None,
112
+ "iban": None,
113
+ "swift_code": None,
114
+ "total_before_tax": None,
115
+ "tax_amount": None,
116
+ "tax_rate": None,
117
+ "shipping_charges": None,
118
+ "discount": None,
119
+ "total_due": None,
120
+ "amount_paid": None,
121
+ "balance_due": None,
122
+ "due_date": None,
123
+ "invoice_status": None,
124
+ "reference_number": None,
125
+ "project_code": None,
126
+ "department": None,
127
+ "contact_person": None,
128
+ "notes": None,
129
+ "additional_info": None
130
+ },
131
+ "line_items": [
132
+ {
133
+ "quantity": None,
134
+ "units": None,
135
+ "description": None,
136
+ "footage": None,
137
+ "price": None,
138
+ "amount": None,
139
+ "notes": None
140
+ }
141
+ ]
142
+ }
143
+ STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
144
+
145
+ # Synonym map
146
+ SYN2KEY: Dict[str, str] = {
147
+ "invoice no": "invoice_number",
148
+ "invoice number": "invoice_number",
149
+ "invoice#": "invoice_number",
150
+ "inv no": "invoice_number",
151
+ "inv#": "invoice_number",
152
+ "invoice date": "invoice_date",
153
+ "date of invoice": "invoice_date",
154
+ "po no": "purchase_order_number",
155
+ "po number": "purchase_order_number",
156
+ "purchase order": "purchase_order_number",
157
+ "order no": "order_number",
158
+ "order number": "order_number",
159
+ "sales order": "sales_order_number",
160
+ "customer order": "customer_order_number",
161
+ "our order": "our_order_number",
162
+ "due date": "due_date",
163
+ "date of supply": "order_date",
164
+ "gstin": "supplier_tax_id",
165
+ "gstin no": "supplier_tax_id",
166
+ "tax id": "tax_id",
167
+ "vat number": "vat_number",
168
+ "tax registration number": "tax_registration_number",
169
+ "place of supply": "shipping_point",
170
+ "state code": "additional_info",
171
+ "taxable value": "total_before_tax",
172
+ "total value": "total_due",
173
+ "total amount": "total_due",
174
+ "amount due": "total_due",
175
+ "bank": "bank_account_number",
176
+ "account no": "bank_account_number",
177
+ "account number": "bank_account_number",
178
+ "ifs code": "swift_code",
179
+ "ifsc": "payment_reference",
180
+ "swift code": "swift_code",
181
+ "iban": "iban",
182
+ "e-way bill no": "reference_number",
183
+ "eway bill": "reference_number",
184
+ "dispatched via": "additional_info",
185
+ "documents dispatched through": "additional_info",
186
+ "kind attn": "contact_person",
187
+ "billed to": "bill_to_name",
188
+ "receiver": "bill_to_name",
189
+ "shipped to": "ship_to_name",
190
+ "consignee": "ship_to_name",
191
+ }
192
+
193
+ def norm(s: str) -> str:
194
+ return re.sub(r"\s+", " ", s).strip()
195
+
196
+ def deep_copy_schema() -> Dict[str, Any]:
197
+ return json.loads(json.dumps(SCHEMA_JSON))
198
+
199
+ def extract_candidates(text: str) -> Dict[str, str]:
200
+ cands: Dict[str, str] = {}
201
+ for raw in text.splitlines():
202
+ line = raw.strip().strip("|").strip()
203
+ if not line:
204
+ continue
205
+ if ":" in line:
206
+ if "|" in raw:
207
+ parts = [p.strip() for p in raw.split("|") if p.strip()]
208
+ for cell in parts:
209
+ if ":" in cell:
210
+ k, v = cell.split(":", 1)
211
+ cands[norm(k)] = norm(v)
212
+ else:
213
+ k, v = line.split(":", 1)
214
+ cands[norm(k)] = norm(v)
215
+ for raw in text.splitlines():
216
+ m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
217
+ if m:
218
+ k = norm(m.group(1))
219
+ v = norm(m.group(2))
220
+ cands[k] = v
221
+ return cands
222
+
223
+ def regex_extract_all(text: str) -> Dict[str, str]:
224
+ out: Dict[str, str] = {}
225
+ m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
226
+ if m: out["invoice_number"] = m.group(1)
227
+ m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
228
+ if m: out["invoice_date"] = m.group(1)
229
+ m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
230
+ if m: out["purchase_order_number"] = m.group(1)
231
+ m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
232
+ if m: out["order_date"] = m.group(1)
233
+ if "order_date" not in out:
234
+ m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
235
+ if m: out["order_date"] = m.group(1)
236
+ m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
237
+ if m: out["shipping_point"] = m.group(1).strip(" |")
238
+ m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
239
+ if m: out["supplier_tax_id"] = m.group(1)
240
+ m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
241
+ if m: out["total_before_tax"] = m.group(1).replace(",", "")
242
+ cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
243
+ sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
244
+ if cgst and sgst:
245
+ try:
246
+ tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
247
+ out["tax_amount"] = f"{tax_total:.2f}"
248
+ cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
249
+ sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
250
+ if cgstp and sgstp:
251
+ try:
252
+ rate = float(cgstp.group(1)) + float(sgstp.group(1))
253
+ out["tax_rate"] = f"{rate:g}"
254
+ except:
255
+ pass
256
+ except:
257
+ pass
258
+ m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
259
+ if m: out["reference_number"] = m.group(1).strip()
260
+ return out
261
+
262
+ def extract_bank_block(text: str) -> Dict[str, str]:
263
+ bank: Dict[str, str] = {}
264
+ m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
265
+ if m: bank["supplier_name"] = m.group(1).strip()
266
+ m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
267
+ if m: bank["bank_account_number"] = m.group(1).strip()
268
+ m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
269
+ if m:
270
+ bank["additional_info"] = ("Bank: " + m.group(1).strip())
271
+ m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
272
+ if m: bank["payment_reference"] = m.group(1).strip()
273
+ m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
274
+ if m: bank["swift_code"] = m.group(1).strip()
275
+ branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
276
+ micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
277
+ extra_bits = []
278
+ if branch: extra_bits.append("Branch: " + branch.group(1).strip())
279
+ if micr: extra_bits.append("MICR: " + micr.group(1).strip())
280
+ if extra_bits:
281
+ bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
282
+ return bank
283
+
284
+ def parse_line_items(text: str) -> List[Dict[str, Any]]:
285
+ items: List[Dict[str, Any]] = []
286
+ lines = [ln for ln in text.splitlines() if ln.strip()]
287
+ header_idx = -1
288
+ for i, ln in enumerate(lines):
289
+ if ("|") in ln and ("Description" in ln and ("Qty" in ln or "QTY" in ln)) and ("Rate" in ln or "Price" in ln) and ("Total" in ln):
290
+ header_idx = i
291
+ break
292
+ if header_idx == -1:
293
+ return items
294
+ headers = [c.strip().lower() for c in lines[header_idx].split("|")]
295
+ headers = [h for h in headers if h and set(h) - set("-")]
296
+ for j in range(header_idx + 1, len(lines)):
297
+ row = lines[j]
298
+ if row.strip().startswith("|") and row.count("|") >= 2:
299
+ cells = [c.strip() for c in row.split("|")]
300
+ cells = [c for c in cells if c and set(c) - set("-")]
301
+ if len(cells) < 3:
302
+ continue
303
+ rowd = {"quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None}
304
+ def idx_of(name_parts: List[str]) -> int:
305
+ for k, h in enumerate(headers):
306
+ if any(p in h for p in name_parts):
307
+ return k
308
+ return -1
309
+ i_desc = idx_of(["description", "item"])
310
+ i_qty = idx_of(["qty", "quantity"])
311
+ i_uom = idx_of(["uom", "unit"])
312
+ i_rate = idx_of(["rate", "price"])
313
+ i_amt = idx_of(["total value", "amount", "total"])
314
+ def safe(i: int) -> str:
315
+ return cells[i] if 0 <= i < len(cells) else ""
316
+ if i_desc != -1: rowd["description"] = safe(i_desc) or None
317
+ if i_qty != -1: rowd["quantity"] = safe(i_qty) or None
318
+ if i_uom != -1: rowd["units"] = safe(i_uom) or None
319
+ if i_rate != -1: rowd["price"] = safe(i_rate) or None
320
+ if i_amt != -1: rowd["amount"] = safe(i_amt) or None
321
+ if rowd["units"] and rowd["quantity"]:
322
+ rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
323
+ items.append(rowd)
324
+ else:
325
+ if j > header_idx + 1:
326
+ break
327
+ return items
328
+
329
+ def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float, sentence_model) -> Dict[str, str]:
330
+ if not candidates:
331
+ return {}
332
+ cand_keys = list(candidates.keys())
333
+ mapped: Dict[str, str] = {}
334
+ leftovers: Dict[str, str] = {}
335
+ for k, v in candidates.items():
336
+ lk = k.lower()
337
+ lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
338
+ hit = None
339
+ for syn, key in SYN2KEY.items():
340
+ if syn in lk_norm:
341
+ hit = key
342
+ break
343
+ if hit:
344
+ mapped[hit] = v
345
+ else:
346
+ leftovers[k] = v
347
+ if leftovers:
348
+ cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
349
+ head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
350
+ M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
351
+ keys_left = list(leftovers.keys())
352
+ for i, ck in enumerate(keys_left):
353
+ j = int(np.argmax(M[i]))
354
+ score = float(M[i][j])
355
+ if score >= thresh:
356
+ mapped[static_headers[j]] = leftovers[ck]
357
+ return mapped
358
+
359
+ def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
360
+ instruction = (
361
+ 'Use this schema:\n'
362
+ '{\n'
363
+ ' "invoice_header": {\n'
364
+ ' "car_number": "string or null",\n'
365
+ ' "shipment_number": "string or null",\n'
366
+ ' "shipping_point": "string or null",\n'
367
+ ' "currency": "string or null",\n'
368
+ ' "invoice_number": "string or null",\n'
369
+ ' "invoice_date": "string or null",\n'
370
+ ' "order_number": "string or null",\n'
371
+ ' "customer_order_number": "string or null",\n'
372
+ ' "our_order_number": "string or null",\n'
373
+ ' "sales_order_number": "string or null",\n'
374
+ ' "purchase_order_number": "string or null",\n'
375
+ ' "order_date": "string or null",\n'
376
+ ' "supplier_name": "string or null",\n'
377
+ ' "supplier_address": "string or null",\n'
378
+ ' "supplier_phone": "string or null",\n'
379
+ ' "supplier_email": "string or null",\n'
380
+ ' "supplier_tax_id": "string or null",\n'
381
+ ' "customer_name": "string or null",\n'
382
+ ' "customer_address": "string or null",\n'
383
+ ' "customer_phone": "string or null",\n'
384
+ ' "customer_email": "string or null",\n'
385
+ ' "customer_tax_id": "string or null",\n'
386
+ ' "ship_to_name": "string or null",\n'
387
+ ' "ship_to_address": "string or null",\n'
388
+ ' "bill_to_name": "string or null",\n'
389
+ ' "bill_to_address": "string or null",\n'
390
+ ' "remit_to_name": "string or null",\n'
391
+ ' "remit_to_address": "string or null",\n'
392
+ ' "tax_id": "string or null",\n'
393
+ ' "tax_registration_number": "string or null",\n'
394
+ ' "vat_number": "string or null",\n'
395
+ ' "payment_terms": "string or null",\n'
396
+ ' "payment_method": "string or null",\n'
397
+ ' "payment_reference": "string or null",\n'
398
+ ' "bank_account_number": "string or null",\n'
399
+ ' "iban": "string or null",\n'
400
+ ' "swift_code": "string or null",\n'
401
+ ' "total_before_tax": "string or null",\n'
402
+ ' "tax_amount": "string or null",\n'
403
+ ' "tax_rate": "string or null",\n'
404
+ ' "shipping_charges": "string or null",\n'
405
+ ' "discount": "string or null",\n'
406
+ ' "total_due": "string or null",\n'
407
+ ' "amount_paid": "string or null",\n'
408
+ ' "balance_due": "string or null",\n'
409
+ ' "due_date": "string or null",\n'
410
+ ' "invoice_status": "string or null",\n'
411
+ ' "reference_number": "string or null",\n'
412
+ ' "project_code": "string or null",\n'
413
+ ' "department": "string or null",\n'
414
+ ' "contact_person": "string or null",\n'
415
+ ' "notes": "string or null",\n'
416
+ ' "additional_info": "string or null"\n'
417
+ ' },\n'
418
+ ' "line_items": [\n'
419
+ ' {\n'
420
+ ' "quantity": "string or null",\n'
421
+ ' "units": "string or null",\n'
422
+ ' "description": "string or null",\n'
423
+ ' "footage": "string or null",\n'
424
+ ' "price": "string or null",\n'
425
+ ' "amount": "string or null",\n'
426
+ ' "notes": "string or null"\n'
427
+ ' }\n'
428
+ ' ]\n'
429
+ '}\n'
430
+ 'If a field is missing for a line item or header, use null. '
431
+ 'Do not invent fields. Do not add any header or shipment data to any line item. '
432
+ 'Return ONLY the JSON object, no explanation.\n'
433
+ )
434
+ hints = ""
435
+ if mapped_hints:
436
+ hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
437
+ if items_hints:
438
+ try:
439
+ hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
440
+ except:
441
+ pass
442
+ return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
443
+
444
+ def strict_json(text: str) -> Dict[str, Any]:
445
+ try:
446
+ return json.loads(text)
447
+ except:
448
+ pass
449
+ start = text.find("{")
450
+ end = text.rfind("}")
451
+ if start != -1 and end != -1 and end > start:
452
+ try:
453
+ return json.loads(text[start:end+1])
454
+ except:
455
+ pass
456
+ raise ValueError("Model did not return valid JSON.")
457
+
458
+ def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
459
+ final = copy.deepcopy(rule_json)
460
+ hdr = final["invoice_header"]
461
+ mdl_hdr = (model_json.get("invoice_header") or {})
462
+ for k in hdr.keys():
463
+ if hdr[k] in [None, "", "null"]:
464
+ v = mdl_hdr.get(k, None)
465
+ if v not in [None, "", "null"]:
466
+ hdr[k] = v
467
+ if final["line_items"] and any(any(v for v in row.values() if v not in [None, "", "null"]) for row in final["line_items"]):
468
+ pass
469
+ else:
470
+ mdl_items = model_json.get("line_items")
471
+ if isinstance(mdl_items, list) and mdl_items:
472
+ final["line_items"] = mdl_items
473
+ return final
474
+
475
+ # ---------------------- MAIN FUNCTION ----------------------
476
+ def invoice_text_to_json(
477
+ invoice_text: str,
478
+ threshold: float = 0.60,
479
+ max_new_tokens: int = 512
480
+ ) -> Dict[str, Any]:
481
+ # Load models once (cache if you like for production)
482
+ sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
483
+ json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
484
+
485
+ txt = invoice_text
486
+
487
+ # 1) Deterministic extraction
488
+ candidates = extract_candidates(txt)
489
+ hard = regex_extract_all(txt)
490
+ bank = extract_bank_block(txt)
491
+ items = parse_line_items(txt)
492
+ sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold, sentence_model)
493
+ header_found: Dict[str, Any] = {}
494
+ header_found.update(sem_mapped)
495
+ header_found.update(hard)
496
+ header_found.update(bank)
497
+
498
+ # 2) Build RULE JSON (schema-shaped, rules filled)
499
+ rule_json = deep_copy_schema()
500
+ for k, v in header_found.items():
501
+ if k in rule_json["invoice_header"]:
502
+ rule_json["invoice_header"][k] = v
503
+ if items:
504
+ rule_json["line_items"] = items
505
+
506
+ # 3) MD2JSON generation with strong hints
507
+ prompt = build_prompt(txt, header_found, items)
508
+ gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
509
+ try:
510
+ model_json = strict_json(gen)
511
+ except Exception as e:
512
+ model_json = deep_copy_schema() # model failed; keep empty shape
513
+
514
+ # 4) Final merge (rules win)
515
+ final_json = merge_schema(rule_json, model_json)
516
+ return final_json
517
+
518
+ # ---------- Gradio UI ----------
519
+ TITLE = "docTR OCR — Text Extractor"
520
+ DESC = (
521
+ "Upload an image or PDF. This Space uses Mindee's docTR (PyTorch backend) to detect & recognize text, "
522
+ "and returns plain text per page. CPU-friendly and ready for enterprise prototyping."
523
+ )
524
+
525
+ with gr.Blocks(theme="soft", title=TITLE) as demo:
526
+ gr.Markdown(f"# {TITLE}\n{DESC}")
527
+
528
+ with gr.Row():
529
+ inp = gr.File(label="Upload image/PDF", file_types=[".png", ".jpg", ".jpeg", ".tif", ".tiff", ".pdf"])
530
+ out = gr.Code(label="Extracted JSON", language="json")
531
+
532
+
533
+ run_btn = gr.Button("Run OCR", variant="primary")
534
+ run_btn.click(fn=run_ocr, inputs=inp, outputs=out)
535
+
536
+ gr.Examples(
537
+ examples=[
538
+ # You can drop a couple of public sample URLs here if desired,
539
+ # but Spaces won't auto-download without code. Leave empty by default.
540
+ ],
541
+ inputs=inp,
542
+ outputs=out,
543
+ cache_examples=False,
544
+ label="(Optional) Examples"
545
+ )
546
+
547
+ gr.Markdown(
548
+ "Tip: For multi-page PDFs, the output shows a **PAGE BREAK** separator between pages.\n"
549
+ "For production pipelines, capture this output and route it to your parsing/LLM layer."
550
+ )
551
+
552
+ if __name__ == "__main__":
553
+ demo.launch(
554
+ server_name="0.0.0.0",
555
+ server_port=7860,
556
+ share=True,
557
+ show_error=True
558
+ )