Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +383 -92
src/streamlit_app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
# =========================
|
| 2 |
# Invoice Extractor (Qwen3-VL via RunPod vLLM) - Batch Mode with Tax Validation
|
| 3 |
-
# UPDATED:
|
| 4 |
-
# FIX: Tax calculation skips both empty ("") and explicit zero (0.00) values
|
| 5 |
# =========================
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
|
@@ -90,27 +89,118 @@ def ensure_state(k: str, default):
|
|
| 90 |
st.session_state[k] = default
|
| 91 |
|
| 92 |
def clean_float(x) -> float:
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
if x is None:
|
| 95 |
return 0.0
|
| 96 |
if isinstance(x, (int, float)):
|
| 97 |
return float(x)
|
|
|
|
| 98 |
s = str(x).strip()
|
| 99 |
if s == "":
|
| 100 |
return 0.0
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
try:
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
return 0.0
|
| 109 |
|
| 110 |
def normalize_date(date_str) -> str:
|
| 111 |
"""
|
| 112 |
-
Normalize various date formats
|
| 113 |
-
|
|
|
|
| 114 |
Returns empty string if date cannot be parsed
|
| 115 |
"""
|
| 116 |
if not date_str or date_str == "":
|
|
@@ -121,8 +211,8 @@ def normalize_date(date_str) -> str:
|
|
| 121 |
if date_str == "":
|
| 122 |
return ""
|
| 123 |
|
| 124 |
-
#
|
| 125 |
-
|
| 126 |
# ISO formats (4-digit year)
|
| 127 |
"%Y-%m-%d", # 2025-01-15
|
| 128 |
"%Y/%m/%d", # 2025/01/15
|
|
@@ -162,7 +252,7 @@ def normalize_date(date_str) -> str:
|
|
| 162 |
|
| 163 |
# European formats with 2-digit year - Day first
|
| 164 |
"%d-%m-%y", # 15-01-25
|
| 165 |
-
"%d/%m/%y", # 15/01/25
|
| 166 |
"%d.%m.%y", # 15.01.25
|
| 167 |
"%d %m %y", # 15 01 25
|
| 168 |
|
|
@@ -205,35 +295,72 @@ def normalize_date(date_str) -> str:
|
|
| 205 |
"%Y%d%m", # 20251501
|
| 206 |
]
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
# Try parsing with each format
|
| 211 |
-
for fmt in formats:
|
| 212 |
try:
|
| 213 |
parsed_date = datetime.strptime(str(date_str), fmt)
|
| 214 |
-
|
| 215 |
except (ValueError, TypeError):
|
| 216 |
continue
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
if
|
| 220 |
-
import re
|
| 221 |
cleaned = re.sub(r'(\d+)(st|nd|rd|th)\b', r'\1', date_str, flags=re.IGNORECASE)
|
| 222 |
-
|
| 223 |
if cleaned != date_str:
|
| 224 |
-
for fmt in
|
| 225 |
try:
|
| 226 |
parsed_date = datetime.strptime(cleaned, fmt)
|
| 227 |
-
|
| 228 |
except (ValueError, TypeError):
|
| 229 |
continue
|
| 230 |
|
| 231 |
-
#
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
def parse_date_to_object(date_str):
|
| 239 |
"""
|
|
@@ -331,6 +458,41 @@ def parse_date_to_object(date_str):
|
|
| 331 |
"%d%m%Y", # 15012025
|
| 332 |
"%m%d%Y", # 01152025
|
| 333 |
"%Y%d%m", # 20251501
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
]
|
| 335 |
|
| 336 |
# Try parsing with each format
|
|
@@ -356,22 +518,6 @@ def parse_date_to_object(date_str):
|
|
| 356 |
|
| 357 |
return None
|
| 358 |
|
| 359 |
-
# -----------------------------
|
| 360 |
-
# HF login flow (REMOVED - No longer needed for vLLM API)
|
| 361 |
-
# -----------------------------
|
| 362 |
-
# Authentication is now handled via POD_URL and VLLM_API_KEY instead
|
| 363 |
-
|
| 364 |
-
# -----------------------------
|
| 365 |
-
# Model config
|
| 366 |
-
# -----------------------------
|
| 367 |
-
# OLD DONUT CODE (COMMENTED OUT - Now using vLLM API)
|
| 368 |
-
# -----------------------------
|
| 369 |
-
# HF_MODEL_ID = "Bhuvi13/model-V7"
|
| 370 |
-
# TASK_PROMPT = "<s_cord-v2>"
|
| 371 |
-
#
|
| 372 |
-
# @st.cache_resource(show_spinner=False)
|
| 373 |
-
# def load_model_and_processor(hf_model_id: str, task_prompt: str):
|
| 374 |
-
# ...
|
| 375 |
|
| 376 |
# -----------------------------
|
| 377 |
# vLLM Inference Function (RunPod API)
|
|
@@ -407,7 +553,8 @@ Extract the data into this exact JSON structure:
|
|
| 407 |
"quantity": "Quantity of items",
|
| 408 |
"unit_price": "Price per unit",
|
| 409 |
"amount": "Total amount for this line item",
|
| 410 |
-
"
|
|
|
|
| 411 |
"Line_total": "Total amount including tax for this line"
|
| 412 |
}
|
| 413 |
],
|
|
@@ -434,7 +581,7 @@ IMPORTANT GUIDELINES:
|
|
| 434 |
- Extract text exactly as it appears, including special characters and formatting
|
| 435 |
- For dates, preserve the original format shown in the invoice
|
| 436 |
- If both sender and receiver addresses are in the United States, extract ACH; otherwise extract Wire transfer (WT).
|
| 437 |
-
- If payment terms specify a number of days (e.g.,
|
| 438 |
- if tax_rate is not given in invoice but tax_amount is given, calculate the tax_rate using tax_amount and subtotal.
|
| 439 |
- line-item wise tax calculation has to be done properly based ONLY on the tax_rate given in the summary, and the same tax_rate must be used for every line item in that invoice.
|
| 440 |
- If currency symbols are present, note them appropriately
|
|
@@ -519,12 +666,98 @@ Return only the JSON object with the extracted information"""
|
|
| 519 |
def parse_vllm_json(raw_json_text):
|
| 520 |
"""Parse vLLM JSON output into structured format"""
|
| 521 |
try:
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
def clean_amount(value):
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
return 0.0
|
| 527 |
-
return float(re.sub(r'[^\d\.-]', '', str(value)))
|
| 528 |
|
| 529 |
header = data.get("header", {})
|
| 530 |
summary = data.get("summary", {})
|
|
@@ -792,6 +1025,7 @@ def map_prediction_to_ui(pred):
|
|
| 792 |
return None
|
| 793 |
|
| 794 |
def clean_number(x):
|
|
|
|
| 795 |
if x is None:
|
| 796 |
return 0.0
|
| 797 |
if isinstance(x, (int, float)):
|
|
@@ -799,13 +1033,71 @@ def map_prediction_to_ui(pred):
|
|
| 799 |
s = str(x).strip()
|
| 800 |
if s == "":
|
| 801 |
return 0.0
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
try:
|
| 807 |
-
|
| 808 |
-
|
|
|
|
| 809 |
return 0.0
|
| 810 |
|
| 811 |
def collect_keys(obj, out):
|
|
@@ -1076,16 +1368,6 @@ def flatten_invoice_to_rows(invoice_data) -> list:
|
|
| 1076 |
rows.append(row)
|
| 1077 |
return rows
|
| 1078 |
|
| 1079 |
-
# -----------------------------
|
| 1080 |
-
# Load model (COMMENTED OUT - Now using vLLM API)
|
| 1081 |
-
# -----------------------------
|
| 1082 |
-
# try:
|
| 1083 |
-
# with st.spinner("Loading model & processor (cached) ..."):
|
| 1084 |
-
# processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT)
|
| 1085 |
-
# except Exception as e:
|
| 1086 |
-
# st.error("Could not load model automatically. See details below.")
|
| 1087 |
-
# st.exception(e)
|
| 1088 |
-
# st.stop()
|
| 1089 |
|
| 1090 |
# -----------------------------
|
| 1091 |
# Session scaffolding
|
|
@@ -1156,6 +1438,8 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
|
|
| 1156 |
continue
|
| 1157 |
|
| 1158 |
# vLLM Inference + parsing + tax validation
|
|
|
|
|
|
|
| 1159 |
try:
|
| 1160 |
# Call vLLM API
|
| 1161 |
raw_json = run_inference_vllm(image)
|
|
@@ -1174,18 +1458,18 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
|
|
| 1174 |
st.warning(f"No response from vLLM for {uploaded_file.name}")
|
| 1175 |
mapped = {}
|
| 1176 |
|
| 1177 |
-
pred = raw_json # Store raw JSON for debugging
|
| 1178 |
except Exception as e:
|
| 1179 |
st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
|
| 1180 |
-
|
| 1181 |
mapped = {}
|
| 1182 |
|
| 1183 |
safe_mapped = mapped if isinstance(mapped, dict) else {}
|
| 1184 |
|
|
|
|
| 1185 |
st.session_state.batch_results[file_hash] = {
|
| 1186 |
"file_name": uploaded_file.name,
|
| 1187 |
"image": image,
|
| 1188 |
-
"raw_pred":
|
| 1189 |
"mapped_data": safe_mapped,
|
| 1190 |
"edited_data": safe_mapped.copy()
|
| 1191 |
}
|
|
@@ -1320,9 +1604,17 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1320 |
with frame_left:
|
| 1321 |
st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH)
|
| 1322 |
st.write(f"**File Hash:** {selected_hash[:8]}...")
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1326 |
|
| 1327 |
if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"):
|
| 1328 |
with st.spinner("Re-running inference..."):
|
|
@@ -1345,10 +1637,9 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1345 |
mapped = {}
|
| 1346 |
|
| 1347 |
safe_mapped = mapped if isinstance(mapped, dict) else {}
|
| 1348 |
-
pred = raw_json # Store raw JSON
|
| 1349 |
|
| 1350 |
# Update stored results
|
| 1351 |
-
st.session_state.batch_results[selected_hash]["raw_pred"] =
|
| 1352 |
st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
|
| 1353 |
st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
|
| 1354 |
|
|
@@ -1396,10 +1687,10 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1396 |
if st.session_state.get(f"Currency_{selected_hash}") == 'Other':
|
| 1397 |
st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}")
|
| 1398 |
|
| 1399 |
-
st.number_input("Subtotal", key=f"Subtotal_{selected_hash}")
|
| 1400 |
-
st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}")
|
| 1401 |
-
st.number_input("Total Tax", key=f"Total Tax_{selected_hash}")
|
| 1402 |
-
st.number_input("Total Amount", key=f"Total Amount_{selected_hash}")
|
| 1403 |
|
| 1404 |
with tabs[1]:
|
| 1405 |
st.text_input("Sender Name", key=f"Sender Name_{selected_hash}")
|
|
@@ -1512,7 +1803,7 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1512 |
|
| 1513 |
st.dataframe(
|
| 1514 |
totals_df,
|
| 1515 |
-
|
| 1516 |
hide_index=True,
|
| 1517 |
height=38
|
| 1518 |
)
|
|
@@ -1562,16 +1853,16 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1562 |
calculated_tax_pct = round((calculated_total_tax / calculated_subtotal) * 100, 4)
|
| 1563 |
|
| 1564 |
if saved:
|
| 1565 |
-
# Build updated data structure
|
| 1566 |
updated = {
|
| 1567 |
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
|
| 1568 |
'Invoice Date': invoice_date_str,
|
| 1569 |
'Due Date': due_date_str,
|
| 1570 |
'Currency': currency,
|
| 1571 |
-
'Subtotal':
|
| 1572 |
-
'Tax Percentage':
|
| 1573 |
-
'Total Tax':
|
| 1574 |
-
'Total Amount':
|
| 1575 |
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
|
| 1576 |
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
|
| 1577 |
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
|
|
@@ -1595,10 +1886,10 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1595 |
# Save to batch_results (this persists the data)
|
| 1596 |
st.session_state.batch_results[selected_hash]["edited_data"] = updated
|
| 1597 |
|
| 1598 |
-
# CRITICAL: Clear
|
| 1599 |
-
|
| 1600 |
-
|
| 1601 |
-
del st.session_state[
|
| 1602 |
|
| 1603 |
# Show success message
|
| 1604 |
st.success("✅ Saved")
|
|
@@ -1606,16 +1897,16 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 1606 |
# Rerun to reload the form with saved data
|
| 1607 |
st.rerun()
|
| 1608 |
|
| 1609 |
-
# Per-file CSV download (ALWAYS visible, uses current
|
| 1610 |
download_data = {
|
| 1611 |
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
|
| 1612 |
'Invoice Date': invoice_date_str,
|
| 1613 |
'Due Date': due_date_str,
|
| 1614 |
'Currency': currency,
|
| 1615 |
-
'Subtotal':
|
| 1616 |
-
'Tax Percentage':
|
| 1617 |
-
'Total Tax':
|
| 1618 |
-
'Total Amount':
|
| 1619 |
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
|
| 1620 |
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
|
| 1621 |
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
|
|
|
|
| 1 |
# =========================
|
| 2 |
# Invoice Extractor (Qwen3-VL via RunPod vLLM) - Batch Mode with Tax Validation
|
| 3 |
+
# UPDATED: Fixed raw model output display
|
|
|
|
| 4 |
# =========================
|
| 5 |
import os
|
| 6 |
from pathlib import Path
|
|
|
|
| 89 |
st.session_state[k] = default
|
| 90 |
|
| 91 |
def clean_float(x) -> float:
|
| 92 |
+
"""
|
| 93 |
+
Parse a number string handling both US and European formats.
|
| 94 |
+
|
| 95 |
+
US Format: 1,234,567.89 (comma = thousands, period = decimal)
|
| 96 |
+
European: 1.234.567,89 (period = thousands, comma = decimal)
|
| 97 |
+
|
| 98 |
+
Examples:
|
| 99 |
+
"1,234.56" → 1234.56 (US)
|
| 100 |
+
"1.234,56" → 1234.56 (European)
|
| 101 |
+
"3.000,2234" → 3000.2234 (European with 4 decimal places)
|
| 102 |
+
"261,49" → 261.49 (European decimal only)
|
| 103 |
+
"39,22-" → -39.22 (European with trailing minus)
|
| 104 |
+
"""
|
| 105 |
if x is None:
|
| 106 |
return 0.0
|
| 107 |
if isinstance(x, (int, float)):
|
| 108 |
return float(x)
|
| 109 |
+
|
| 110 |
s = str(x).strip()
|
| 111 |
if s == "":
|
| 112 |
return 0.0
|
| 113 |
+
|
| 114 |
+
# Handle negative signs (could be leading or trailing)
|
| 115 |
+
is_negative = False
|
| 116 |
+
if s.startswith('-'):
|
| 117 |
+
is_negative = True
|
| 118 |
+
s = s[1:].strip()
|
| 119 |
+
elif s.endswith('-'):
|
| 120 |
+
is_negative = True
|
| 121 |
+
s = s[:-1].strip()
|
| 122 |
+
elif s.startswith('(') and s.endswith(')'):
|
| 123 |
+
# Accounting format: (123.45) means negative
|
| 124 |
+
is_negative = True
|
| 125 |
+
s = s[1:-1].strip()
|
| 126 |
+
|
| 127 |
+
# Remove currency symbols and spaces
|
| 128 |
+
s = re.sub(r'[€$£¥₹\s]', '', s)
|
| 129 |
+
|
| 130 |
+
if s == "":
|
| 131 |
return 0.0
|
| 132 |
+
|
| 133 |
+
# Count occurrences
|
| 134 |
+
comma_count = s.count(',')
|
| 135 |
+
period_count = s.count('.')
|
| 136 |
+
|
| 137 |
+
# Find positions of last comma and last period
|
| 138 |
+
last_comma = s.rfind(',')
|
| 139 |
+
last_period = s.rfind('.')
|
| 140 |
+
|
| 141 |
+
# Determine format based on which separator comes last
|
| 142 |
+
if comma_count > 0 and period_count > 0:
|
| 143 |
+
# Both separators present - the LAST one is the decimal separator
|
| 144 |
+
if last_comma > last_period:
|
| 145 |
+
# European format: 1.234,56 → comma is decimal
|
| 146 |
+
# Remove periods (thousands), replace comma with period
|
| 147 |
+
s = s.replace('.', '').replace(',', '.')
|
| 148 |
+
else:
|
| 149 |
+
# US format: 1,234.56 → period is decimal
|
| 150 |
+
# Remove commas (thousands)
|
| 151 |
+
s = s.replace(',', '')
|
| 152 |
+
|
| 153 |
+
elif comma_count > 0 and period_count == 0:
|
| 154 |
+
# Only commas present
|
| 155 |
+
# Check what comes after the LAST comma
|
| 156 |
+
after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
|
| 157 |
+
|
| 158 |
+
if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
|
| 159 |
+
# Single comma with 1-4 digits after → European decimal
|
| 160 |
+
# "261,49" → 261.49, "1234,5678" → 1234.5678
|
| 161 |
+
s = s.replace(',', '.')
|
| 162 |
+
elif len(after_last_comma) == 3 and comma_count >= 1:
|
| 163 |
+
# 3 digits after comma(s) → likely thousands separator
|
| 164 |
+
# "1,234" → 1234, "1,234,567" → 1234567
|
| 165 |
+
s = s.replace(',', '')
|
| 166 |
+
else:
|
| 167 |
+
# Multiple commas → thousands separator
|
| 168 |
+
# "1,234,567" → 1234567
|
| 169 |
+
s = s.replace(',', '')
|
| 170 |
+
|
| 171 |
+
elif period_count > 0 and comma_count == 0:
|
| 172 |
+
# Only periods present
|
| 173 |
+
# Check what comes after the LAST period
|
| 174 |
+
after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
|
| 175 |
+
|
| 176 |
+
if period_count > 1:
|
| 177 |
+
# Multiple periods → definitely thousands separator (European: "1.234.567")
|
| 178 |
+
s = s.replace('.', '')
|
| 179 |
+
elif len(after_last_period) == 3 and after_last_period.isdigit():
|
| 180 |
+
# Single period with exactly 3 digits after → European thousands: "1.000" → 1000
|
| 181 |
+
# (In invoices, "1.000" almost always means 1000, not 1.0 with trailing zeros)
|
| 182 |
+
before_period = s[:last_period]
|
| 183 |
+
if before_period.isdigit() and len(before_period) <= 3:
|
| 184 |
+
s = s.replace('.', '')
|
| 185 |
+
# Otherwise keep as is (standard decimal like "1.50", "123.45")
|
| 186 |
+
|
| 187 |
+
# Clean any remaining non-numeric characters except period and minus
|
| 188 |
+
s = re.sub(r'[^\d.]', '', s)
|
| 189 |
+
|
| 190 |
+
if s == "" or s == ".":
|
| 191 |
+
return 0.0
|
| 192 |
+
|
| 193 |
try:
|
| 194 |
+
result = float(s)
|
| 195 |
+
return -result if is_negative else result
|
| 196 |
+
except ValueError:
|
| 197 |
return 0.0
|
| 198 |
|
| 199 |
def normalize_date(date_str) -> str:
|
| 200 |
"""
|
| 201 |
+
Normalize various date formats:
|
| 202 |
+
- Full dates (day-month-year) → dd-MMM-yyyy (e.g., 01-Jan-2025)
|
| 203 |
+
- Month-year only → MMM-yyyy (e.g., Aug-2025)
|
| 204 |
Returns empty string if date cannot be parsed
|
| 205 |
"""
|
| 206 |
if not date_str or date_str == "":
|
|
|
|
| 211 |
if date_str == "":
|
| 212 |
return ""
|
| 213 |
|
| 214 |
+
# FULL DATE FORMATS (day-month-year) - try these first
|
| 215 |
+
full_date_formats = [
|
| 216 |
# ISO formats (4-digit year)
|
| 217 |
"%Y-%m-%d", # 2025-01-15
|
| 218 |
"%Y/%m/%d", # 2025/01/15
|
|
|
|
| 252 |
|
| 253 |
# European formats with 2-digit year - Day first
|
| 254 |
"%d-%m-%y", # 15-01-25
|
| 255 |
+
"%d/%m/%y", # 15/01/25
|
| 256 |
"%d.%m.%y", # 15.01.25
|
| 257 |
"%d %m %y", # 15 01 25
|
| 258 |
|
|
|
|
| 295 |
"%Y%d%m", # 20251501
|
| 296 |
]
|
| 297 |
|
| 298 |
+
# Try full date formats first → output as dd-MMM-yyyy
|
| 299 |
+
for fmt in full_date_formats:
|
|
|
|
|
|
|
| 300 |
try:
|
| 301 |
parsed_date = datetime.strptime(str(date_str), fmt)
|
| 302 |
+
return parsed_date.strftime("%d-%b-%Y")
|
| 303 |
except (ValueError, TypeError):
|
| 304 |
continue
|
| 305 |
|
| 306 |
+
# Try with ordinal suffixes removed (1st, 2nd, 3rd, etc.)
|
| 307 |
+
if isinstance(date_str, str):
|
|
|
|
| 308 |
cleaned = re.sub(r'(\d+)(st|nd|rd|th)\b', r'\1', date_str, flags=re.IGNORECASE)
|
|
|
|
| 309 |
if cleaned != date_str:
|
| 310 |
+
for fmt in full_date_formats:
|
| 311 |
try:
|
| 312 |
parsed_date = datetime.strptime(cleaned, fmt)
|
| 313 |
+
return parsed_date.strftime("%d-%b-%Y")
|
| 314 |
except (ValueError, TypeError):
|
| 315 |
continue
|
| 316 |
|
| 317 |
+
# MONTH-YEAR ONLY FORMATS - output as MMM-yyyy
|
| 318 |
+
month_year_formats = [
|
| 319 |
+
# Full month name with year
|
| 320 |
+
"%B %Y", # August 2025
|
| 321 |
+
"%b %Y", # Aug 2025
|
| 322 |
+
"%B, %Y", # August, 2025
|
| 323 |
+
"%b, %Y", # Aug, 2025
|
| 324 |
+
"%B-%Y", # August-2025
|
| 325 |
+
"%b-%Y", # Aug-2025
|
| 326 |
+
"%B/%Y", # August/2025
|
| 327 |
+
"%b/%Y", # Aug/2025
|
| 328 |
+
|
| 329 |
+
# Numeric month-year (4-digit year)
|
| 330 |
+
"%m/%Y", # 08/2025
|
| 331 |
+
"%m-%Y", # 08-2025
|
| 332 |
+
"%m.%Y", # 08.2025
|
| 333 |
+
"%m %Y", # 08 2025
|
| 334 |
+
"%Y-%m", # 2025-08
|
| 335 |
+
"%Y/%m", # 2025/08
|
| 336 |
+
"%Y.%m", # 2025.08
|
| 337 |
+
"%Y %m", # 2025 08
|
| 338 |
+
|
| 339 |
+
# Numeric month-year (2-digit year)
|
| 340 |
+
"%m/%y", # 08/25
|
| 341 |
+
"%m-%y", # 08-25
|
| 342 |
+
"%m.%y", # 08.25
|
| 343 |
+
"%m %y", # 08 25
|
| 344 |
+
"%y-%m", # 25-08
|
| 345 |
+
"%y/%m", # 25/08
|
| 346 |
+
|
| 347 |
+
# Full month name with 2-digit year
|
| 348 |
+
"%B %y", # August 25
|
| 349 |
+
"%b %y", # Aug 25
|
| 350 |
+
"%B-%y", # August-25
|
| 351 |
+
"%b-%y", # Aug-25
|
| 352 |
+
]
|
| 353 |
|
| 354 |
+
# Try month-year formats → output as MMM-yyyy (no day)
|
| 355 |
+
for fmt in month_year_formats:
|
| 356 |
+
try:
|
| 357 |
+
parsed_date = datetime.strptime(str(date_str), fmt)
|
| 358 |
+
return parsed_date.strftime("%b-%Y") # Aug-2025 format
|
| 359 |
+
except (ValueError, TypeError):
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
# If no format matched, return empty string
|
| 363 |
+
return ""
|
| 364 |
|
| 365 |
def parse_date_to_object(date_str):
|
| 366 |
"""
|
|
|
|
| 458 |
"%d%m%Y", # 15012025
|
| 459 |
"%m%d%Y", # 01152025
|
| 460 |
"%Y%d%m", # 20251501
|
| 461 |
+
|
| 462 |
+
# ========== MONTH-YEAR ONLY FORMATS (defaults to 1st of month) ==========
|
| 463 |
+
# Full month name with year
|
| 464 |
+
"%B %Y", # August 2025
|
| 465 |
+
"%b %Y", # Aug 2025
|
| 466 |
+
"%B, %Y", # August, 2025
|
| 467 |
+
"%b, %Y", # Aug, 2025
|
| 468 |
+
"%B-%Y", # August-2025
|
| 469 |
+
"%b-%Y", # Aug-2025
|
| 470 |
+
"%B/%Y", # August/2025
|
| 471 |
+
"%b/%Y", # Aug/2025
|
| 472 |
+
|
| 473 |
+
# Numeric month-year (4-digit year)
|
| 474 |
+
"%m/%Y", # 08/2025
|
| 475 |
+
"%m-%Y", # 08-2025
|
| 476 |
+
"%m.%Y", # 08.2025
|
| 477 |
+
"%m %Y", # 08 2025
|
| 478 |
+
"%Y-%m", # 2025-08
|
| 479 |
+
"%Y/%m", # 2025/08
|
| 480 |
+
"%Y.%m", # 2025.08
|
| 481 |
+
"%Y %m", # 2025 08
|
| 482 |
+
|
| 483 |
+
# Numeric month-year (2-digit year)
|
| 484 |
+
"%m/%y", # 08/25
|
| 485 |
+
"%m-%y", # 08-25
|
| 486 |
+
"%m.%y", # 08.25
|
| 487 |
+
"%m %y", # 08 25
|
| 488 |
+
"%y-%m", # 25-08
|
| 489 |
+
"%y/%m", # 25/08
|
| 490 |
+
|
| 491 |
+
# Full month name with 2-digit year
|
| 492 |
+
"%B %y", # August 25
|
| 493 |
+
"%b %y", # Aug 25
|
| 494 |
+
"%B-%y", # August-25
|
| 495 |
+
"%b-%y", # Aug-25
|
| 496 |
]
|
| 497 |
|
| 498 |
# Try parsing with each format
|
|
|
|
| 518 |
|
| 519 |
return None
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
# -----------------------------
|
| 523 |
# vLLM Inference Function (RunPod API)
|
|
|
|
| 553 |
"quantity": "Quantity of items",
|
| 554 |
"unit_price": "Price per unit",
|
| 555 |
"amount": "Total amount for this line item",
|
| 556 |
+
"t_rate": "tax_rate",
|
| 557 |
+
"tax": "amount*t_rate/100",
|
| 558 |
"Line_total": "Total amount including tax for this line"
|
| 559 |
}
|
| 560 |
],
|
|
|
|
| 581 |
- Extract text exactly as it appears, including special characters and formatting
|
| 582 |
- For dates, preserve the original format shown in the invoice
|
| 583 |
- If both sender and receiver addresses are in the United States, extract ACH; otherwise extract Wire transfer (WT).
|
| 584 |
+
- If payment terms specify a number of days (e.g., "payment terms 30 days", "payable within 15 days", "terms 45 days", "Net 30", or any similar phrase), compute: due_date = invoice_date + N days. If the invoice states "due on receipt", "due upon receipt" ,"Immediate" or any similar phrase meaning immediate payment, then: due_date = invoice_date. Use the same date format as the invoice. Output only the computed due_date.
|
| 585 |
- if tax_rate is not given in invoice but tax_amount is given, calculate the tax_rate using tax_amount and subtotal.
|
| 586 |
- line-item wise tax calculation has to be done properly based ONLY on the tax_rate given in the summary, and the same tax_rate must be used for every line item in that invoice.
|
| 587 |
- If currency symbols are present, note them appropriately
|
|
|
|
| 666 |
def parse_vllm_json(raw_json_text):
|
| 667 |
"""Parse vLLM JSON output into structured format"""
|
| 668 |
try:
|
| 669 |
+
# Try to parse the JSON - handle potential markdown code blocks
|
| 670 |
+
text_to_parse = raw_json_text.strip()
|
| 671 |
+
|
| 672 |
+
# Remove markdown code fences if present
|
| 673 |
+
if text_to_parse.startswith("```json"):
|
| 674 |
+
text_to_parse = text_to_parse[7:]
|
| 675 |
+
elif text_to_parse.startswith("```"):
|
| 676 |
+
text_to_parse = text_to_parse[3:]
|
| 677 |
+
if text_to_parse.endswith("```"):
|
| 678 |
+
text_to_parse = text_to_parse[:-3]
|
| 679 |
+
text_to_parse = text_to_parse.strip()
|
| 680 |
+
|
| 681 |
+
data = json.loads(text_to_parse)
|
| 682 |
|
| 683 |
def clean_amount(value):
|
| 684 |
+
"""Parse number handling both US and European formats."""
|
| 685 |
+
if value is None:
|
| 686 |
+
return 0.0
|
| 687 |
+
if isinstance(value, (int, float)):
|
| 688 |
+
return float(value)
|
| 689 |
+
if value == "":
|
| 690 |
+
return 0.0
|
| 691 |
+
|
| 692 |
+
s = str(value).strip()
|
| 693 |
+
if s == "":
|
| 694 |
+
return 0.0
|
| 695 |
+
|
| 696 |
+
# Handle negative signs (leading or trailing)
|
| 697 |
+
is_negative = False
|
| 698 |
+
if s.startswith('-'):
|
| 699 |
+
is_negative = True
|
| 700 |
+
s = s[1:].strip()
|
| 701 |
+
elif s.endswith('-'):
|
| 702 |
+
is_negative = True
|
| 703 |
+
s = s[:-1].strip()
|
| 704 |
+
elif s.startswith('(') and s.endswith(')'):
|
| 705 |
+
is_negative = True
|
| 706 |
+
s = s[1:-1].strip()
|
| 707 |
+
|
| 708 |
+
# Remove currency symbols and spaces
|
| 709 |
+
s = re.sub(r'[€$£¥₹\s]', '', s)
|
| 710 |
+
|
| 711 |
+
if s == "":
|
| 712 |
+
return 0.0
|
| 713 |
+
|
| 714 |
+
comma_count = s.count(',')
|
| 715 |
+
period_count = s.count('.')
|
| 716 |
+
last_comma = s.rfind(',')
|
| 717 |
+
last_period = s.rfind('.')
|
| 718 |
+
|
| 719 |
+
if comma_count > 0 and period_count > 0:
|
| 720 |
+
# Both present - LAST one is decimal separator
|
| 721 |
+
if last_comma > last_period:
|
| 722 |
+
# European: 1.234,56 or 3.000,2234
|
| 723 |
+
s = s.replace('.', '').replace(',', '.')
|
| 724 |
+
else:
|
| 725 |
+
# US: 1,234.56
|
| 726 |
+
s = s.replace(',', '')
|
| 727 |
+
elif comma_count > 0:
|
| 728 |
+
# Only commas
|
| 729 |
+
after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
|
| 730 |
+
if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
|
| 731 |
+
# European decimal: "261,49" or "9,60"
|
| 732 |
+
s = s.replace(',', '.')
|
| 733 |
+
elif len(after_last_comma) == 3:
|
| 734 |
+
# 3 digits after comma → thousands: "1,234"
|
| 735 |
+
s = s.replace(',', '')
|
| 736 |
+
else:
|
| 737 |
+
# Multiple commas = US thousands: "1,234,567"
|
| 738 |
+
s = s.replace(',', '')
|
| 739 |
+
elif period_count > 0:
|
| 740 |
+
# Only periods
|
| 741 |
+
after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
|
| 742 |
+
if period_count > 1:
|
| 743 |
+
# Multiple periods = European thousands: "1.234.567"
|
| 744 |
+
s = s.replace('.', '')
|
| 745 |
+
elif len(after_last_period) == 3 and after_last_period.isdigit():
|
| 746 |
+
# Single period with exactly 3 digits → European thousands: "1.000"
|
| 747 |
+
before_period = s[:last_period]
|
| 748 |
+
if before_period.isdigit() and len(before_period) <= 3:
|
| 749 |
+
s = s.replace('.', '')
|
| 750 |
+
|
| 751 |
+
s = re.sub(r'[^\d.]', '', s)
|
| 752 |
+
|
| 753 |
+
if s == "" or s == ".":
|
| 754 |
+
return 0.0
|
| 755 |
+
|
| 756 |
+
try:
|
| 757 |
+
result = float(s)
|
| 758 |
+
return -result if is_negative else result
|
| 759 |
+
except ValueError:
|
| 760 |
return 0.0
|
|
|
|
| 761 |
|
| 762 |
header = data.get("header", {})
|
| 763 |
summary = data.get("summary", {})
|
|
|
|
| 1025 |
return None
|
| 1026 |
|
| 1027 |
def clean_number(x):
|
| 1028 |
+
"""Parse number handling both US and European formats."""
|
| 1029 |
if x is None:
|
| 1030 |
return 0.0
|
| 1031 |
if isinstance(x, (int, float)):
|
|
|
|
| 1033 |
s = str(x).strip()
|
| 1034 |
if s == "":
|
| 1035 |
return 0.0
|
| 1036 |
+
|
| 1037 |
+
# Handle negative signs (leading or trailing)
|
| 1038 |
+
is_negative = False
|
| 1039 |
+
if s.startswith('-'):
|
| 1040 |
+
is_negative = True
|
| 1041 |
+
s = s[1:].strip()
|
| 1042 |
+
elif s.endswith('-'):
|
| 1043 |
+
is_negative = True
|
| 1044 |
+
s = s[:-1].strip()
|
| 1045 |
+
elif s.startswith('(') and s.endswith(')'):
|
| 1046 |
+
is_negative = True
|
| 1047 |
+
s = s[1:-1].strip()
|
| 1048 |
+
|
| 1049 |
+
# Remove currency symbols and spaces
|
| 1050 |
+
s = re.sub(r'[€$£¥₹\s]', '', s)
|
| 1051 |
+
|
| 1052 |
+
if s == "":
|
| 1053 |
return 0.0
|
| 1054 |
+
|
| 1055 |
+
comma_count = s.count(',')
|
| 1056 |
+
period_count = s.count('.')
|
| 1057 |
+
last_comma = s.rfind(',')
|
| 1058 |
+
last_period = s.rfind('.')
|
| 1059 |
+
|
| 1060 |
+
if comma_count > 0 and period_count > 0:
|
| 1061 |
+
# Both present - LAST one is decimal separator
|
| 1062 |
+
if last_comma > last_period:
|
| 1063 |
+
# European: 1.234,56 or 3.000,2234
|
| 1064 |
+
s = s.replace('.', '').replace(',', '.')
|
| 1065 |
+
else:
|
| 1066 |
+
# US: 1,234.56
|
| 1067 |
+
s = s.replace(',', '')
|
| 1068 |
+
elif comma_count > 0:
|
| 1069 |
+
# Only commas
|
| 1070 |
+
after_last_comma = s[last_comma + 1:] if last_comma < len(s) - 1 else ""
|
| 1071 |
+
if comma_count == 1 and len(after_last_comma) <= 4 and after_last_comma.isdigit():
|
| 1072 |
+
# European decimal: "261,49" or "9,60"
|
| 1073 |
+
s = s.replace(',', '.')
|
| 1074 |
+
elif len(after_last_comma) == 3:
|
| 1075 |
+
# 3 digits after comma → thousands: "1,234"
|
| 1076 |
+
s = s.replace(',', '')
|
| 1077 |
+
else:
|
| 1078 |
+
# Multiple commas = US thousands: "1,234,567"
|
| 1079 |
+
s = s.replace(',', '')
|
| 1080 |
+
elif period_count > 0:
|
| 1081 |
+
# Only periods
|
| 1082 |
+
after_last_period = s[last_period + 1:] if last_period < len(s) - 1 else ""
|
| 1083 |
+
if period_count > 1:
|
| 1084 |
+
# Multiple periods = European thousands: "1.234.567"
|
| 1085 |
+
s = s.replace('.', '')
|
| 1086 |
+
elif len(after_last_period) == 3 and after_last_period.isdigit():
|
| 1087 |
+
# Single period with exactly 3 digits → European thousands: "1.000"
|
| 1088 |
+
before_period = s[:last_period]
|
| 1089 |
+
if before_period.isdigit() and len(before_period) <= 3:
|
| 1090 |
+
s = s.replace('.', '')
|
| 1091 |
+
|
| 1092 |
+
s = re.sub(r'[^\d.]', '', s)
|
| 1093 |
+
|
| 1094 |
+
if s == "" or s == ".":
|
| 1095 |
+
return 0.0
|
| 1096 |
+
|
| 1097 |
try:
|
| 1098 |
+
result = float(s)
|
| 1099 |
+
return -result if is_negative else result
|
| 1100 |
+
except ValueError:
|
| 1101 |
return 0.0
|
| 1102 |
|
| 1103 |
def collect_keys(obj, out):
|
|
|
|
| 1368 |
rows.append(row)
|
| 1369 |
return rows
|
| 1370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1371 |
|
| 1372 |
# -----------------------------
|
| 1373 |
# Session scaffolding
|
|
|
|
| 1438 |
continue
|
| 1439 |
|
| 1440 |
# vLLM Inference + parsing + tax validation
|
| 1441 |
+
raw_json = None
|
| 1442 |
+
mapped = {}
|
| 1443 |
try:
|
| 1444 |
# Call vLLM API
|
| 1445 |
raw_json = run_inference_vllm(image)
|
|
|
|
| 1458 |
st.warning(f"No response from vLLM for {uploaded_file.name}")
|
| 1459 |
mapped = {}
|
| 1460 |
|
|
|
|
| 1461 |
except Exception as e:
|
| 1462 |
st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
|
| 1463 |
+
raw_json = None
|
| 1464 |
mapped = {}
|
| 1465 |
|
| 1466 |
safe_mapped = mapped if isinstance(mapped, dict) else {}
|
| 1467 |
|
| 1468 |
+
# Store BOTH raw string AND parsed dict for display
|
| 1469 |
st.session_state.batch_results[file_hash] = {
|
| 1470 |
"file_name": uploaded_file.name,
|
| 1471 |
"image": image,
|
| 1472 |
+
"raw_pred": raw_json, # Original string from API (untouched)
|
| 1473 |
"mapped_data": safe_mapped,
|
| 1474 |
"edited_data": safe_mapped.copy()
|
| 1475 |
}
|
|
|
|
| 1604 |
with frame_left:
|
| 1605 |
st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH)
|
| 1606 |
st.write(f"**File Hash:** {selected_hash[:8]}...")
|
| 1607 |
+
|
| 1608 |
+
# ============ RAW MODEL OUTPUT DISPLAY (UNTOUCHED) ============
|
| 1609 |
+
with st.expander("🔍 Show raw model output"):
|
| 1610 |
+
raw_pred = current.get('raw_pred')
|
| 1611 |
+
|
| 1612 |
+
if raw_pred is None:
|
| 1613 |
+
st.warning("No raw output available (API may have returned None)")
|
| 1614 |
+
else:
|
| 1615 |
+
# Show raw output exactly as received from the model - UNTOUCHED
|
| 1616 |
+
st.code(str(raw_pred), language='json')
|
| 1617 |
+
# ==============================================================
|
| 1618 |
|
| 1619 |
if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"):
|
| 1620 |
with st.spinner("Re-running inference..."):
|
|
|
|
| 1637 |
mapped = {}
|
| 1638 |
|
| 1639 |
safe_mapped = mapped if isinstance(mapped, dict) else {}
|
|
|
|
| 1640 |
|
| 1641 |
# Update stored results
|
| 1642 |
+
st.session_state.batch_results[selected_hash]["raw_pred"] = raw_json
|
| 1643 |
st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
|
| 1644 |
st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
|
| 1645 |
|
|
|
|
| 1687 |
if st.session_state.get(f"Currency_{selected_hash}") == 'Other':
|
| 1688 |
st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}")
|
| 1689 |
|
| 1690 |
+
st.number_input("Subtotal", key=f"Subtotal_{selected_hash}", format="%.2f")
|
| 1691 |
+
st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}", format="%.4f")
|
| 1692 |
+
st.number_input("Total Tax", key=f"Total Tax_{selected_hash}", format="%.2f")
|
| 1693 |
+
st.number_input("Total Amount", key=f"Total Amount_{selected_hash}", format="%.2f")
|
| 1694 |
|
| 1695 |
with tabs[1]:
|
| 1696 |
st.text_input("Sender Name", key=f"Sender Name_{selected_hash}")
|
|
|
|
| 1803 |
|
| 1804 |
st.dataframe(
|
| 1805 |
totals_df,
|
| 1806 |
+
use_container_width=True,
|
| 1807 |
hide_index=True,
|
| 1808 |
height=38
|
| 1809 |
)
|
|
|
|
| 1853 |
calculated_tax_pct = round((calculated_total_tax / calculated_subtotal) * 100, 4)
|
| 1854 |
|
| 1855 |
if saved:
|
| 1856 |
+
# Build updated data structure using ACTUAL user-entered values from form
|
| 1857 |
updated = {
|
| 1858 |
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
|
| 1859 |
'Invoice Date': invoice_date_str,
|
| 1860 |
'Due Date': due_date_str,
|
| 1861 |
'Currency': currency,
|
| 1862 |
+
'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
|
| 1863 |
+
'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
|
| 1864 |
+
'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
|
| 1865 |
+
'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
|
| 1866 |
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
|
| 1867 |
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
|
| 1868 |
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
|
|
|
|
| 1886 |
# Save to batch_results (this persists the data)
|
| 1887 |
st.session_state.batch_results[selected_hash]["edited_data"] = updated
|
| 1888 |
|
| 1889 |
+
# CRITICAL: Clear ALL session state keys for this file so they reload from saved edited_data
|
| 1890 |
+
keys_to_delete = [k for k in list(st.session_state.keys()) if k.endswith(f"_{selected_hash}")]
|
| 1891 |
+
for key in keys_to_delete:
|
| 1892 |
+
del st.session_state[key]
|
| 1893 |
|
| 1894 |
# Show success message
|
| 1895 |
st.success("✅ Saved")
|
|
|
|
| 1897 |
# Rerun to reload the form with saved data
|
| 1898 |
st.rerun()
|
| 1899 |
|
| 1900 |
+
# Per-file CSV download (ALWAYS visible, uses current form values)
|
| 1901 |
download_data = {
|
| 1902 |
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
|
| 1903 |
'Invoice Date': invoice_date_str,
|
| 1904 |
'Due Date': due_date_str,
|
| 1905 |
'Currency': currency,
|
| 1906 |
+
'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
|
| 1907 |
+
'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
|
| 1908 |
+
'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
|
| 1909 |
+
'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
|
| 1910 |
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
|
| 1911 |
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
|
| 1912 |
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
|