Ankushbl6 commited on
Commit
efa6392
·
verified ·
1 Parent(s): c5b9438

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +383 -92
src/streamlit_app.py CHANGED
@@ -1,7 +1,6 @@
1
  # =========================
2
  # Invoice Extractor (Qwen3-VL via RunPod vLLM) - Batch Mode with Tax Validation
3
- # UPDATED: Comprehensive date parsing (50+ formats) + Hybrid date display
4
- # FIX: Tax calculation skips both empty ("") and explicit zero (0.00) values
5
  # =========================
6
  import os
7
  from pathlib import Path
@@ -90,27 +89,118 @@ def ensure_state(k: str, default):
90
  st.session_state[k] = default
91
 
92
  def clean_float(x) -> float:
93
- import re
 
 
 
 
 
 
 
 
 
 
 
 
94
  if x is None:
95
  return 0.0
96
  if isinstance(x, (int, float)):
97
  return float(x)
 
98
  s = str(x).strip()
99
  if s == "":
100
  return 0.0
101
- s = re.sub(r"[,\s]", "", s)
102
- s = re.sub(r"[^\d\.\-]", "", s)
103
- if s in ("", ".", "-", "-."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  try:
106
- return float(s)
107
- except Exception:
 
108
  return 0.0
109
 
110
  def normalize_date(date_str) -> str:
111
  """
112
- Normalize various date formats to dd-MMM-yyyy format (e.g., 01-Jan-2025)
113
- Handles: ISO, US, EU, Asian, two-digit years, and 50+ worldwide date formats
 
114
  Returns empty string if date cannot be parsed
115
  """
116
  if not date_str or date_str == "":
@@ -121,8 +211,8 @@ def normalize_date(date_str) -> str:
121
  if date_str == "":
122
  return ""
123
 
124
- # Comprehensive list of date formats to try (order matters - most specific first)
125
- formats = [
126
  # ISO formats (4-digit year)
127
  "%Y-%m-%d", # 2025-01-15
128
  "%Y/%m/%d", # 2025/01/15
@@ -162,7 +252,7 @@ def normalize_date(date_str) -> str:
162
 
163
  # European formats with 2-digit year - Day first
164
  "%d-%m-%y", # 15-01-25
165
- "%d/%m/%y", # 15/01/25 or 25/09/25 ← FIXES YOUR ISSUE!
166
  "%d.%m.%y", # 15.01.25
167
  "%d %m %y", # 15 01 25
168
 
@@ -205,35 +295,72 @@ def normalize_date(date_str) -> str:
205
  "%Y%d%m", # 20251501
206
  ]
207
 
208
- parsed_date = None
209
-
210
- # Try parsing with each format
211
- for fmt in formats:
212
  try:
213
  parsed_date = datetime.strptime(str(date_str), fmt)
214
- break
215
  except (ValueError, TypeError):
216
  continue
217
 
218
- # If still not parsed, try removing ordinal suffixes (st, nd, rd, th)
219
- if parsed_date is None and isinstance(date_str, str):
220
- import re
221
  cleaned = re.sub(r'(\d+)(st|nd|rd|th)\b', r'\1', date_str, flags=re.IGNORECASE)
222
-
223
  if cleaned != date_str:
224
- for fmt in formats:
225
  try:
226
  parsed_date = datetime.strptime(cleaned, fmt)
227
- break
228
  except (ValueError, TypeError):
229
  continue
230
 
231
- # If no format matched, return empty string
232
- if parsed_date is None:
233
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Format as dd-MMM-yyyy (e.g., 01-Jan-2025)
236
- return parsed_date.strftime("%d-%b-%Y")
 
 
 
 
 
 
 
 
237
 
238
  def parse_date_to_object(date_str):
239
  """
@@ -331,6 +458,41 @@ def parse_date_to_object(date_str):
331
  "%d%m%Y", # 15012025
332
  "%m%d%Y", # 01152025
333
  "%Y%d%m", # 20251501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  ]
335
 
336
  # Try parsing with each format
@@ -356,22 +518,6 @@ def parse_date_to_object(date_str):
356
 
357
  return None
358
 
359
- # -----------------------------
360
- # HF login flow (REMOVED - No longer needed for vLLM API)
361
- # -----------------------------
362
- # Authentication is now handled via POD_URL and VLLM_API_KEY instead
363
-
364
- # -----------------------------
365
- # Model config
366
- # -----------------------------
367
- # OLD DONUT CODE (COMMENTED OUT - Now using vLLM API)
368
- # -----------------------------
369
- # HF_MODEL_ID = "Bhuvi13/model-V7"
370
- # TASK_PROMPT = "<s_cord-v2>"
371
- #
372
- # @st.cache_resource(show_spinner=False)
373
- # def load_model_and_processor(hf_model_id: str, task_prompt: str):
374
- # ...
375
 
376
  # -----------------------------
377
  # vLLM Inference Function (RunPod API)
@@ -407,7 +553,8 @@ Extract the data into this exact JSON structure:
407
  "quantity": "Quantity of items",
408
  "unit_price": "Price per unit",
409
  "amount": "Total amount for this line item",
410
- "tax": "Tax amount for this item",
 
411
  "Line_total": "Total amount including tax for this line"
412
  }
413
  ],
@@ -434,7 +581,7 @@ IMPORTANT GUIDELINES:
434
  - Extract text exactly as it appears, including special characters and formatting
435
  - For dates, preserve the original format shown in the invoice
436
  - If both sender and receiver addresses are in the United States, extract ACH; otherwise extract Wire transfer (WT).
437
- - If payment terms specify a number of days (e.g., payment terms 30 days”, payable within 15 days”, terms 45 days”, Net 30”, or any similar phrase), compute: due_date = invoice_date + N days. If the invoice states due on receipt”, due upon receipt ,"Immediate" or any similar phrase meaning immediate payment, then: due_date = invoice_date. Use the same date format as the invoice. Output only the computed due_date.
438
  - if tax_rate is not given in invoice but tax_amount is given, calculate the tax_rate using tax_amount and subtotal.
439
  - line-item wise tax calculation has to be done properly based ONLY on the tax_rate given in the summary, and the same tax_rate must be used for every line item in that invoice.
440
  - If currency symbols are present, note them appropriately
@@ -519,12 +666,98 @@ Return only the JSON object with the extracted information"""
519
  def parse_vllm_json(raw_json_text):
520
  """Parse vLLM JSON output into structured format"""
521
  try:
522
- data = json.loads(raw_json_text)
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  def clean_amount(value):
525
- if not value or value == "":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  return 0.0
527
- return float(re.sub(r'[^\d\.-]', '', str(value)))
528
 
529
  header = data.get("header", {})
530
  summary = data.get("summary", {})
@@ -792,6 +1025,7 @@ def map_prediction_to_ui(pred):
792
  return None
793
 
794
  def clean_number(x):
 
795
  if x is None:
796
  return 0.0
797
  if isinstance(x, (int, float)):
@@ -799,13 +1033,71 @@ def map_prediction_to_ui(pred):
799
  s = str(x).strip()
800
  if s == "":
801
  return 0.0
802
- s = re.sub(r"[,\s]", "", s)
803
- s = re.sub(r"[^\d\.\-]", "", s)
804
- if s in ("", ".", "-", "-."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  try:
807
- return float(s)
808
- except Exception:
 
809
  return 0.0
810
 
811
  def collect_keys(obj, out):
@@ -1076,16 +1368,6 @@ def flatten_invoice_to_rows(invoice_data) -> list:
1076
  rows.append(row)
1077
  return rows
1078
 
1079
- # -----------------------------
1080
- # Load model (COMMENTED OUT - Now using vLLM API)
1081
- # -----------------------------
1082
- # try:
1083
- # with st.spinner("Loading model & processor (cached) ..."):
1084
- # processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT)
1085
- # except Exception as e:
1086
- # st.error("Could not load model automatically. See details below.")
1087
- # st.exception(e)
1088
- # st.stop()
1089
 
1090
  # -----------------------------
1091
  # Session scaffolding
@@ -1156,6 +1438,8 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
1156
  continue
1157
 
1158
  # vLLM Inference + parsing + tax validation
 
 
1159
  try:
1160
  # Call vLLM API
1161
  raw_json = run_inference_vllm(image)
@@ -1174,18 +1458,18 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
1174
  st.warning(f"No response from vLLM for {uploaded_file.name}")
1175
  mapped = {}
1176
 
1177
- pred = raw_json # Store raw JSON for debugging
1178
  except Exception as e:
1179
  st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
1180
- pred = None
1181
  mapped = {}
1182
 
1183
  safe_mapped = mapped if isinstance(mapped, dict) else {}
1184
 
 
1185
  st.session_state.batch_results[file_hash] = {
1186
  "file_name": uploaded_file.name,
1187
  "image": image,
1188
- "raw_pred": pred,
1189
  "mapped_data": safe_mapped,
1190
  "edited_data": safe_mapped.copy()
1191
  }
@@ -1320,9 +1604,17 @@ elif len(st.session_state.batch_results) > 0:
1320
  with frame_left:
1321
  st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH)
1322
  st.write(f"**File Hash:** {selected_hash[:8]}...")
1323
- if current.get('raw_pred') is not None:
1324
- with st.expander("🔍 Show raw model output"):
1325
- st.json(current['raw_pred'])
 
 
 
 
 
 
 
 
1326
 
1327
  if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"):
1328
  with st.spinner("Re-running inference..."):
@@ -1345,10 +1637,9 @@ elif len(st.session_state.batch_results) > 0:
1345
  mapped = {}
1346
 
1347
  safe_mapped = mapped if isinstance(mapped, dict) else {}
1348
- pred = raw_json # Store raw JSON
1349
 
1350
  # Update stored results
1351
- st.session_state.batch_results[selected_hash]["raw_pred"] = pred
1352
  st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
1353
  st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
1354
 
@@ -1396,10 +1687,10 @@ elif len(st.session_state.batch_results) > 0:
1396
  if st.session_state.get(f"Currency_{selected_hash}") == 'Other':
1397
  st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}")
1398
 
1399
- st.number_input("Subtotal", key=f"Subtotal_{selected_hash}")
1400
- st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}")
1401
- st.number_input("Total Tax", key=f"Total Tax_{selected_hash}")
1402
- st.number_input("Total Amount", key=f"Total Amount_{selected_hash}")
1403
 
1404
  with tabs[1]:
1405
  st.text_input("Sender Name", key=f"Sender Name_{selected_hash}")
@@ -1512,7 +1803,7 @@ elif len(st.session_state.batch_results) > 0:
1512
 
1513
  st.dataframe(
1514
  totals_df,
1515
- width="stretch", # <- see note below
1516
  hide_index=True,
1517
  height=38
1518
  )
@@ -1562,16 +1853,16 @@ elif len(st.session_state.batch_results) > 0:
1562
  calculated_tax_pct = round((calculated_total_tax / calculated_subtotal) * 100, 4)
1563
 
1564
  if saved:
1565
- # Build updated data structure
1566
  updated = {
1567
  'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
1568
  'Invoice Date': invoice_date_str,
1569
  'Due Date': due_date_str,
1570
  'Currency': currency,
1571
- 'Subtotal': calculated_subtotal, # Auto-calculated from line items
1572
- 'Tax Percentage': calculated_tax_pct, # Auto-calculated
1573
- 'Total Tax': calculated_total_tax, # Auto-calculated from line items
1574
- 'Total Amount': calculated_total, # Auto-calculated from line items
1575
  'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
1576
  'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
1577
  'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
@@ -1595,10 +1886,10 @@ elif len(st.session_state.batch_results) > 0:
1595
  # Save to batch_results (this persists the data)
1596
  st.session_state.batch_results[selected_hash]["edited_data"] = updated
1597
 
1598
- # CRITICAL: Clear items_df from session state so it rebuilds from saved data on next rerun
1599
- items_state_key = f"items_df_{selected_hash}"
1600
- if items_state_key in st.session_state:
1601
- del st.session_state[items_state_key]
1602
 
1603
  # Show success message
1604
  st.success("✅ Saved")
@@ -1606,16 +1897,16 @@ elif len(st.session_state.batch_results) > 0:
1606
  # Rerun to reload the form with saved data
1607
  st.rerun()
1608
 
1609
- # Per-file CSV download (ALWAYS visible, uses current edited values)
1610
  download_data = {
1611
  'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
1612
  'Invoice Date': invoice_date_str,
1613
  'Due Date': due_date_str,
1614
  'Currency': currency,
1615
- 'Subtotal': calculated_subtotal, # Use calculated value
1616
- 'Tax Percentage': calculated_tax_pct, # Use calculated value
1617
- 'Total Tax': calculated_total_tax, # Use calculated value
1618
- 'Total Amount': calculated_total, # Use calculated value
1619
  'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
1620
  'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
1621
  'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
 
1
  # =========================
2
  # Invoice Extractor (Qwen3-VL via RunPod vLLM) - Batch Mode with Tax Validation
3
+ # UPDATED: Fixed raw model output display
 
4
  # =========================
5
  import os
6
  from pathlib import Path
 
89
  st.session_state[k] = default
90
 
91
  def clean_float(x) -> float:
92
+ """
93
+ Parse a number string handling both US and European formats.
94
+
95
+ US Format: 1,234,567.89 (comma = thousands, period = decimal)
96
+ European: 1.234.567,89 (period = thousands, comma = decimal)
97
+
98
+ Examples:
99
+ "1,234.56" → 1234.56 (US)
100
+ "1.234,56" → 1234.56 (European)
101
+ "3.000,2234" → 3000.2234 (European with 4 decimal places)
102
+ "261,49" → 261.49 (European decimal only)
103
+ "39,22-" → -39.22 (European with trailing minus)
104
+ """
105
  if x is None:
106
  return 0.0
107
  if isinstance(x, (int, float)):
108
  return float(x)
109
+
110
  s = str(x).strip()
111
  if s == "":
112
  return 0.0
113
+
114
+ # Handle negative signs (could be leading or trailing)
115
+ is_negative = False
116
+ if s.startswith('-'):
117
+ is_negative = True
118
+ s = s[1:].strip()
119
+ elif s.endswith('-'):
120
+ is_negative = True
121
+ s = s[:-1].strip()
122
+ elif s.startswith('(') and s.endswith(')'):
123
+ # Accounting format: (123.45) means negative
124
+ is_negative = True
125
+ s = s[1:-1].strip()
126
+
127
+ # Remove currency symbols and spaces
128
+ s = re.sub(r'[€$£¥₹\s]', '', s)
129
+
130
+ if s == "":
131
  return 0.0
132
+
133
+ # Count occurrences
134
+ comma_count = s.count(',')
135
+ period_count = s.count('.')
136
+
137
+ # Find positions of last comma and last period
138
+ last_comma = s.rfind(',')
139
+ last_period = s.rfind('.')
140
+
141
+ # Determine format based on which separator comes last
142
+ if comma_count > 0 and period_count > 0:
143
+ # Both separators present - the LAST one is the decimal separator
144
+ if last_comma > last_period:
145
+ # European format: 1.234,56 → comma is decimal
146
+ # Remove periods (thousands), replace comma with period
147
+ s = s.replace('.', '').replace(',', '.')
148
+ else:
149
+ # US format: 1,234.56 → period is decimal
150
+ # Remove commas (thousands)
151
+ s = s.replace(',', '')
152
+
153
+ elif comma_count > 0 and period_count == 0:
154
+ # Only commas present
155
+ # Check what comes after the LAST comma
156
+ after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
157
+
158
+ if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
159
+ # Single comma with 1-4 digits after → European decimal
160
+ # "261,49" → 261.49, "1234,5678" → 1234.5678
161
+ s = s.replace(',', '.')
162
+ elif len(after_last_comma) == 3 and comma_count >= 1:
163
+ # 3 digits after comma(s) → likely thousands separator
164
+ # "1,234" → 1234, "1,234,567" → 1234567
165
+ s = s.replace(',', '')
166
+ else:
167
+ # Multiple commas → thousands separator
168
+ # "1,234,567" → 1234567
169
+ s = s.replace(',', '')
170
+
171
+ elif period_count > 0 and comma_count == 0:
172
+ # Only periods present
173
+ # Check what comes after the LAST period
174
+ after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
175
+
176
+ if period_count > 1:
177
+ # Multiple periods → definitely thousands separator (European: "1.234.567")
178
+ s = s.replace('.', '')
179
+ elif len(after_last_period) == 3 and after_last_period.isdigit():
180
+ # Single period with exactly 3 digits after → European thousands: "1.000" → 1000
181
+ # (In invoices, "1.000" almost always means 1000, not 1.0 with trailing zeros)
182
+ before_period = s[:last_period]
183
+ if before_period.isdigit() and len(before_period) <= 3:
184
+ s = s.replace('.', '')
185
+ # Otherwise keep as is (standard decimal like "1.50", "123.45")
186
+
187
+ # Clean any remaining non-numeric characters except period and minus
188
+ s = re.sub(r'[^\d.]', '', s)
189
+
190
+ if s == "" or s == ".":
191
+ return 0.0
192
+
193
  try:
194
+ result = float(s)
195
+ return -result if is_negative else result
196
+ except ValueError:
197
  return 0.0
198
 
199
  def normalize_date(date_str) -> str:
200
  """
201
+ Normalize various date formats:
202
+ - Full dates (day-month-year) dd-MMM-yyyy (e.g., 01-Jan-2025)
203
+ - Month-year only → MMM-yyyy (e.g., Aug-2025)
204
  Returns empty string if date cannot be parsed
205
  """
206
  if not date_str or date_str == "":
 
211
  if date_str == "":
212
  return ""
213
 
214
+ # FULL DATE FORMATS (day-month-year) - try these first
215
+ full_date_formats = [
216
  # ISO formats (4-digit year)
217
  "%Y-%m-%d", # 2025-01-15
218
  "%Y/%m/%d", # 2025/01/15
 
252
 
253
  # European formats with 2-digit year - Day first
254
  "%d-%m-%y", # 15-01-25
255
+ "%d/%m/%y", # 15/01/25
256
  "%d.%m.%y", # 15.01.25
257
  "%d %m %y", # 15 01 25
258
 
 
295
  "%Y%d%m", # 20251501
296
  ]
297
 
298
+ # Try full date formats first → output as dd-MMM-yyyy
299
+ for fmt in full_date_formats:
 
 
300
  try:
301
  parsed_date = datetime.strptime(str(date_str), fmt)
302
+ return parsed_date.strftime("%d-%b-%Y")
303
  except (ValueError, TypeError):
304
  continue
305
 
306
+ # Try with ordinal suffixes removed (1st, 2nd, 3rd, etc.)
307
+ if isinstance(date_str, str):
 
308
  cleaned = re.sub(r'(\d+)(st|nd|rd|th)\b', r'\1', date_str, flags=re.IGNORECASE)
 
309
  if cleaned != date_str:
310
+ for fmt in full_date_formats:
311
  try:
312
  parsed_date = datetime.strptime(cleaned, fmt)
313
+ return parsed_date.strftime("%d-%b-%Y")
314
  except (ValueError, TypeError):
315
  continue
316
 
317
+ # MONTH-YEAR ONLY FORMATS - output as MMM-yyyy
318
+ month_year_formats = [
319
+ # Full month name with year
320
+ "%B %Y", # August 2025
321
+ "%b %Y", # Aug 2025
322
+ "%B, %Y", # August, 2025
323
+ "%b, %Y", # Aug, 2025
324
+ "%B-%Y", # August-2025
325
+ "%b-%Y", # Aug-2025
326
+ "%B/%Y", # August/2025
327
+ "%b/%Y", # Aug/2025
328
+
329
+ # Numeric month-year (4-digit year)
330
+ "%m/%Y", # 08/2025
331
+ "%m-%Y", # 08-2025
332
+ "%m.%Y", # 08.2025
333
+ "%m %Y", # 08 2025
334
+ "%Y-%m", # 2025-08
335
+ "%Y/%m", # 2025/08
336
+ "%Y.%m", # 2025.08
337
+ "%Y %m", # 2025 08
338
+
339
+ # Numeric month-year (2-digit year)
340
+ "%m/%y", # 08/25
341
+ "%m-%y", # 08-25
342
+ "%m.%y", # 08.25
343
+ "%m %y", # 08 25
344
+ "%y-%m", # 25-08
345
+ "%y/%m", # 25/08
346
+
347
+ # Full month name with 2-digit year
348
+ "%B %y", # August 25
349
+ "%b %y", # Aug 25
350
+ "%B-%y", # August-25
351
+ "%b-%y", # Aug-25
352
+ ]
353
 
354
+ # Try month-year formats → output as MMM-yyyy (no day)
355
+ for fmt in month_year_formats:
356
+ try:
357
+ parsed_date = datetime.strptime(str(date_str), fmt)
358
+ return parsed_date.strftime("%b-%Y") # Aug-2025 format
359
+ except (ValueError, TypeError):
360
+ continue
361
+
362
+ # If no format matched, return empty string
363
+ return ""
364
 
365
  def parse_date_to_object(date_str):
366
  """
 
458
  "%d%m%Y", # 15012025
459
  "%m%d%Y", # 01152025
460
  "%Y%d%m", # 20251501
461
+
462
+ # ========== MONTH-YEAR ONLY FORMATS (defaults to 1st of month) ==========
463
+ # Full month name with year
464
+ "%B %Y", # August 2025
465
+ "%b %Y", # Aug 2025
466
+ "%B, %Y", # August, 2025
467
+ "%b, %Y", # Aug, 2025
468
+ "%B-%Y", # August-2025
469
+ "%b-%Y", # Aug-2025
470
+ "%B/%Y", # August/2025
471
+ "%b/%Y", # Aug/2025
472
+
473
+ # Numeric month-year (4-digit year)
474
+ "%m/%Y", # 08/2025
475
+ "%m-%Y", # 08-2025
476
+ "%m.%Y", # 08.2025
477
+ "%m %Y", # 08 2025
478
+ "%Y-%m", # 2025-08
479
+ "%Y/%m", # 2025/08
480
+ "%Y.%m", # 2025.08
481
+ "%Y %m", # 2025 08
482
+
483
+ # Numeric month-year (2-digit year)
484
+ "%m/%y", # 08/25
485
+ "%m-%y", # 08-25
486
+ "%m.%y", # 08.25
487
+ "%m %y", # 08 25
488
+ "%y-%m", # 25-08
489
+ "%y/%m", # 25/08
490
+
491
+ # Full month name with 2-digit year
492
+ "%B %y", # August 25
493
+ "%b %y", # Aug 25
494
+ "%B-%y", # August-25
495
+ "%b-%y", # Aug-25
496
  ]
497
 
498
  # Try parsing with each format
 
518
 
519
  return None
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  # -----------------------------
523
  # vLLM Inference Function (RunPod API)
 
553
  "quantity": "Quantity of items",
554
  "unit_price": "Price per unit",
555
  "amount": "Total amount for this line item",
556
+ "t_rate": "tax_rate",
557
+ "tax": "amount*t_rate/100",
558
  "Line_total": "Total amount including tax for this line"
559
  }
560
  ],
 
581
  - Extract text exactly as it appears, including special characters and formatting
582
  - For dates, preserve the original format shown in the invoice
583
  - If both sender and receiver addresses are in the United States, extract ACH; otherwise extract Wire transfer (WT).
584
+ - If payment terms specify a number of days (e.g., "payment terms 30 days", "payable within 15 days", "terms 45 days", "Net 30", or any similar phrase), compute: due_date = invoice_date + N days. If the invoice states "due on receipt", "due upon receipt" ,"Immediate" or any similar phrase meaning immediate payment, then: due_date = invoice_date. Use the same date format as the invoice. Output only the computed due_date.
585
  - if tax_rate is not given in invoice but tax_amount is given, calculate the tax_rate using tax_amount and subtotal.
586
  - line-item wise tax calculation has to be done properly based ONLY on the tax_rate given in the summary, and the same tax_rate must be used for every line item in that invoice.
587
  - If currency symbols are present, note them appropriately
 
666
  def parse_vllm_json(raw_json_text):
667
  """Parse vLLM JSON output into structured format"""
668
  try:
669
+ # Try to parse the JSON - handle potential markdown code blocks
670
+ text_to_parse = raw_json_text.strip()
671
+
672
+ # Remove markdown code fences if present
673
+ if text_to_parse.startswith("```json"):
674
+ text_to_parse = text_to_parse[7:]
675
+ elif text_to_parse.startswith("```"):
676
+ text_to_parse = text_to_parse[3:]
677
+ if text_to_parse.endswith("```"):
678
+ text_to_parse = text_to_parse[:-3]
679
+ text_to_parse = text_to_parse.strip()
680
+
681
+ data = json.loads(text_to_parse)
682
 
683
  def clean_amount(value):
684
+ """Parse number handling both US and European formats."""
685
+ if value is None:
686
+ return 0.0
687
+ if isinstance(value, (int, float)):
688
+ return float(value)
689
+ if value == "":
690
+ return 0.0
691
+
692
+ s = str(value).strip()
693
+ if s == "":
694
+ return 0.0
695
+
696
+ # Handle negative signs (leading or trailing)
697
+ is_negative = False
698
+ if s.startswith('-'):
699
+ is_negative = True
700
+ s = s[1:].strip()
701
+ elif s.endswith('-'):
702
+ is_negative = True
703
+ s = s[:-1].strip()
704
+ elif s.startswith('(') and s.endswith(')'):
705
+ is_negative = True
706
+ s = s[1:-1].strip()
707
+
708
+ # Remove currency symbols and spaces
709
+ s = re.sub(r'[€$£¥₹\s]', '', s)
710
+
711
+ if s == "":
712
+ return 0.0
713
+
714
+ comma_count = s.count(',')
715
+ period_count = s.count('.')
716
+ last_comma = s.rfind(',')
717
+ last_period = s.rfind('.')
718
+
719
+ if comma_count > 0 and period_count > 0:
720
+ # Both present - LAST one is decimal separator
721
+ if last_comma > last_period:
722
+ # European: 1.234,56 or 3.000,2234
723
+ s = s.replace('.', '').replace(',', '.')
724
+ else:
725
+ # US: 1,234.56
726
+ s = s.replace(',', '')
727
+ elif comma_count > 0:
728
+ # Only commas
729
+ after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
730
+ if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
731
+ # European decimal: "261,49" or "9,60"
732
+ s = s.replace(',', '.')
733
+ elif len(after_last_comma) == 3:
734
+ # 3 digits after comma → thousands: "1,234"
735
+ s = s.replace(',', '')
736
+ else:
737
+ # Multiple commas = US thousands: "1,234,567"
738
+ s = s.replace(',', '')
739
+ elif period_count > 0:
740
+ # Only periods
741
+ after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
742
+ if period_count > 1:
743
+ # Multiple periods = European thousands: "1.234.567"
744
+ s = s.replace('.', '')
745
+ elif len(after_last_period) == 3 and after_last_period.isdigit():
746
+ # Single period with exactly 3 digits → European thousands: "1.000"
747
+ before_period = s[:last_period]
748
+ if before_period.isdigit() and len(before_period) <= 3:
749
+ s = s.replace('.', '')
750
+
751
+ s = re.sub(r'[^\d.]', '', s)
752
+
753
+ if s == "" or s == ".":
754
+ return 0.0
755
+
756
+ try:
757
+ result = float(s)
758
+ return -result if is_negative else result
759
+ except ValueError:
760
  return 0.0
 
761
 
762
  header = data.get("header", {})
763
  summary = data.get("summary", {})
 
1025
  return None
1026
 
1027
  def clean_number(x):
1028
+ """Parse number handling both US and European formats."""
1029
  if x is None:
1030
  return 0.0
1031
  if isinstance(x, (int, float)):
 
1033
  s = str(x).strip()
1034
  if s == "":
1035
  return 0.0
1036
+
1037
+ # Handle negative signs (leading or trailing)
1038
+ is_negative = False
1039
+ if s.startswith('-'):
1040
+ is_negative = True
1041
+ s = s[1:].strip()
1042
+ elif s.endswith('-'):
1043
+ is_negative = True
1044
+ s = s[:-1].strip()
1045
+ elif s.startswith('(') and s.endswith(')'):
1046
+ is_negative = True
1047
+ s = s[1:-1].strip()
1048
+
1049
+ # Remove currency symbols and spaces
1050
+ s = re.sub(r'[€$£¥₹\s]', '', s)
1051
+
1052
+ if s == "":
1053
  return 0.0
1054
+
1055
+ comma_count = s.count(',')
1056
+ period_count = s.count('.')
1057
+ last_comma = s.rfind(',')
1058
+ last_period = s.rfind('.')
1059
+
1060
+ if comma_count > 0 and period_count > 0:
1061
+ # Both present - LAST one is decimal separator
1062
+ if last_comma > last_period:
1063
+ # European: 1.234,56 or 3.000,2234
1064
+ s = s.replace('.', '').replace(',', '.')
1065
+ else:
1066
+ # US: 1,234.56
1067
+ s = s.replace(',', '')
1068
+ elif comma_count > 0:
1069
+ # Only commas
1070
+ after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
1071
+ if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
1072
+ # European decimal: "261,49" or "9,60"
1073
+ s = s.replace(',', '.')
1074
+ elif len(after_last_comma) == 3:
1075
+ # 3 digits after comma → thousands: "1,234"
1076
+ s = s.replace(',', '')
1077
+ else:
1078
+ # Multiple commas = US thousands: "1,234,567"
1079
+ s = s.replace(',', '')
1080
+ elif period_count > 0:
1081
+ # Only periods
1082
+ after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
1083
+ if period_count > 1:
1084
+ # Multiple periods = European thousands: "1.234.567"
1085
+ s = s.replace('.', '')
1086
+ elif len(after_last_period) == 3 and after_last_period.isdigit():
1087
+ # Single period with exactly 3 digits → European thousands: "1.000"
1088
+ before_period = s[:last_period]
1089
+ if before_period.isdigit() and len(before_period) <= 3:
1090
+ s = s.replace('.', '')
1091
+
1092
+ s = re.sub(r'[^\d.]', '', s)
1093
+
1094
+ if s == "" or s == ".":
1095
+ return 0.0
1096
+
1097
  try:
1098
+ result = float(s)
1099
+ return -result if is_negative else result
1100
+ except ValueError:
1101
  return 0.0
1102
 
1103
  def collect_keys(obj, out):
 
1368
  rows.append(row)
1369
  return rows
1370
 
 
 
 
 
 
 
 
 
 
 
1371
 
1372
  # -----------------------------
1373
  # Session scaffolding
 
1438
  continue
1439
 
1440
  # vLLM Inference + parsing + tax validation
1441
+ raw_json = None
1442
+ mapped = {}
1443
  try:
1444
  # Call vLLM API
1445
  raw_json = run_inference_vllm(image)
 
1458
  st.warning(f"No response from vLLM for {uploaded_file.name}")
1459
  mapped = {}
1460
 
 
1461
  except Exception as e:
1462
  st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
1463
+ raw_json = None
1464
  mapped = {}
1465
 
1466
  safe_mapped = mapped if isinstance(mapped, dict) else {}
1467
 
1468
+ # Store BOTH raw string AND parsed dict for display
1469
  st.session_state.batch_results[file_hash] = {
1470
  "file_name": uploaded_file.name,
1471
  "image": image,
1472
+ "raw_pred": raw_json, # Original string from API (untouched)
1473
  "mapped_data": safe_mapped,
1474
  "edited_data": safe_mapped.copy()
1475
  }
 
1604
  with frame_left:
1605
  st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH)
1606
  st.write(f"**File Hash:** {selected_hash[:8]}...")
1607
+
1608
+ # ============ RAW MODEL OUTPUT DISPLAY (UNTOUCHED) ============
1609
+ with st.expander("🔍 Show raw model output"):
1610
+ raw_pred = current.get('raw_pred')
1611
+
1612
+ if raw_pred is None:
1613
+ st.warning("No raw output available (API may have returned None)")
1614
+ else:
1615
+ # Show raw output exactly as received from the model - UNTOUCHED
1616
+ st.code(str(raw_pred), language='json')
1617
+ # ==============================================================
1618
 
1619
  if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"):
1620
  with st.spinner("Re-running inference..."):
 
1637
  mapped = {}
1638
 
1639
  safe_mapped = mapped if isinstance(mapped, dict) else {}
 
1640
 
1641
  # Update stored results
1642
+ st.session_state.batch_results[selected_hash]["raw_pred"] = raw_json
1643
  st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
1644
  st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
1645
 
 
1687
  if st.session_state.get(f"Currency_{selected_hash}") == 'Other':
1688
  st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}")
1689
 
1690
+ st.number_input("Subtotal", key=f"Subtotal_{selected_hash}", format="%.2f")
1691
+ st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}", format="%.4f")
1692
+ st.number_input("Total Tax", key=f"Total Tax_{selected_hash}", format="%.2f")
1693
+ st.number_input("Total Amount", key=f"Total Amount_{selected_hash}", format="%.2f")
1694
 
1695
  with tabs[1]:
1696
  st.text_input("Sender Name", key=f"Sender Name_{selected_hash}")
 
1803
 
1804
  st.dataframe(
1805
  totals_df,
1806
+ use_container_width=True,
1807
  hide_index=True,
1808
  height=38
1809
  )
 
1853
  calculated_tax_pct = round((calculated_total_tax / calculated_subtotal) * 100, 4)
1854
 
1855
  if saved:
1856
+ # Build updated data structure using ACTUAL user-entered values from form
1857
  updated = {
1858
  'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
1859
  'Invoice Date': invoice_date_str,
1860
  'Due Date': due_date_str,
1861
  'Currency': currency,
1862
+ 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
1863
+ 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
1864
+ 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
1865
+ 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
1866
  'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
1867
  'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
1868
  'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
 
1886
  # Save to batch_results (this persists the data)
1887
  st.session_state.batch_results[selected_hash]["edited_data"] = updated
1888
 
1889
+ # CRITICAL: Clear ALL session state keys for this file so they reload from saved edited_data
1890
+ keys_to_delete = [k for k in list(st.session_state.keys()) if k.endswith(f"_{selected_hash}")]
1891
+ for key in keys_to_delete:
1892
+ del st.session_state[key]
1893
 
1894
  # Show success message
1895
  st.success("✅ Saved")
 
1897
  # Rerun to reload the form with saved data
1898
  st.rerun()
1899
 
1900
+ # Per-file CSV download (ALWAYS visible, uses current form values)
1901
  download_data = {
1902
  'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
1903
  'Invoice Date': invoice_date_str,
1904
  'Due Date': due_date_str,
1905
  'Currency': currency,
1906
+ 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
1907
+ 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
1908
+ 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
1909
+ 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
1910
  'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
1911
  'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
1912
  'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),