Update src/streamlit_app.py
Browse files- src/streamlit_app.py +95 -87
src/streamlit_app.py
CHANGED
|
@@ -196,7 +196,6 @@ def run_inference_on_image(image: Image.Image, processor, model, device, decoder
|
|
| 196 |
|
| 197 |
# ---------------------------
|
| 198 |
# Helper: map donut output to our UI schema
|
| 199 |
-
# (kept unchanged from your original)
|
| 200 |
# ---------------------------
|
| 201 |
def map_prediction_to_ui(pred):
|
| 202 |
import json, re
|
|
@@ -218,77 +217,6 @@ def map_prediction_to_ui(pred):
|
|
| 218 |
except Exception:
|
| 219 |
return None
|
| 220 |
return None
|
| 221 |
-
def flatten_invoice_to_rows(invoice_data) -> list:
|
| 222 |
-
"""
|
| 223 |
-
Converts nested invoice data into a flat list of rows (one per line item),
|
| 224 |
-
with invoice-level and sender/recipient/bank fields repeated in each row.
|
| 225 |
-
"""
|
| 226 |
-
rows = []
|
| 227 |
-
line_items = invoice_data.get("Itemized Data", [])
|
| 228 |
-
if not line_items:
|
| 229 |
-
# If no line items, create one row with invoice info only
|
| 230 |
-
row = {
|
| 231 |
-
"Invoice Number": invoice_data.get("Invoice Number", ""),
|
| 232 |
-
"Invoice Date": invoice_data.get("Invoice Date", ""),
|
| 233 |
-
"Due Date": invoice_data.get("Due Date", ""),
|
| 234 |
-
"Currency": invoice_data.get("Currency", ""),
|
| 235 |
-
"Subtotal": invoice_data.get("Subtotal", 0.0),
|
| 236 |
-
"Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
|
| 237 |
-
"Total Tax": invoice_data.get("Total Tax", 0.0),
|
| 238 |
-
"Total Amount": invoice_data.get("Total Amount", 0.0),
|
| 239 |
-
"Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
|
| 240 |
-
"Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
|
| 241 |
-
"Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
|
| 242 |
-
"Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
|
| 243 |
-
}
|
| 244 |
-
# Flatten bank details
|
| 245 |
-
bank = invoice_data.get("Bank Details", {})
|
| 246 |
-
for k, v in bank.items():
|
| 247 |
-
row[f"bank_{k}"] = v
|
| 248 |
-
|
| 249 |
-
# Add empty line item fields
|
| 250 |
-
row.update({
|
| 251 |
-
"Item Description": "",
|
| 252 |
-
"Item Quantity": 0,
|
| 253 |
-
"Item Unit Price": 0.0,
|
| 254 |
-
"Item Amount": 0.0,
|
| 255 |
-
})
|
| 256 |
-
rows.append(row)
|
| 257 |
-
return rows
|
| 258 |
-
|
| 259 |
-
# For each line item, create a row with all invoice context
|
| 260 |
-
for item in line_items:
|
| 261 |
-
row = {
|
| 262 |
-
"Invoice Number": invoice_data.get("Invoice Number", ""),
|
| 263 |
-
"Invoice Date": invoice_data.get("Invoice Date", ""),
|
| 264 |
-
"Due Date": invoice_data.get("Due Date", ""),
|
| 265 |
-
"Currency": invoice_data.get("Currency", ""),
|
| 266 |
-
"Subtotal": invoice_data.get("Subtotal", 0.0),
|
| 267 |
-
"Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
|
| 268 |
-
"Total Tax": invoice_data.get("Total Tax", 0.0),
|
| 269 |
-
"Total Amount": invoice_data.get("Total Amount", 0.0),
|
| 270 |
-
"Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
|
| 271 |
-
"Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
|
| 272 |
-
"Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
|
| 273 |
-
"Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
|
| 274 |
-
}
|
| 275 |
-
|
| 276 |
-
# Flatten bank details
|
| 277 |
-
bank = invoice_data.get("Bank Details", {})
|
| 278 |
-
for k, v in bank.items():
|
| 279 |
-
row[f"bank_{k}"] = v
|
| 280 |
-
|
| 281 |
-
# Add line item fields
|
| 282 |
-
row.update({
|
| 283 |
-
"Item Description": item.get("Description", ""),
|
| 284 |
-
"Item Quantity": item.get("Quantity", 0),
|
| 285 |
-
"Item Unit Price": item.get("Unit Price", 0.0),
|
| 286 |
-
"Item Amount": item.get("Amount", 0.0),
|
| 287 |
-
})
|
| 288 |
-
|
| 289 |
-
rows.append(row)
|
| 290 |
-
|
| 291 |
-
return rows
|
| 292 |
|
| 293 |
def clean_number(x):
|
| 294 |
if x is None:
|
|
@@ -474,6 +402,81 @@ def flatten_invoice_to_rows(invoice_data) -> list:
|
|
| 474 |
|
| 475 |
return ui
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
# Load model once
|
| 478 |
try:
|
| 479 |
with st.spinner("Loading model & processor (cached) ..."):
|
|
@@ -553,7 +556,11 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
|
|
| 553 |
mapped = map_prediction_to_ui(pred)
|
| 554 |
except Exception as e:
|
| 555 |
st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
|
| 556 |
-
pred
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
# Save to session state
|
| 559 |
st.session_state.batch_results[file_hash] = {
|
|
@@ -561,7 +568,7 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
|
|
| 561 |
"image": image,
|
| 562 |
"raw_pred": pred,
|
| 563 |
"mapped_data": mapped,
|
| 564 |
-
"edited_data":
|
| 565 |
}
|
| 566 |
|
| 567 |
progress_bar.progress((idx + 1) / len(uploaded_files))
|
|
@@ -711,6 +718,9 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 711 |
data['Bank Details'] = bank_info
|
| 712 |
|
| 713 |
# ---------- Line Items ----------
|
|
|
|
|
|
|
|
|
|
| 714 |
with tabs[3]:
|
| 715 |
editor_key = f"item_editor_{selected_hash}"
|
| 716 |
|
|
@@ -720,29 +730,27 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 720 |
if col not in df.columns:
|
| 721 |
df[col] = ""
|
| 722 |
|
| 723 |
-
|
|
|
|
| 724 |
edited_df = st.data_editor(
|
| 725 |
df,
|
| 726 |
num_rows="dynamic",
|
| 727 |
key=editor_key,
|
| 728 |
use_container_width=True,
|
| 729 |
-
on_change=trigger_rerun
|
| 730 |
)
|
| 731 |
|
| 732 |
-
# ✅ Safe to use — edited_df is updated after forced rerun
|
| 733 |
-
data['Itemized Data'] = edited_df.to_dict('records')
|
| 734 |
-
|
| 735 |
-
# ❗ OPTIONAL: Auto-calculate Amount = Quantity × Unit Price
|
| 736 |
-
# Uncomment below if you want auto-calculation:
|
| 737 |
-
# if "Quantity" in edited_df.columns and "Unit Price" in edited_df.columns:
|
| 738 |
-
# edited_df["Amount"] = (edited_df["Quantity"] * edited_df["Unit Price"]).round(2)
|
| 739 |
-
# data['Itemized Data'] = edited_df.to_dict('records')
|
| 740 |
-
|
| 741 |
if len(edited_df) == 0:
|
| 742 |
st.info("No line items found in the invoice.")
|
| 743 |
|
| 744 |
-
#
|
|
|
|
|
|
|
| 745 |
if st.button("💾 Save Edits for This File", key=f"save_{selected_hash}"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
st.session_state.batch_results[selected_hash]["edited_data"] = data
|
| 747 |
st.success(f"✅ Edits saved for {current['file_name']}")
|
| 748 |
|
|
@@ -798,7 +806,7 @@ elif len(st.session_state.batch_results) > 0:
|
|
| 798 |
json_name = f"{Path(result['file_name']).stem}_extracted.json"
|
| 799 |
zf.writestr(json_name, json_data)
|
| 800 |
|
| 801 |
-
|
| 802 |
rows = flatten_invoice_to_rows(result["edited_data"])
|
| 803 |
full_df = pd.DataFrame(rows)
|
| 804 |
|
|
|
|
| 196 |
|
| 197 |
# ---------------------------
|
| 198 |
# Helper: map donut output to our UI schema
|
|
|
|
| 199 |
# ---------------------------
|
| 200 |
def map_prediction_to_ui(pred):
|
| 201 |
import json, re
|
|
|
|
| 217 |
except Exception:
|
| 218 |
return None
|
| 219 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
def clean_number(x):
|
| 222 |
if x is None:
|
|
|
|
| 402 |
|
| 403 |
return ui
|
| 404 |
|
| 405 |
+
# ---------------------------
|
| 406 |
+
# Helper: flatten invoice to CSV rows
|
| 407 |
+
# ---------------------------
|
| 408 |
+
def flatten_invoice_to_rows(invoice_data) -> list:
|
| 409 |
+
"""
|
| 410 |
+
Converts nested invoice data into a flat list of rows (one per line item),
|
| 411 |
+
with invoice-level and sender/recipient/bank fields repeated in each row.
|
| 412 |
+
"""
|
| 413 |
+
rows = []
|
| 414 |
+
line_items = invoice_data.get("Itemized Data", [])
|
| 415 |
+
if not line_items:
|
| 416 |
+
# If no line items, create one row with invoice info only
|
| 417 |
+
row = {
|
| 418 |
+
"Invoice Number": invoice_data.get("Invoice Number", ""),
|
| 419 |
+
"Invoice Date": invoice_data.get("Invoice Date", ""),
|
| 420 |
+
"Due Date": invoice_data.get("Due Date", ""),
|
| 421 |
+
"Currency": invoice_data.get("Currency", ""),
|
| 422 |
+
"Subtotal": invoice_data.get("Subtotal", 0.0),
|
| 423 |
+
"Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
|
| 424 |
+
"Total Tax": invoice_data.get("Total Tax", 0.0),
|
| 425 |
+
"Total Amount": invoice_data.get("Total Amount", 0.0),
|
| 426 |
+
"Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
|
| 427 |
+
"Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
|
| 428 |
+
"Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
|
| 429 |
+
"Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
|
| 430 |
+
}
|
| 431 |
+
# Flatten bank details
|
| 432 |
+
bank = invoice_data.get("Bank Details", {})
|
| 433 |
+
for k, v in bank.items():
|
| 434 |
+
row[f"bank_{k}"] = v
|
| 435 |
+
|
| 436 |
+
# Add empty line item fields
|
| 437 |
+
row.update({
|
| 438 |
+
"Item Description": "",
|
| 439 |
+
"Item Quantity": 0,
|
| 440 |
+
"Item Unit Price": 0.0,
|
| 441 |
+
"Item Amount": 0.0,
|
| 442 |
+
})
|
| 443 |
+
rows.append(row)
|
| 444 |
+
return rows
|
| 445 |
+
|
| 446 |
+
# For each line item, create a row with all invoice context
|
| 447 |
+
for item in line_items:
|
| 448 |
+
row = {
|
| 449 |
+
"Invoice Number": invoice_data.get("Invoice Number", ""),
|
| 450 |
+
"Invoice Date": invoice_data.get("Invoice Date", ""),
|
| 451 |
+
"Due Date": invoice_data.get("Due Date", ""),
|
| 452 |
+
"Currency": invoice_data.get("Currency", ""),
|
| 453 |
+
"Subtotal": invoice_data.get("Subtotal", 0.0),
|
| 454 |
+
"Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
|
| 455 |
+
"Total Tax": invoice_data.get("Total Tax", 0.0),
|
| 456 |
+
"Total Amount": invoice_data.get("Total Amount", 0.0),
|
| 457 |
+
"Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
|
| 458 |
+
"Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
|
| 459 |
+
"Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
|
| 460 |
+
"Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
# Flatten bank details
|
| 464 |
+
bank = invoice_data.get("Bank Details", {})
|
| 465 |
+
for k, v in bank.items():
|
| 466 |
+
row[f"bank_{k}"] = v
|
| 467 |
+
|
| 468 |
+
# Add line item fields
|
| 469 |
+
row.update({
|
| 470 |
+
"Item Description": item.get("Description", ""),
|
| 471 |
+
"Item Quantity": item.get("Quantity", 0),
|
| 472 |
+
"Item Unit Price": item.get("Unit Price", 0.0),
|
| 473 |
+
"Item Amount": item.get("Amount", 0.0),
|
| 474 |
+
})
|
| 475 |
+
|
| 476 |
+
rows.append(row)
|
| 477 |
+
|
| 478 |
+
return rows
|
| 479 |
+
|
| 480 |
# Load model once
|
| 481 |
try:
|
| 482 |
with st.spinner("Loading model & processor (cached) ..."):
|
|
|
|
| 556 |
mapped = map_prediction_to_ui(pred)
|
| 557 |
except Exception as e:
|
| 558 |
st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
|
| 559 |
+
pred = None
|
| 560 |
+
mapped = {} # 👈 Ensure mapped is always a dict
|
| 561 |
+
|
| 562 |
+
# ✅ SAFETY: Ensure mapped is a dict before copying
|
| 563 |
+
safe_mapped = mapped if isinstance(mapped, dict) else {}
|
| 564 |
|
| 565 |
# Save to session state
|
| 566 |
st.session_state.batch_results[file_hash] = {
|
|
|
|
| 568 |
"image": image,
|
| 569 |
"raw_pred": pred,
|
| 570 |
"mapped_data": mapped,
|
| 571 |
+
"edited_data": safe_mapped.copy() # editable copy — now safe
|
| 572 |
}
|
| 573 |
|
| 574 |
progress_bar.progress((idx + 1) / len(uploaded_files))
|
|
|
|
| 718 |
data['Bank Details'] = bank_info
|
| 719 |
|
| 720 |
# ---------- Line Items ----------
|
| 721 |
+
# ---------- Line Items ----------
|
| 722 |
+
# ---------- Line Items ----------
|
| 723 |
+
# ---------- Line Items ----------
|
| 724 |
with tabs[3]:
|
| 725 |
editor_key = f"item_editor_{selected_hash}"
|
| 726 |
|
|
|
|
| 730 |
if col not in df.columns:
|
| 731 |
df[col] = ""
|
| 732 |
|
| 733 |
+
st.write("✏️ Edit line items below. Press Enter or click outside a cell to confirm each edit.")
|
| 734 |
+
# Get the edited DataFrame directly from data_editor
|
| 735 |
edited_df = st.data_editor(
|
| 736 |
df,
|
| 737 |
num_rows="dynamic",
|
| 738 |
key=editor_key,
|
| 739 |
use_container_width=True,
|
| 740 |
+
on_change=trigger_rerun
|
| 741 |
)
|
| 742 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
if len(edited_df) == 0:
|
| 744 |
st.info("No line items found in the invoice.")
|
| 745 |
|
| 746 |
+
# ... (previous code) ...
|
| 747 |
+
|
| 748 |
+
# Save button (per file)
|
| 749 |
if st.button("💾 Save Edits for This File", key=f"save_{selected_hash}"):
|
| 750 |
+
# Update line items from the current edited_df (which is a DataFrame)
|
| 751 |
+
data['Itemized Data'] = edited_df.to_dict('records')
|
| 752 |
+
|
| 753 |
+
# Save the entire data to session state
|
| 754 |
st.session_state.batch_results[selected_hash]["edited_data"] = data
|
| 755 |
st.success(f"✅ Edits saved for {current['file_name']}")
|
| 756 |
|
|
|
|
| 806 |
json_name = f"{Path(result['file_name']).stem}_extracted.json"
|
| 807 |
zf.writestr(json_name, json_data)
|
| 808 |
|
| 809 |
+
# Save FULL CSV (all data)
|
| 810 |
rows = flatten_invoice_to_rows(result["edited_data"])
|
| 811 |
full_df = pd.DataFrame(rows)
|
| 812 |
|