donut_UI / src /streamlit_app.py
Ankushbl6's picture
Update src/streamlit_app.py
34e1651 verified
# =========================
# Invoice Extractor (Donut) - Batch Mode (stable UI + original line-item logic)
# =========================
import os
from pathlib import Path
# -----------------------------
# Environment hardening (HF Spaces, /.cache issue)
# -----------------------------
_home = os.environ.get("HOME", "")
if _home in ("", "/", None):
repo_dir = os.getcwd()
safe_home = repo_dir if os.access(repo_dir, os.W_OK) else "/tmp"
os.environ["HOME"] = safe_home
print(f"[startup] HOME not set or unwritable — setting HOME={safe_home}")
streamlit_dir = Path(os.environ["HOME"]) / ".streamlit"
try:
streamlit_dir.mkdir(parents=True, exist_ok=True)
print(f"[startup] ensured {streamlit_dir}")
except Exception as e:
print(f"[startup] WARNING: could not create {streamlit_dir}: {e}")
# -----------------------------
# Imports
# -----------------------------
import json
from io import BytesIO
import hashlib
from typing import Dict, Any
import streamlit as st
import pandas as pd
from PIL import Image
from huggingface_hub import login
# Optional: pdf2image is only needed for PDFs
try:
from pdf2image import convert_from_bytes
except Exception:
convert_from_bytes = None
# -----------------------------
# Page config & CSS
# -----------------------------
st.set_page_config(page_title="Invoice Extractor (Donut) - Batch Mode", layout="wide")
st.title("Invoice Extraction")
st.markdown(
"""
<style>
.stApp { background-color: #ECECEC !important; }
div.block-container { padding-top: 1rem; padding-bottom: 1rem; }
[data-testid="stSidebar"] { background-color: #F7F7F7 !important; }
div[data-testid="stTabs"] > div > div { padding-bottom: 6px !important; }
/* Keep right column steady on first render post-extraction */
[data-testid="column"]:nth-of-type(2) { min-height: 780px; }
</style>
""",
unsafe_allow_html=True
)
# Fixed sizes to prevent reflow wobble
FIXED_IMG_WIDTH = 640
DATA_EDITOR_HEIGHT = 380
# -----------------------------
# Helpers
# -----------------------------
def ensure_state(k: str, default):
"""Initialize a session_state key once, then let widgets bind to it via key=... (no value=...)."""
if k not in st.session_state:
st.session_state[k] = default
def clean_float(x) -> float:
import re
if x is None:
return 0.0
if isinstance(x, (int, float)):
return float(x)
s = str(x).strip()
if s == "":
return 0.0
s = re.sub(r"[,\s]", "", s)
s = re.sub(r"[^\d\.\-]", "", s)
if s in ("", ".", "-", "-."):
return 0.0
try:
return float(s)
except Exception:
return 0.0
# -----------------------------
# HF login flow (token from session/env/secrets)
# -----------------------------
def _get_hf_token():
if st.session_state.get("_hf_token"):
return st.session_state.get("_hf_token"), "session"
env_tok = os.getenv("HF_TOKEN")
if env_tok:
return env_tok, "env"
try:
sec = st.secrets.get("HF_TOKEN", None)
if sec:
return sec, "secrets"
except Exception:
pass
return None, None
hf_token, hf_token_source = _get_hf_token()
if hf_token is None:
st.subheader("Login Token 🔑")
token_input = st.text_input("Enter your Hugging Face token (starts with 'hf_'):", type="password")
if token_input:
if not token_input.startswith("hf_"):
st.error("Invalid token format. Token must start with 'hf_'.")
st.stop()
try:
login(token_input)
st.session_state["_hf_token"] = token_input
st.session_state.logged_in = True
st.success("Logged in successfully. Loading model...")
st.rerun()
except Exception as e:
st.error(f"Failed to log in: {e}")
st.stop()
else:
st.warning("Provide a token here or set HF_TOKEN in the environment.")
st.stop()
else:
try:
login(hf_token)
st.session_state.logged_in = True
except Exception as e:
st.error(f"Failed to log in with {hf_token_source or 'unknown'} token: {e}")
st.stop()
# -----------------------------
# Model config
# -----------------------------
HF_MODEL_ID = "Bhuvi13/model-V7"
TASK_PROMPT = "<s_cord-v2>"
@st.cache_resource(show_spinner=False)
def load_model_and_processor(hf_model_id: str, task_prompt: str):
try:
import torch
from transformers import VisionEncoderDecoderModel, DonutProcessor
except Exception as e:
raise RuntimeError(f"Failed to import ML libraries: {e}")
try:
processor = DonutProcessor.from_pretrained(hf_model_id)
model = VisionEncoderDecoderModel.from_pretrained(hf_model_id)
except Exception as e:
raise RuntimeError(
f"Failed to load model/processor from Hugging Face ({hf_model_id}). "
f"Original error: {e}"
)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
with torch.no_grad():
decoder_input_ids = processor.tokenizer(
task_prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids.to(device)
return processor, model, device, decoder_input_ids
def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids):
import torch
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
gen_kwargs = dict(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=1536,
num_beams=4,
early_stopping=False,
)
with torch.no_grad():
generated_ids = model.generate(**gen_kwargs)
raw_pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
cleaned = (raw_pred
.replace(processor.tokenizer.eos_token or "", "")
.replace(processor.tokenizer.pad_token or "", "")
.strip())
token2json_out = processor.token2json(cleaned)
if isinstance(token2json_out, str):
try:
pred_dict = json.loads(token2json_out)
except Exception:
pred_dict = token2json_out
else:
pred_dict = token2json_out
return pred_dict
# -----------------------------
# ORIGINAL (previous) mapping logic — restored verbatim
# -----------------------------
def map_prediction_to_ui(pred):
import json, re
from collections import defaultdict
def safe_json_load(s):
if s is None:
return None
if isinstance(s, (dict, list)):
return s
if isinstance(s, str):
s = s.strip()
if s == "":
return None
try:
return json.loads(s)
except Exception:
subs = []
stack = []
start = None
for i, ch in enumerate(s):
if ch == "{":
if not stack:
start = i
stack.append("{")
elif ch == "}":
if stack:
stack.pop()
if not stack and start is not None:
subs.append(s[start:i+1])
start = None
for sub in subs:
try:
return json.loads(sub)
except Exception:
continue
return None
def clean_number(x):
if x is None:
return 0.0
if isinstance(x, (int, float)):
return float(x)
s = str(x).strip()
if s == "":
return 0.0
s = re.sub(r"[,\s]", "", s)
s = re.sub(r"[^\d\.\-]", "", s)
if s in ("", ".", "-", "-."):
return 0.0
try:
return float(s)
except Exception:
return 0.0
def collect_keys(obj, out):
if isinstance(obj, dict):
for k, v in obj.items():
lk = str(k).strip().lower()
out[lk].append(v)
collect_keys(v, out)
elif isinstance(obj, list):
for it in obj:
collect_keys(it, out)
def collect_lists_of_dicts(obj, out_lists):
if isinstance(obj, dict):
for v in obj.values():
if isinstance(v, list) and v and isinstance(v[0], dict):
out_lists.append(v)
else:
collect_lists_of_dicts(v, out_lists)
elif isinstance(obj, list):
for it in obj:
if isinstance(it, list) and it and isinstance(it[0], dict):
out_lists.append(it)
else:
collect_lists_of_dicts(it, out_lists)
def map_item_dict(it):
if not isinstance(it, dict):
return None
lower = {str(k).strip().lower(): v for k, v in it.items()}
desc = (lower.get("descriptions") or lower.get("description") or lower.get("desc") or lower.get("item") or "")
qty = lower.get("quantity") or lower.get("qty") or lower.get("count") or ""
unit_price = lower.get("unit_price") or lower.get("price") or ""
amount = lower.get("amount") or lower.get("line_total") or lower.get("line total") or lower.get("total") or ""
tax = lower.get("tax") or lower.get("tax_amount") or ""
line_total = lower.get("line_total") or lower.get("line_total".lower()) or lower.get("line total") or amount
return {
"Description": str(desc).strip(),
"Quantity": float(clean_number(qty)),
"Unit Price": float(clean_number(unit_price)),
"Amount": float(clean_number(amount)),
"Tax": float(clean_number(tax)),
"Line Total": float(clean_number(line_total))
}
parsed = safe_json_load(pred) if isinstance(pred, str) else pred
if parsed is None and isinstance(pred, str):
parsed = None
if parsed is None and not isinstance(pred, dict):
parsed = pred
ui = {
"Invoice Number": "",
"Invoice Date": "",
"Due Date": "",
"Currency": "",
"Subtotal": 0.0,
"Tax Percentage": 0.0,
"Total Tax": 0.0,
"Total Amount": 0.0,
"Sender": {"Name": "", "Address": ""},
"Recipient": {"Name": "", "Address": ""},
"Sender Name": "",
"Sender Address": "",
"Recipient Name": "",
"Recipient Address": "",
"Bank Details": {},
"Itemized Data": []
}
key_map = defaultdict(list)
list_candidates = []
if isinstance(parsed, dict):
collect_keys(parsed, key_map)
collect_lists_of_dicts(parsed, list_candidates)
elif isinstance(pred, dict):
collect_keys(pred, key_map)
collect_lists_of_dicts(pred, list_candidates)
def pick_first(*candidate_keys):
for k in candidate_keys:
lk = k.strip().lower()
if lk in key_map:
for v in key_map[lk]:
if v is None:
continue
if isinstance(v, (dict, list)):
return v
s = str(v).strip()
if s != "":
return s
return None
ui["Invoice Number"] = pick_first("invoice_no", "invoice_number", "invoiceid", "invoice id") or ""
ui["Invoice Date"] = pick_first("invoice_date", "date", "invoice date") or ""
ui["Due Date"] = pick_first("due_date", "due_date", "due") or ""
ui["Sender Name"] = pick_first("sender_name", "sender") or ""
ui["Sender Address"] = pick_first("sender_addr", "sender_address", "sender addr") or ""
ui["Recipient Name"] = pick_first("rcpt_name", "recipient_name", "recipient", "rcpt") or ""
ui["Recipient Address"] = pick_first("rcpt_addr", "recipient_address", "recipient addr") or ""
bank = {}
for bk in ("bank_name", "bank_acc_no", "bank_account_number", "bank_acc_name", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "iban"):
val = pick_first(bk, bk.replace("bank_", ""))
if val:
if bk == "iban":
bank["bank_iban"] = str(val)
else:
bank[bk if bk != "bank_acc_no" else "bank_account_number"] = str(val)
ui["Bank Details"] = bank
ui["Subtotal"] = clean_number(pick_first("subtotal", "sub_total", "sub total") or 0.0)
ui["Tax Percentage"] = clean_number(pick_first("tax_rate", "tax_percentage", "tax pct", "tax percentage") or 0.0)
ui["Total Tax"] = clean_number(pick_first("tax_amount", "tax", "total_tax") or 0.0)
ui["Total Amount"] = clean_number(pick_first("total_amount", "grand_total", "total", "amount") or 0.0)
ui["Currency"] = (pick_first("currency") or "").strip()
items_rows = []
def list_looks_like_items(lst):
if not isinstance(lst, list) or not lst:
return False
if not isinstance(lst[0], dict):
return False
expected = {"descriptions", "description", "desc", "item", "quantity", "qty", "amount", "unit_price", "line_total", "line_total".lower(), "line_total"}
keys0 = {str(k).strip().lower() for k in lst[0].keys()}
return bool(expected.intersection(keys0))
for cand in list_candidates:
if list_looks_like_items(cand):
for it in cand:
row = map_item_dict(it)
if row is not None:
items_rows.append(row)
if items_rows:
break
if not items_rows:
single_candidate_keys = {k.strip().lower() for k in (parsed.keys() if isinstance(parsed, dict) else [])} if isinstance(parsed, dict) else set()
item_like_keys = {"descriptions", "description", "desc", "item", "quantity", "qty", "unit_price", "unit price", "price", "amount", "line_total", "line total", "line_total", "line_total".lower(), "sku", "tax", "tax_amount"}
if single_candidate_keys and single_candidate_keys.intersection(item_like_keys):
single_row = map_item_dict(parsed)
if single_row is not None:
items_rows.append(single_row)
if not items_rows:
for k, vals in key_map.items():
for v in vals:
if isinstance(v, dict):
lower_keys = {str(x).strip().lower() for x in v.keys()}
if lower_keys.intersection({"descriptions", "description", "desc", "amount", "line_total", "quantity", "qty", "unit_price"}):
row = map_item_dict(v)
if row is not None:
items_rows.append(row)
if not items_rows:
desc = pick_first("descriptions", "description")
amt = pick_first("amount", "line_total")
qty = pick_first("quantity", "qty")
unit_price = pick_first("unit_price", "price")
if desc or amt or qty or unit_price:
items_rows.append({
"Description": str(desc or ""),
"Quantity": float(clean_number(qty)),
"Unit Price": float(clean_number(unit_price)),
"Amount": float(clean_number(amt)),
"Tax": float(clean_number(pick_first("tax", "tax_amount") or 0.0)),
"Line Total": float(clean_number(amt or 0.0))
})
ui["Itemized Data"] = items_rows
ui["Sender"] = {"Name": ui["Sender Name"], "Address": ui["Sender Address"]}
ui["Recipient"] = {"Name": ui["Recipient Name"], "Address": ui["Recipient Address"]}
return ui
def flatten_invoice_to_rows(invoice_data) -> list:
EXPECTED_BANK_FIELDS = [
"bank_name",
"bank_account_number",
"bank_acc_name",
"bank_iban",
"bank_swift",
"bank_routing",
"bank_branch"
]
rows = []
invoice_data = invoice_data or {}
line_items = invoice_data.get("Itemized Data", []) or []
bank_details = {}
nested = invoice_data.get("Bank Details", {}) or {}
if isinstance(nested, dict):
for k, v in nested.items():
key_name = k if str(k).startswith("bank_") else f"bank_{k}"
bank_details[key_name] = v
for k, v in invoice_data.items():
if isinstance(k, str) and k.lower().startswith("bank_"):
bank_details[k] = v
for f in EXPECTED_BANK_FIELDS:
bank_details.setdefault(f, "")
def base_invoice_info():
return {
"Invoice Number": invoice_data.get("Invoice Number", ""),
"Invoice Date": invoice_data.get("Invoice Date", ""),
"Due Date": invoice_data.get("Due Date", ""),
"Currency": invoice_data.get("Currency", ""),
"Subtotal": invoice_data.get("Subtotal", 0.0),
"Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
"Total Tax": invoice_data.get("Total Tax", 0.0),
"Total Amount": invoice_data.get("Total Amount", 0.0),
"Sender Name": invoice_data.get("Sender Name", "") or (invoice_data.get("Sender",{}) or {}).get("Name",""),
"Sender Address": invoice_data.get("Sender Address", "") or (invoice_data.get("Sender",{}) or {}).get("Address",""),
"Recipient Name": invoice_data.get("Recipient Name", "") or (invoice_data.get("Recipient",{}) or {}).get("Name",""),
"Recipient Address": invoice_data.get("Recipient Address", "") or (invoice_data.get("Recipient",{}) or {}).get("Address",""),
}
if not line_items:
row = base_invoice_info()
for k in EXPECTED_BANK_FIELDS:
row[k] = bank_details.get(k, "")
row.update({
"Item Description": "",
"Item Quantity": 0,
"Item Unit Price": 0.0,
"Item Amount": 0.0,
"Item Tax": 0.0,
"Item Line Total": 0.0,
})
rows.append(row)
return rows
for item in line_items:
row = base_invoice_info()
for k in EXPECTED_BANK_FIELDS:
row[k] = bank_details.get(k, "")
row.update({
"Item Description": item.get("Description", "") if isinstance(item, dict) else "",
"Item Quantity": item.get("Quantity", 0) if isinstance(item, dict) else 0,
"Item Unit Price": item.get("Unit Price", 0.0) if isinstance(item, dict) else 0.0,
"Item Amount": item.get("Amount", 0.0) if isinstance(item, dict) else 0.0,
"Item Tax": item.get("Tax", 0.0) if isinstance(item, dict) else 0.0,
"Item Line Total": item.get("Line Total", item.get("Amount", 0.0)) if isinstance(item, dict) else 0.0,
})
rows.append(row)
return rows
# -----------------------------
# Load model
# -----------------------------
try:
with st.spinner("Loading model & processor (cached) ..."):
processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT)
except Exception as e:
st.error("Could not load model automatically. See details below.")
st.exception(e)
st.stop()
# -----------------------------
# Session scaffolding
# -----------------------------
if "batch_results" not in st.session_state:
st.session_state.batch_results = {}
if "current_file_hash" not in st.session_state:
st.session_state.current_file_hash = None
if "is_processing_batch" not in st.session_state:
st.session_state.is_processing_batch = False
# -----------------------------
# Pre-mount two-column skeleton to avoid layout jump
# -----------------------------
frame_left, frame_right = st.columns([1, 1], vertical_alignment="top")
# -----------------------------
# Upload / Process
# -----------------------------
if not st.session_state.is_processing_batch and len(st.session_state.batch_results) == 0:
with frame_left:
st.header("📤 Upload Invoices")
uploaded_files = st.file_uploader(
"Upload invoice images (png/jpg/jpeg/pdf)",
type=["png", "jpg", "jpeg", "pdf"],
accept_multiple_files=True
)
if uploaded_files:
st.session_state.is_processing_batch = True
progress_bar = st.progress(0)
status_text = st.empty()
for idx, uploaded_file in enumerate(uploaded_files):
status_text.text(f"Processing {idx+1}/{len(uploaded_files)}: {uploaded_file.name}")
uploaded_bytes = uploaded_file.read()
file_hash = hashlib.sha256(uploaded_bytes).hexdigest()
if file_hash in st.session_state.batch_results:
progress_bar.progress((idx + 1) / len(uploaded_files))
continue
# Load image (first page for PDFs)
image = None
is_pdf = uploaded_file.name.lower().endswith('.pdf') or (hasattr(uploaded_file, 'type') and uploaded_file.type == 'application/pdf')
if is_pdf:
if convert_from_bytes is None:
st.warning(f"PDF {uploaded_file.name} could not be rendered (pdf2image/poppler missing).")
continue
try:
pages = convert_from_bytes(uploaded_bytes, dpi=200)
if len(pages) > 0:
image = pages[0].convert("RGB")
else:
st.warning(f"PDF {uploaded_file.name} has no pages.")
continue
except Exception:
st.warning(f"Could not render PDF {uploaded_file.name}. Ensure 'pdf2image' and poppler are installed.")
continue
else:
try:
image = Image.open(BytesIO(uploaded_bytes)).convert("RGB")
except Exception:
st.warning(f"Failed to open {uploaded_file.name}.")
continue
if image is None:
continue
# Inference + mapping
try:
pred = run_inference_on_image(image, processor, model, device, decoder_input_ids)
mapped = map_prediction_to_ui(pred)
except Exception as e:
st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
pred = None
mapped = {}
safe_mapped = mapped if isinstance(mapped, dict) else {}
st.session_state.batch_results[file_hash] = {
"file_name": uploaded_file.name,
"image": image,
"raw_pred": pred,
"mapped_data": safe_mapped,
"edited_data": safe_mapped.copy()
}
progress_bar.progress((idx + 1) / len(uploaded_files))
status_text.text("✅ All files processed!")
st.session_state.is_processing_batch = False
st.rerun()
with frame_right:
st.caption("Preview & editor will appear here after extraction.")
elif len(st.session_state.batch_results) > 0:
# --------- Top row: All-results download + Back button ----------
with frame_left:
all_rows = []
for file_hash, result in st.session_state.batch_results.items():
rows = flatten_invoice_to_rows(result["edited_data"])
for r in rows:
r["Source File"] = result.get("file_name", file_hash)
all_rows.extend(rows)
if all_rows:
full_df = pd.DataFrame(all_rows)
cols = list(full_df.columns)
if "Source File" in cols:
cols = ["Source File"] + [c for c in cols if c != "Source File"]
full_df = full_df[cols]
csv_bytes = full_df.to_csv(index=False).encode("utf-8")
st.download_button("📦 Download All Results (CSV)", csv_bytes,
file_name="all_extracted_invoices.csv", mime="text/csv", key="download_all_csv")
with frame_right:
if st.button("⬅️ Back to Upload"):
st.session_state.batch_results.clear()
st.session_state.current_file_hash = None
st.session_state.is_processing_batch = False
st.rerun()
# --------- Selector ----------
with frame_left:
file_options = {f"{v['file_name']} ({k[:6]})": k for k, v in st.session_state.batch_results.items()}
selected_display = st.selectbox("Select invoice to view/edit:", options=list(file_options.keys()), index=0, key="file_selector")
selected_hash = file_options[selected_display]
if st.session_state.current_file_hash != selected_hash:
st.session_state.current_file_hash = selected_hash
current = st.session_state.batch_results[selected_hash]
image = current["image"]
form_data = current["edited_data"]
# --------- Initialize widget state ONCE (no value= in widgets) ----------
bank = form_data.get("Bank Details", {}) if isinstance(form_data.get("Bank Details", {}), dict) else {}
ensure_state(f"Invoice Number_{selected_hash}", form_data.get('Invoice Number', ''))
ensure_state(f"Invoice Date_{selected_hash}", str(form_data.get('Invoice Date', '')).strip())
ensure_state(f"Due Date_{selected_hash}", str(form_data.get('Due Date', '')).strip())
ensure_state(f"Currency_{selected_hash}", form_data.get('Currency', 'USD') or 'USD')
ensure_state(f"Currency_Custom_{selected_hash}", form_data.get('Currency', '') if form_data.get('Currency') not in ['USD','EUR','GBP','INR'] else '')
ensure_state(f"Subtotal_{selected_hash}", float(form_data.get('Subtotal', 0.0)))
ensure_state(f"Tax Percentage_{selected_hash}", float(form_data.get('Tax Percentage', 0.0)))
ensure_state(f"Total Tax_{selected_hash}", float(form_data.get('Total Tax', 0.0)))
ensure_state(f"Total Amount_{selected_hash}", float(form_data.get('Total Amount', 0.0)))
ensure_state(f"Sender Name_{selected_hash}", form_data.get('Sender Name', ''))
ensure_state(f"Sender Address_{selected_hash}", form_data.get('Sender Address', ''))
ensure_state(f"Recipient Name_{selected_hash}", form_data.get('Recipient Name', ''))
ensure_state(f"Recipient Address_{selected_hash}", form_data.get('Recipient Address', ''))
ensure_state(f"Bank_bank_name_{selected_hash}", bank.get('bank_name', ''))
ensure_state(f"Bank_bank_account_number_{selected_hash}", bank.get('bank_account_number', '') or bank.get('bank_acc_no', ''))
ensure_state(f"Bank_bank_acc_name_{selected_hash}", bank.get('bank_acc_name', ''))
ensure_state(f"Bank_bank_iban_{selected_hash}", bank.get('bank_iban', ''))
ensure_state(f"Bank_bank_swift_{selected_hash}", bank.get('bank_swift', ''))
ensure_state(f"Bank_bank_routing_{selected_hash}", bank.get('bank_routing', ''))
ensure_state(f"Bank_bank_branch_{selected_hash}", bank.get('bank_branch', ''))
# --------- Display (no wobble) ----------
with frame_left:
st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH)
st.write(f"**File Hash:** {selected_hash[:8]}...")
if current.get('raw_pred') is not None:
with st.expander("🔍 Show raw model output"):
st.json(current['raw_pred'])
if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"):
with st.spinner("Re-running inference..."):
try:
pred = run_inference_on_image(image, processor, model, device, decoder_input_ids)
mapped = map_prediction_to_ui(pred)
safe_mapped = mapped if isinstance(mapped, dict) else {}
# Update stored results
st.session_state.batch_results[selected_hash]["raw_pred"] = pred
st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
# Clear widget state for this file so defaults refresh from new mapped data
for key in [k for k in st.session_state.keys() if k.endswith(f"_{selected_hash}")]:
del st.session_state[key]
st.success("✅ Re-run complete")
st.rerun()
except Exception as e:
st.error(f"Re-run failed: {e}")
with frame_right:
st.subheader(f"Editable Invoice: {current['file_name']}")
# Quick swap outside the form (one clean rerun)
swap_cols = st.columns([1,1,2])
with swap_cols[0]:
if st.button("⇄ Swap Sender ↔ Recipient", key=f"swap_{selected_hash}"):
sn = f"Sender Name_{selected_hash}"
rn = f"Recipient Name_{selected_hash}"
sa = f"Sender Address_{selected_hash}"
ra = f"Recipient Address_{selected_hash}"
st.session_state[sn], st.session_state[rn] = st.session_state[rn], st.session_state[sn]
st.session_state[sa], st.session_state[ra] = st.session_state[ra], st.session_state[sa]
st.rerun()
# ----------------- FORM START -----------------
with st.form(key=f"edit_form_{selected_hash}", clear_on_submit=False):
tabs = st.tabs(["Invoice Details", "Sender/Recipient", "Bank Details", "Line Items"])
with tabs[0]:
st.text_input("Invoice Number", key=f"Invoice Number_{selected_hash}")
st.text_input("Invoice Date", key=f"Invoice Date_{selected_hash}")
st.text_input("Due Date", key=f"Due Date_{selected_hash}")
curr_options = ['USD', 'EUR', 'GBP', 'INR', 'Other']
if st.session_state[f"Currency_{selected_hash}"] not in curr_options:
st.session_state[f"Currency_{selected_hash}"] = 'Other'
st.selectbox("Currency", options=curr_options, key=f"Currency_{selected_hash}")
if st.session_state.get(f"Currency_{selected_hash}") == 'Other':
st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}")
st.number_input("Subtotal", key=f"Subtotal_{selected_hash}")
st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}")
st.number_input("Total Tax", key=f"Total Tax_{selected_hash}")
st.number_input("Total Amount", key=f"Total Amount_{selected_hash}")
with tabs[1]:
st.text_input("Sender Name", key=f"Sender Name_{selected_hash}")
st.text_area("Sender Address", key=f"Sender Address_{selected_hash}", height=80)
st.text_input("Recipient Name", key=f"Recipient Name_{selected_hash}")
st.text_area("Recipient Address", key=f"Recipient Address_{selected_hash}", height=80)
with tabs[2]:
st.text_input("Bank Name", key=f"Bank_bank_name_{selected_hash}")
st.text_input("Account Number", key=f"Bank_bank_account_number_{selected_hash}")
st.text_input("Account Name", key=f"Bank_bank_acc_name_{selected_hash}")
st.text_input("IBAN", key=f"Bank_bank_iban_{selected_hash}")
st.text_input("SWIFT", key=f"Bank_bank_swift_{selected_hash}")
st.text_input("Routing", key=f"Bank_bank_routing_{selected_hash}")
st.text_input("Branch", key=f"Bank_bank_branch_{selected_hash}")
with tabs[3]:
# Build base DF from current edited_data (not raw mapped) so it's always what the user last saved
item_rows = form_data.get('Itemized Data', []) or []
normalized = []
for it in item_rows:
if not isinstance(it, dict):
it = {}
normalized.append({
"Description": it.get("Description", it.get("Item Description", "")),
"Quantity": it.get("Quantity", it.get("Item Quantity", 0)),
"Unit Price": it.get("Unit Price", it.get("Item Unit Price", 0.0)),
"Amount": it.get("Amount", it.get("Item Amount", 0.0)),
"Tax": it.get("Tax", it.get("Item Tax", 0.0)),
"Line Total": it.get("Line Total", it.get("Item Line Total", 0.0)),
})
items_df = pd.DataFrame(normalized) if normalized else pd.DataFrame(
columns=["Description", "Quantity", "Unit Price", "Amount", "Tax", "Line Total"]
)
edited_df = st.data_editor(
items_df,
num_rows="dynamic",
key=f"items_editor_{selected_hash}",
use_container_width=True,
height=DATA_EDITOR_HEIGHT,
)
saved = st.form_submit_button("💾 Save All Edits")
# ----------------- FORM END -----------------
if saved:
currency = st.session_state.get(f"Currency_{selected_hash}", 'USD')
if currency == 'Other':
currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '')
updated = {
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''),
'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''),
'Currency': currency,
'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''),
'Bank Details': {
'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''),
'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''),
'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''),
'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''),
'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''),
'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''),
'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '')
},
'Itemized Data': edited_df.to_dict('records'),
'Sender': {"Name": st.session_state.get(f"Sender Name_{selected_hash}", ''),
"Address": st.session_state.get(f"Sender Address_{selected_hash}", '')},
'Recipient': {"Name": st.session_state.get(f"Recipient Name_{selected_hash}", ''),
"Address": st.session_state.get(f"Recipient Address_{selected_hash}", '')},
}
st.session_state.batch_results[selected_hash]["edited_data"] = updated
st.success(f"✅ Saved: {current['file_name']}")
# Per-file CSV download (uses the current editor contents even if not saved)
d_currency = st.session_state.get(f"Currency_{selected_hash}", 'USD')
if d_currency == 'Other':
d_currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '')
download_data = {
'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''),
'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''),
'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''),
'Currency': d_currency,
'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0),
'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0),
'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0),
'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0),
'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''),
'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''),
'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''),
'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''),
'Bank Details': {
'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''),
'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''),
'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''),
'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''),
'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''),
'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''),
'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '')
},
'Itemized Data': edited_df.to_dict('records')
}
rows = flatten_invoice_to_rows(download_data)
full_df = pd.DataFrame(rows)
csv_bytes_one = full_df.to_csv(index=False).encode("utf-8")
st.download_button(
"📥 Download This Invoice (CSV)",
csv_bytes_one,
file_name=f"{Path(current['file_name']).stem}_full.csv",
mime="text/csv",
key=f"dl_{selected_hash}"
)
elif st.session_state.is_processing_batch:
with frame_left:
st.info("⏳ Processing batch... Please wait.")
st.progress(0)
with frame_right:
st.caption("Preview & editor will appear here after extraction.")
else:
# Shouldn't happen, but keeps skeleton steady
with frame_left:
st.caption("Ready when you are.")
with frame_right:
st.caption("Preview & editor will appear here after extraction.")