Seth0330 commited on
Commit
3a1b8b7
·
verified ·
1 Parent(s): 6443176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -595
app.py CHANGED
@@ -1,619 +1,296 @@
1
  # app.py
2
- # Invoice -> JSON (Paste Text Only) with better accuracy:
3
- # - Pipe-table aware parsing
4
- # - Regex extractors for common headers (Invoice No, Dates, PO, totals, taxes, GSTIN, etc.)
5
- # - Line-item table parser (SNO, Description, Qty, UOM, Rate, Total Value)
6
- # - Synonym dictionary -> canonical schema keys
7
- # - Semantic mapping (MiniLM) for leftovers
8
- # - MD2JSON prompt with strong hints; final schema = RULES ∪ MODEL (model cannot remove found values)
9
-
10
- import re
11
  import json
12
- from typing import List, Dict, Any, Tuple
13
- import copy
 
14
 
15
- import numpy as np
16
  import streamlit as st
17
- import torch
18
- from transformers import pipeline
19
- from sentence_transformers import SentenceTransformer, util
20
-
21
- st.set_page_config(page_title="Invoice JSON (Paste Text) · Accurate v2", layout="wide")
22
- st.title("Invoice JSON (Paste Text) Accurate v2")
23
-
24
- # ----------------------------- Schema -----------------------------
25
- SCHEMA_JSON: Dict[str, Any] = {
26
- "invoice_header": {
27
- "car_number": None,
28
- "shipment_number": None,
29
- "shipping_point": None,
30
- "currency": None,
31
- "invoice_number": None,
32
- "invoice_date": None,
33
- "order_number": None,
34
- "customer_order_number": None,
35
- "our_order_number": None,
36
- "sales_order_number": None,
37
- "purchase_order_number": None,
38
- "order_date": None,
39
- "supplier_name": None,
40
- "supplier_address": None,
41
- "supplier_phone": None,
42
- "supplier_email": None,
43
- "supplier_tax_id": None,
44
- "customer_name": None,
45
- "customer_address": None,
46
- "customer_phone": None,
47
- "customer_email": None,
48
- "customer_tax_id": None,
49
- "ship_to_name": None,
50
- "ship_to_address": None,
51
- "bill_to_name": None,
52
- "bill_to_address": None,
53
- "remit_to_name": None,
54
- "remit_to_address": None,
55
- "tax_id": None,
56
- "tax_registration_number": None,
57
- "vat_number": None,
58
- "payment_terms": None,
59
- "payment_method": None,
60
- "payment_reference": None,
61
- "bank_account_number": None,
62
- "iban": None,
63
- "swift_code": None,
64
- "total_before_tax": None,
65
- "tax_amount": None,
66
- "tax_rate": None,
67
- "shipping_charges": None,
68
- "discount": None,
69
- "total_due": None,
70
- "amount_paid": None,
71
- "balance_due": None,
72
- "due_date": None,
73
- "invoice_status": None,
74
- "reference_number": None,
75
- "project_code": None,
76
- "department": None,
77
- "contact_person": None,
78
- "notes": None,
79
- "additional_info": None
80
- },
81
- "line_items": [
82
- {
83
- "quantity": None,
84
- "units": None,
85
- "description": None,
86
- "footage": None,
87
- "price": None,
88
- "amount": None,
89
- "notes": None
90
- }
91
- ]
92
- }
93
- STATIC_HEADERS: List[str] = list(SCHEMA_JSON["invoice_header"].keys())
94
-
95
- # ----------------------------- Sidebar -----------------------------
96
- st.sidebar.header("Settings")
97
- threshold = st.sidebar.slider("Semantic match threshold (cosine)", 0.0, 1.0, 0.60, 0.01)
98
- max_new_tokens = st.sidebar.slider("Max new tokens (MD2JSON)", 128, 2048, 512, 32)
99
- show_intermediates = st.sidebar.checkbox("Show intermediates", value=True)
100
-
101
- # ----------------------------- Models (cached) -----------------------------
102
- @st.cache_resource(show_spinner=True)
103
- def load_models():
104
- sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
105
- json_converter = pipeline("text2text-generation", model="yahyakhoder/MD2JSON-T5-small-V1")
106
- return sentence_model, json_converter
107
- sentence_model, json_converter = load_models()
108
-
109
- # ----------------------------- Synonym map -> schema keys -----------------------------
110
- SYN2KEY: Dict[str, str] = {
111
- # direct header synonyms
112
- "invoice no": "invoice_number",
113
- "invoice number": "invoice_number",
114
- "invoice#": "invoice_number",
115
- "inv no": "invoice_number",
116
- "inv#": "invoice_number",
117
-
118
- "invoice date": "invoice_date",
119
- "date of invoice": "invoice_date",
120
-
121
- "po no": "purchase_order_number",
122
- "po number": "purchase_order_number",
123
- "purchase order": "purchase_order_number",
124
- "order no": "order_number",
125
- "order number": "order_number",
126
- "sales order": "sales_order_number",
127
- "customer order": "customer_order_number",
128
- "our order": "our_order_number",
129
-
130
- "due date": "due_date",
131
- "date of supply": "order_date",
132
-
133
- "gstin": "supplier_tax_id",
134
- "gstin no": "supplier_tax_id",
135
- "tax id": "tax_id",
136
- "vat number": "vat_number",
137
- "tax registration number": "tax_registration_number",
138
-
139
- "place of supply": "shipping_point",
140
- "state code": "additional_info", # keep if you prefer a specific field
141
-
142
- "taxable value": "total_before_tax",
143
- "total value": "total_due",
144
- "total amount": "total_due",
145
- "amount due": "total_due",
146
-
147
- "bank": "bank_account_number", # we’ll fix value using bank block parsing
148
- "account no": "bank_account_number",
149
- "account number": "bank_account_number",
150
- "ifs code": "swift_code", # India: really IFSC; we’ll drop it into 'payment_reference' or keep separate
151
- "ifsc": "payment_reference",
152
- "swift code": "swift_code",
153
- "iban": "iban",
154
-
155
- "e-way bill no": "reference_number",
156
- "eway bill": "reference_number",
157
-
158
- "dispatched via": "additional_info",
159
- "documents dispatched through": "additional_info",
160
- "kind attn": "contact_person",
161
-
162
- # parties
163
- "billed to": "bill_to_name",
164
- "receiver": "bill_to_name",
165
- "shipped to": "ship_to_name",
166
- "consignee": "ship_to_name",
167
- }
168
-
169
- # ----------------------------- Utilities -----------------------------
170
- def norm(s: str) -> str:
171
- return re.sub(r"\s+", " ", s).strip()
172
-
173
- def to_lower(s: str) -> str:
174
- return s.lower().strip()
175
-
176
- def deep_copy_schema() -> Dict[str, Any]:
177
- return json.loads(json.dumps(SCHEMA_JSON))
178
 
179
- # ----------------------------- Pipe-table aware candidate extractor -----------------------------
180
- def extract_candidates(text: str) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  """
182
- Build candidates from:
183
- 1) colon lines: Key: Value
184
- 2) pipe rows: | ... | ... | (pick obvious key:value pairs like "Invoice No: X" inside cells)
185
- 3) single-value lines for totals (Taxable Value, Total, etc.)
186
  """
187
- cands: Dict[str, str] = {}
 
188
 
189
- # 1) colon lines
190
- for raw in text.splitlines():
191
- line = raw.strip().strip("|").strip()
192
- if not line:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  continue
194
- if ":" in line:
195
- # multiple '|'? try to split cells and parse each cell
196
- if "|" in raw:
197
- parts = [p.strip() for p in raw.split("|") if p.strip()]
198
- for cell in parts:
199
- if ":" in cell:
200
- k, v = cell.split(":", 1)
201
- cands[norm(k)] = norm(v)
202
- else:
203
- k, v = line.split(":", 1)
204
- cands[norm(k)] = norm(v)
205
-
206
- # 2) rows with ' | ' patterns but without colon in cells (rare)
207
- for raw in text.splitlines():
208
- if "|" in raw and ":" not in raw:
209
- parts = [p.strip() for p in raw.split("|") if p.strip() and not set(p.strip()) <= set("-")]
210
- # Heuristic: e.g., ["Dispatched Via","From","To","Under","No","Dated","Freight","Freight Amount"]
211
- # Hard to build k:v reliably here without a header row + next row; we skip unless obvious.
212
-
213
- # 3) totals without colon (e.g., "Taxable Value: 201801.60" already handled; but catch "Taxable Value 201801.60")
214
- for raw in text.splitlines():
215
- m = re.search(r"\b(Taxable\s+Value|Total\s+Value|Total\s+Amount|Amount\s+Due)\b[:\s]*([0-9][0-9,]*(?:\.[0-9]{2})?)", raw, re.I)
216
- if m:
217
- k = norm(m.group(1))
218
- v = norm(m.group(2))
219
- cands[k] = v
220
-
221
- return cands
222
-
223
- # ----------------------------- Regex “hard extractors” -----------------------------
224
- def regex_extract_all(text: str) -> Dict[str, str]:
225
- out: Dict[str, str] = {}
226
-
227
- # Invoice number
228
- m = re.search(r"\bInvoice\s*(?:No\.?|Number|#)\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
229
- if m: out["invoice_number"] = m.group(1)
230
-
231
- # Invoice date (DD-MM-YYYY or similar)
232
- m = re.search(r"\bInvoice\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
233
- if m: out["invoice_date"] = m.group(1)
234
-
235
- # PO number + date
236
- m = re.search(r"\bPO\s*(?:No\.?|Number)?\s*[:\-]?\s*([A-Z0-9\-\/]+)", text, re.I)
237
- if m: out["purchase_order_number"] = m.group(1)
238
- m = re.search(r"\bPO\s*Date\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
239
- if m: out["order_date"] = m.group(1)
240
-
241
- # Date of Supply -> order_date (if not already)
242
- if "order_date" not in out:
243
- m = re.search(r"\bDate\s*of\s*Supply\s*[:\-]?\s*([0-9]{1,2}[-/][0-9]{1,2}[-/][0-9]{2,4})", text, re.I)
244
- if m: out["order_date"] = m.group(1)
245
-
246
- # Place of Supply -> shipping_point
247
- m = re.search(r"\bPlace\s*of\s*Supply\s*[:\-]?\s*([A-Za-z0-9 ,\-\(\)]+)", text, re.I)
248
- if m: out["shipping_point"] = m.group(1).strip(" |")
249
-
250
- # GSTIN (take the first)
251
- m = re.search(r"\bGSTIN\s*(?:No\.?)?\s*[:\-]?\s*([A-Z0-9]{15})", text, re.I)
252
- if m: out["supplier_tax_id"] = m.group(1)
253
-
254
- # Taxable Value -> total_before_tax
255
- m = re.search(r"\bTaxable\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
256
- if m: out["total_before_tax"] = m.group(1).replace(",", "")
257
-
258
- # CGST/SGST values -> tax_amount (sum)
259
- cgst = re.search(r"\bCGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
260
- sgst = re.search(r"\bSGST\s*Value\s*[:\-]?\s*([0-9][0-9,]*(?:\.[0-9]{2})?)", text, re.I)
261
- if cgst and sgst:
262
  try:
263
- tax_total = float(cgst.group(1).replace(",", "")) + float(sgst.group(1).replace(",", ""))
264
- out["tax_amount"] = f"{tax_total:.2f}"
265
- # Tax rate (if both % available and equal, set combined)
266
- cgstp = re.search(r"\bCGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
267
- sgstp = re.search(r"\bSGST\s*%?\s*[:\-]?\s*([0-9]+(?:\.[0-9]+)?)", text, re.I)
268
- if cgstp and sgstp:
269
- try:
270
- rate = float(cgstp.group(1)) + float(sgstp.group(1))
271
- out["tax_rate"] = f"{rate:g}"
272
- except:
273
- pass
274
- except:
275
- pass
276
-
277
- # E-Way bill -> reference_number
278
- m = re.search(r"\bE[-\s]?Way\s*bill\s*no\.?\s*[:\-]?\s*([0-9 ]+)", text, re.I)
279
- if m: out["reference_number"] = m.group(1).strip()
280
-
281
- return out
282
-
283
- # ----------------------------- Bank block parsing -----------------------------
284
- def extract_bank_block(text: str) -> Dict[str, str]:
285
- bank: Dict[str, str] = {}
286
- # account name
287
- m = re.search(r"\bAccount\s*Name\s*:\s*(.+)", text, re.I)
288
- if m: bank["supplier_name"] = m.group(1).strip()
289
-
290
- # account no
291
- m = re.search(r"\bAccount\s*(?:No|Number)\s*:\s*([A-Za-z0-9\- ]+)", text, re.I)
292
- if m: bank["bank_account_number"] = m.group(1).strip()
293
-
294
- # bank name
295
- m = re.search(r"\bBank\s*:\s*([A-Za-z0-9 ,\-\(\)&]+)", text, re.I)
296
- if m:
297
- # place bank name into additional_info to avoid overwriting bank_account_number
298
- bank["additional_info"] = ("Bank: " + m.group(1).strip())
299
-
300
- # IFSC/IFS Code
301
- m = re.search(r"\bIFSC?\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
302
- if m: bank["payment_reference"] = m.group(1).strip()
303
-
304
- # SWIFT
305
- m = re.search(r"\bSWIFT\s*Code\s*:\s*([A-Za-z0-9]+)", text, re.I)
306
- if m: bank["swift_code"] = m.group(1).strip()
307
-
308
- # Branch / MICR etc -> additional_info
309
- branch = re.search(r"\bBranch\s*:\s*(.+)", text, re.I)
310
- micr = re.search(r"\bMICR\s*Code\s*:\s*([0-9]+)", text, re.I)
311
- extra_bits = []
312
- if branch: extra_bits.append("Branch: " + branch.group(1).strip())
313
- if micr: extra_bits.append("MICR: " + micr.group(1).strip())
314
- if extra_bits:
315
- bank["additional_info"] = ((bank.get("additional_info") + " | ") if bank.get("additional_info") else "") + " | ".join(extra_bits)
316
- return bank
317
-
318
- # ----------------------------- Line-item parser (from table) -----------------------------
319
- def parse_line_items(text: str) -> List[Dict[str, Any]]:
320
  """
321
- Parse a classic table with header like:
322
- | SNO | Description | HSN/SAC | Qty | UOM | Rate | ... | Total Value |
323
  """
324
- items: List[Dict[str, Any]] = []
325
- lines = [ln for ln in text.splitlines() if ln.strip()]
326
- # find header row index
327
- header_idx = -1
328
- for i, ln in enumerate(lines):
329
- if ("|") in ln and ("Description" in ln and ("Qty" in ln or "QTY" in ln)) and ("Rate" in ln or "Price" in ln) and ("Total" in ln):
330
- header_idx = i
331
- break
332
- if header_idx == -1:
333
- return items
334
-
335
- # parse header cells
336
- headers = [c.strip().lower() for c in lines[header_idx].split("|")]
337
- # clean
338
- headers = [h for h in headers if h and set(h) - set("-")]
339
-
340
- # parse body until a blank line or a non-table line
341
- for j in range(header_idx + 1, len(lines)):
342
- row = lines[j]
343
- if row.strip().startswith("|") and row.count("|") >= 2:
344
- cells = [c.strip() for c in row.split("|")]
345
- cells = [c for c in cells if c and set(c) - set("-")]
346
- if len(cells) < 3:
347
- continue
348
- # map to our schema per best-effort
349
- rowd = {"quantity": None, "units": None, "description": None, "footage": None, "price": None, "amount": None, "notes": None}
350
- # Try to find index of each logical column
351
- def idx_of(name_parts: List[str]) -> int:
352
- for k, h in enumerate(headers):
353
- if any(p in h for p in name_parts):
354
- return k
355
- return -1
356
- i_desc = idx_of(["description", "item"])
357
- i_qty = idx_of(["qty", "quantity"])
358
- i_uom = idx_of(["uom", "unit"])
359
- i_rate = idx_of(["rate", "price"])
360
- i_amt = idx_of(["total value", "amount", "total"])
361
-
362
- # safe get
363
- def safe(i: int) -> str:
364
- return cells[i] if 0 <= i < len(cells) else ""
365
-
366
- if i_desc != -1: rowd["description"] = safe(i_desc) or None
367
- if i_qty != -1: rowd["quantity"] = safe(i_qty) or None
368
- if i_uom != -1: rowd["units"] = safe(i_uom) or None
369
- if i_rate != -1: rowd["price"] = safe(i_rate) or None
370
- if i_amt != -1: rowd["amount"] = safe(i_amt) or None
371
-
372
- # optional: footage if present in desc like "60.000 mtrs"
373
- if rowd["units"] and rowd["quantity"]:
374
- rowd["footage"] = f'{rowd["quantity"]} {rowd["units"]}'
375
- items.append(rowd)
376
- else:
377
- # stop at first non-table line after header
378
- if j > header_idx + 1:
379
- break
380
- return items
381
-
382
- # ----------------------------- Semantic mapping for leftovers -----------------------------
383
- def semantic_map_candidates(candidates: Dict[str, str], static_headers: List[str], thresh: float) -> Dict[str, str]:
384
- if not candidates:
385
- return {}
386
- cand_keys = list(candidates.keys())
387
- # synonym pass first
388
- mapped: Dict[str, str] = {}
389
- leftovers: Dict[str, str] = {}
390
- for k, v in candidates.items():
391
- lk = k.lower()
392
- lk_norm = re.sub(r"[^a-z0-9]+", " ", lk).strip()
393
- hit = None
394
- for syn, key in SYN2KEY.items():
395
- if syn in lk_norm:
396
- hit = key
397
- break
398
- if hit:
399
- mapped[hit] = v
400
- else:
401
- leftovers[k] = v
402
-
403
- if leftovers:
404
- cand_emb = sentence_model.encode(list(leftovers.keys()), normalize_embeddings=True)
405
- head_emb = sentence_model.encode(static_headers, normalize_embeddings=True)
406
- M = util.cos_sim(torch.tensor(cand_emb), torch.tensor(head_emb)).cpu().numpy()
407
- keys_left = list(leftovers.keys())
408
- for i, ck in enumerate(keys_left):
409
- j = int(np.argmax(M[i]))
410
- score = float(M[i][j])
411
- if score >= thresh:
412
- mapped[static_headers[j]] = leftovers[ck]
413
- return mapped
414
-
415
- # ----------------------------- Build MD2JSON prompt -----------------------------
416
- def build_prompt(invoice_text: str, mapped_hints: Dict[str, str], items_hints: List[Dict[str, Any]]) -> str:
417
- instruction = (
418
- 'Use this schema:\n'
419
- '{\n'
420
- ' "invoice_header": {\n'
421
- ' "car_number": "string or null",\n'
422
- ' "shipment_number": "string or null",\n'
423
- ' "shipping_point": "string or null",\n'
424
- ' "currency": "string or null",\n'
425
- ' "invoice_number": "string or null",\n'
426
- ' "invoice_date": "string or null",\n'
427
- ' "order_number": "string or null",\n'
428
- ' "customer_order_number": "string or null",\n'
429
- ' "our_order_number": "string or null",\n'
430
- ' "sales_order_number": "string or null",\n'
431
- ' "purchase_order_number": "string or null",\n'
432
- ' "order_date": "string or null",\n'
433
- ' "supplier_name": "string or null",\n'
434
- ' "supplier_address": "string or null",\n'
435
- ' "supplier_phone": "string or null",\n'
436
- ' "supplier_email": "string or null",\n'
437
- ' "supplier_tax_id": "string or null",\n'
438
- ' "customer_name": "string or null",\n'
439
- ' "customer_address": "string or null",\n'
440
- ' "customer_phone": "string or null",\n'
441
- ' "customer_email": "string or null",\n'
442
- ' "customer_tax_id": "string or null",\n'
443
- ' "ship_to_name": "string or null",\n'
444
- ' "ship_to_address": "string or null",\n'
445
- ' "bill_to_name": "string or null",\n'
446
- ' "bill_to_address": "string or null",\n'
447
- ' "remit_to_name": "string or null",\n'
448
- ' "remit_to_address": "string or null",\n'
449
- ' "tax_id": "string or null",\n'
450
- ' "tax_registration_number": "string or null",\n'
451
- ' "vat_number": "string or null",\n'
452
- ' "payment_terms": "string or null",\n'
453
- ' "payment_method": "string or null",\n'
454
- ' "payment_reference": "string or null",\n'
455
- ' "bank_account_number": "string or null",\n'
456
- ' "iban": "string or null",\n'
457
- ' "swift_code": "string or null",\n'
458
- ' "total_before_tax": "string or null",\n'
459
- ' "tax_amount": "string or null",\n'
460
- ' "tax_rate": "string or null",\n'
461
- ' "shipping_charges": "string or null",\n'
462
- ' "discount": "string or null",\n'
463
- ' "total_due": "string or null",\n'
464
- ' "amount_paid": "string or null",\n'
465
- ' "balance_due": "string or null",\n'
466
- ' "due_date": "string or null",\n'
467
- ' "invoice_status": "string or null",\n'
468
- ' "reference_number": "string or null",\n'
469
- ' "project_code": "string or null",\n'
470
- ' "department": "string or null",\n'
471
- ' "contact_person": "string or null",\n'
472
- ' "notes": "string or null",\n'
473
- ' "additional_info": "string or null"\n'
474
- ' },\n'
475
- ' "line_items": [\n'
476
- ' {\n'
477
- ' "quantity": "string or null",\n'
478
- ' "units": "string or null",\n'
479
- ' "description": "string or null",\n'
480
- ' "footage": "string or null",\n'
481
- ' "price": "string or null",\n'
482
- ' "amount": "string or null",\n'
483
- ' "notes": "string or null"\n'
484
- ' }\n'
485
- ' ]\n'
486
- '}\n'
487
- 'If a field is missing for a line item or header, use null. '
488
- 'Do not invent fields. Do not add any header or shipment data to any line item. '
489
- 'Return ONLY the JSON object, no explanation.\n'
490
- )
491
- hints = ""
492
- if mapped_hints:
493
- hints += "\nHints (header):\n" + " ".join([f"#{k}: {v}" for k, v in mapped_hints.items()])
494
- if items_hints:
495
  try:
496
- hints += "\nHints (line_items):\n" + json.dumps(items_hints, ensure_ascii=False)
497
- except:
498
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
- return instruction + "\nInvoice Text:\n" + invoice_text.strip() + hints
 
 
 
 
501
 
502
- def strict_json(text: str) -> Dict[str, Any]:
503
- # try direct
 
 
 
 
 
 
 
 
504
  try:
505
- return json.loads(text)
506
- except:
507
- pass
508
- # extract largest {...}
509
- start = text.find("{")
510
- end = text.rfind("}")
511
- if start != -1 and end != -1 and end > start:
512
- try:
513
- return json.loads(text[start:end+1])
514
- except:
515
- pass
516
- raise ValueError("Model did not return valid JSON.")
517
 
518
- # ----------------------------- Final merge policy -----------------------------
519
- def merge_schema(rule_json: Dict[str, Any], model_json: Dict[str, Any]) -> Dict[str, Any]:
520
- """
521
- RULES WIN: Keep everything we extracted deterministically; fill only missing (None) from model.
522
- """
523
- final = copy.deepcopy(rule_json)
524
-
525
- # header
526
- hdr = final["invoice_header"]
527
- mdl_hdr = (model_json.get("invoice_header") or {})
528
- for k in hdr.keys():
529
- if hdr[k] in [None, "", "null"]:
530
- v = mdl_hdr.get(k, None)
531
- if v not in [None, "", "null"]:
532
- hdr[k] = v
533
-
534
- # line_items: if we got some via rules, keep them; else take model's
535
- if final["line_items"] and any(any(v for v in row.values() if v not in [None, "", "null"]) for row in final["line_items"]):
536
- pass
537
- else:
538
- mdl_items = model_json.get("line_items")
539
- if isinstance(mdl_items, list) and mdl_items:
540
- final["line_items"] = mdl_items
541
  else:
542
- # keep template with nulls
543
- pass
 
 
544
 
545
- return final
 
546
 
547
- # ----------------------------- UI -----------------------------
548
- invoice_text = st.text_area(
549
- "Paste the invoice text here.",
550
- height=320,
551
- placeholder="Paste the invoice content (OCR/plain text) ..."
552
- )
553
 
554
- if st.button("Generate JSON", type="primary", use_container_width=True):
555
- if not invoice_text.strip():
556
- st.error("Please paste the invoice text first.")
557
  st.stop()
558
 
559
- txt = invoice_text
560
-
561
- # 1) Deterministic extraction
562
- # 1a) candidates (pipe-table aware)
563
- candidates = extract_candidates(txt)
564
-
565
- # 1b) regex “hard” fields
566
- hard = regex_extract_all(txt)
567
-
568
- # 1c) bank block
569
- bank = extract_bank_block(txt)
570
-
571
- # 1d) line items from table
572
- items = parse_line_items(txt)
573
-
574
- # 1e) map candidates (synonyms + semantic) to schema headers
575
- sem_mapped = semantic_map_candidates(candidates, STATIC_HEADERS, threshold)
576
-
577
- # 1f) combine deterministic header fields
578
- header_found: Dict[str, Any] = {}
579
- header_found.update(sem_mapped)
580
- header_found.update(hard)
581
- header_found.update(bank)
582
-
583
- # 2) Build RULE JSON (schema-shaped, rules filled)
584
- rule_json = deep_copy_schema()
585
- for k, v in header_found.items():
586
- if k in rule_json["invoice_header"]:
587
- rule_json["invoice_header"][k] = v
588
- # line items
589
- if items:
590
- rule_json["line_items"] = items
591
-
592
- if show_intermediates:
593
- st.subheader("Candidates (first 20)")
594
- st.json(dict(list(candidates.items())[:20]))
595
- st.subheader("Regex/Hard fields")
596
- st.json(hard)
597
- st.subheader("Bank block")
598
- st.json(bank)
599
- st.subheader("Semantic-mapped headers")
600
- st.json(sem_mapped)
601
- st.subheader("Line items (parsed)")
602
- st.json(items)
603
-
604
- # 3) MD2JSON generation with strong hints
605
- with st.spinner("Generating structured JSON with MD2JSON-T5-small-V1..."):
606
- prompt = build_prompt(txt, header_found, items)
607
- gen = json_converter(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]
608
- try:
609
- model_json = strict_json(gen)
610
- except:
611
- model_json = deep_copy_schema() # model failed; keep empty shape
612
-
613
- # 4) Final merge (rules win)
614
- final_json = merge_schema(rule_json, model_json)
615
-
616
- st.subheader("Final JSON")
617
- st.json(final_json)
618
- st.download_button("Download JSON", data=json.dumps(final_json, indent=2),
619
- file_name="invoice.json", mime="application/json", use_container_width=True)
 
1
  # app.py
2
+ import os
3
+ import io
4
+ import base64
 
 
 
 
 
 
5
  import json
6
+ import time
7
+ import requests
8
+ from typing import List, Dict, Any, Optional, Tuple
9
 
 
10
  import streamlit as st
11
+ from PIL import Image
12
+
13
+ # --------- CONFIG ---------
14
+ # Expected environment variables on HF Space:
15
+ # OPENAI_API_KEY -> for LLM + Vision
16
+ # SERPAPI_KEY -> optional, enables web price research via Google Shopping
17
+ # --------------------------
18
+
19
+ # --- OpenAI client (v1) ---
20
+ try:
21
+ from openai import OpenAI
22
+ oai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
23
+ except Exception as e:
24
+ oai_client = None
25
+
26
+ # ---------- UI ----------
27
+ st.set_page_config(page_title="Grocery Savings Agent", page_icon="🧾", layout="centered")
28
+ st.title("🧾 Grocery Savings Agent (Canada)")
29
+ st.caption("Upload a grocery receipt. I’ll extract your items, research prices at other stores, and tell you what you could’ve saved (in ≤5 lines).")
30
+
31
+ with st.expander("🔧 Setup checklist (first run)"):
32
+ st.markdown(
33
+ "- Add **OPENAI_API_KEY** (required) in your Space Secrets\n"
34
+ "- (Optional) Add **SERPAPI_KEY** to enable live web price lookups via Google Shopping\n"
35
+ "- Supported uploads: JPG/PNG/PDF (first page used if multi-page)\n"
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # ---------- Helpers ----------
39
+ def img_or_pdf_to_image_bytes(upload) -> bytes:
40
+ """Accepts image or PDF and returns JPG bytes suitable for OpenAI Vision."""
41
+ name = upload.name.lower()
42
+ data = upload.read()
43
+
44
+ if name.endswith((".jpg", ".jpeg", ".png", ".webp")):
45
+ img = Image.open(io.BytesIO(data)).convert("RGB")
46
+ buf = io.BytesIO()
47
+ img.save(buf, format="JPEG", quality=90)
48
+ return buf.getvalue()
49
+
50
+ # Simple PDF first-page render (without system poppler): use pillow-pdf if available,
51
+ # else pass PDF bytes directly to GPT-V (works reasonably for many receipts).
52
+ if name.endswith(".pdf"):
53
+ # Try pillow's built-in PDF rendering via ghostscript providers (may be limited).
54
+ try:
55
+ # Many minimal environments can at least open single-page vector PDFs as images
56
+ img = Image.open(io.BytesIO(data)).convert("RGB")
57
+ buf = io.BytesIO()
58
+ img.save(buf, format="JPEG", quality=90)
59
+ return buf.getvalue()
60
+ except Exception:
61
+ # Fallback: return the PDF bytes; OpenAI Vision can read PDFs as a "file" content
62
+ return data
63
+
64
+ # Fallback treat as image
65
+ return data
66
+
67
+ def b64_data_uri(data: bytes, mime: str) -> str:
68
+ return f"data:{mime};base64," + base64.b64encode(data).decode("utf-8")
69
+
70
+ def call_openai_vision_for_receipt(image_bytes: bytes) -> Dict[str, Any]:
71
  """
72
+ Ask OpenAI to parse the receipt into a strict JSON schema.
73
+ Uses gpt-4o-mini for cost/perf. You can swap to 'gpt-4.1-mini' if preferred.
 
 
74
  """
75
+ if oai_client is None:
76
+ raise RuntimeError("OpenAI client not initialized. Check OPENAI_API_KEY.")
77
 
78
+ # Heuristically decide MIME
79
+ is_pdf = image_bytes[0:4] == b"%PDF"
80
+ mime = "application/pdf" if is_pdf else "image/jpeg"
81
+
82
+ system = (
83
+ "You are a strict, no-chitchat receipt parser for Canadian grocery receipts. "
84
+ "Return ONLY valid JSON matching the schema. Prices in CAD. "
85
+ "Do not infer items not clearly present."
86
+ )
87
+ user_prompt = """
88
+ Extract a clean JSON that follows this schema exactly:
89
+
90
+ {
91
+ "store": {"name": "string", "address": "string|null", "date": "YYYY-MM-DD|null"},
92
+ "items": [
93
+ {"name": "string", "size": "string|null", "qty": 1, "unit_price": 0.00, "line_total": 0.00}
94
+ ],
95
+ "subtotal": 0.00,
96
+ "tax": 0.00,
97
+ "total": 0.00
98
+ }
99
+
100
+ Rules:
101
+ - item.name should be shopper-friendly (e.g., "Natrel 2% Milk 2L" not cryptic codes).
102
+ - qty is integer >=1; prefer the printed quantity if any.
103
+ - unit_price is per single unit before tax.
104
+ - line_total = qty * unit_price (or the printed extended line).
105
+ - If a value is missing on the receipt, set it to null or sensible default (e.g., qty=1).
106
+ - Return ONLY JSON.
107
+ """
108
+
109
+ content = [
110
+ {"type": "input_text", "text": user_prompt},
111
+ {
112
+ "type": "input_image",
113
+ "image_url": b64_data_uri(image_bytes, mime)
114
+ }
115
+ ]
116
+
117
+ resp = oai_client.responses.create(
118
+ model="gpt-4o-mini",
119
+ temperature=0,
120
+ messages=[
121
+ {"role": "system", "content": system},
122
+ {"role": "user", "content": content}
123
+ ]
124
+ )
125
+ # The new Responses API puts text in output_text helper or in output[0]
126
+ try:
127
+ parsed = resp.output_text
128
+ except Exception:
129
+ # Fallback deep extraction
130
+ chunks = []
131
+ for out in resp.output or []:
132
+ for ct in getattr(out, "content", []) or []:
133
+ if ct.type == "output_text":
134
+ chunks.append(ct.text)
135
+ parsed = "\n".join(chunks)
136
+
137
+ # Strip fences if any and parse JSON
138
+ s = parsed.strip()
139
+ if s.startswith("```"):
140
+ s = s.split("```", 2)[1]
141
+ if s.startswith(("json", "JSON")):
142
+ s = s.split("\n", 1)[1]
143
+ data = json.loads(s)
144
+ return data
145
+
146
+ def serpapi_google_shopping(query: str) -> Optional[Dict[str, Any]]:
147
+ """Search price via Google Shopping using SerpAPI."""
148
+ key = os.getenv("SERPAPI_KEY")
149
+ if not key:
150
+ return None
151
+ url = "https://serpapi.com/search.json"
152
+ params = {
153
+ "engine": "google_shopping",
154
+ "q": query,
155
+ "gl": "ca",
156
+ "hl": "en",
157
+ "api_key": key
158
+ }
159
+ r = requests.get(url, params=params, timeout=20)
160
+ if r.status_code != 200:
161
+ return None
162
+ data = r.json()
163
+ products = data.get("shopping_results") or []
164
+ # Pick the first reasonable priced result
165
+ for p in products:
166
+ price = p.get("price")
167
+ if not price:
168
  continue
169
+ # Normalize "$3.99"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  try:
171
+ price_num = float(price.replace("$","").replace(",","").strip())
172
+ except Exception:
173
+ continue
174
+ return {
175
+ "title": p.get("title"),
176
+ "price": price_num,
177
+ "source": p.get("source"),
178
+ "link": p.get("link")
179
+ }
180
+ return None
181
+
182
+ def normalize_query(item: Dict[str, Any]) -> str:
183
+ base = item.get("name") or ""
184
+ size = item.get("size") or ""
185
+ # keep it concise
186
+ q = f"{base} {size}".strip()
187
+ # remove store-specific codes
188
+ return " ".join([tok for tok in q.split() if len(tok) > 1])
189
+
190
+ def research_prices(items: List[Dict[str, Any]], max_items: int = 6) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  """
192
+ For each item, query Google Shopping via SerpAPI (if available).
193
+ Returns list with possibly a cheaper offer for each item.
194
  """
195
+ results = []
196
+ for item in items[:max_items]: # keep it snappy
197
+ query = normalize_query(item)
198
+ if not query:
199
+ continue
200
+ offer = serpapi_google_shopping(query)
201
+ if not offer:
202
+ continue
203
+ unit_price = item.get("unit_price") or None
204
+ cheaper = None
205
+ if unit_price is not None and isinstance(unit_price, (int, float)):
206
+ if offer["price"] < float(unit_price) - 0.005:
207
+ cheaper = offer
208
+ results.append({
209
+ "item_name": item.get("name"),
210
+ "receipt_unit_price": unit_price,
211
+ "found_price": offer["price"],
212
+ "found_store": offer["source"],
213
+ "found_title": offer["title"],
214
+ "found_link": offer["link"],
215
+ "is_cheaper": bool(cheaper)
216
+ })
217
+ time.sleep(0.4) # be gentle
218
+ return results
219
+
220
+ def compute_savings(receipt: Dict[str, Any], found: List[Dict[str, Any]]) -> Tuple[float, List[Dict[str, Any]]]:
221
+ cheaper = [f for f in found if f.get("is_cheaper")]
222
+ savings = 0.0
223
+ for f in cheaper:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  try:
225
+ savings += float(f["receipt_unit_price"]) - float(f["found_price"])
226
+ except Exception:
227
  pass
228
+ return round(savings, 2), cheaper
229
+
230
+ def format_five_lines(receipt: Dict[str, Any], savings: float, cheaper_list: List[Dict[str, Any]]) -> str:
231
+ store = (receipt.get("store") or {}).get("name") or "your store"
232
+ total = receipt.get("total") or receipt.get("subtotal") or None
233
+ total_txt = f"${total:.2f}" if isinstance(total, (int, float)) else "N/A"
234
+
235
+ lines = []
236
+ # 1
237
+ lines.append(f"Receipt read: {store}, total {total_txt}.")
238
+ # 2
239
+ lines.append(f"I found potential savings of ${savings:.2f} by checking other stores.")
240
+ # 3
241
+ if cheaper_list:
242
+ bullets = []
243
+ for f in cheaper_list[:3]:
244
+ item = f['item_name'] or 'Item'
245
+ shop = f['found_store'] or 'other store'
246
+ price = f['found_price']
247
+ bullets.append(f"{item} @ {shop} ${price:.2f}")
248
+ lines.append("Cheaper picks: " + "; ".join(bullets) + ".")
249
+ else:
250
+ lines.append("No clearly cheaper matches found right now for your items.")
251
 
252
+ # 4
253
+ # Simple best “deal” heuristic: top 1 lowest price vs its own receipt price
254
+ if cheaper_list:
255
+ best = sorted(cheaper_list, key=lambda x: x["found_price"])[0]
256
+ lines.append(f"Best deal now: {best['item_name']} at {best['found_store']} for ${best['found_price']:.2f}.")
257
 
258
+ # 5
259
+ lines.append("Reply 'DEALS' anytime to get weekly picks tailored to your receipts.")
260
+
261
+ # Ensure ≤5 lines
262
+ return "\n".join(lines[:5])
263
+
264
+ # ---------- Main UI flow ----------
265
+ uploaded = st.file_uploader("Upload receipt (image or PDF)", type=["jpg","jpeg","png","webp","pdf"])
266
+
267
+ if uploaded and st.button("Analyze Receipt"):
268
  try:
269
+ img_bytes = img_or_pdf_to_image_bytes(uploaded)
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ with st.spinner("Reading receipt with OpenAI…"):
272
+ receipt = call_openai_vision_for_receipt(img_bytes)
273
+
274
+ items = receipt.get("items") or []
275
+ if not items:
276
+ st.error("I couldn't find any line-items on that receipt. Try a higher-resolution image.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  else:
278
+ with st.spinner("Researching prices at other stores…"):
279
+ found = research_prices(items)
280
+ savings, cheaper_list = compute_savings(receipt, found)
281
+ summary = format_five_lines(receipt, savings, cheaper_list)
282
 
283
+ # ✅ User-facing message (≤5 lines)
284
+ st.success(summary)
285
 
286
+ with st.expander("🔎 What I parsed (debug)"):
287
+ st.json(receipt)
288
+ with st.expander("🌐 Price lookups (debug)"):
289
+ st.json(found)
 
 
290
 
291
+ except Exception as e:
292
+ st.error(f"Something went wrong: {e}")
 
293
  st.stop()
294
 
295
+ st.markdown("---")
296
+ st.caption("Tip: Add SERPAPI_KEY for stronger live price checks (Google Shopping). SMS integration can be added later (e.g., Twilio).")