# ========================= # Invoice Extractor (Donut) - Batch Mode (stable UI + original line-item logic) # ========================= import os from pathlib import Path # ----------------------------- # Environment hardening (HF Spaces, /.cache issue) # ----------------------------- _home = os.environ.get("HOME", "") if _home in ("", "/", None): repo_dir = os.getcwd() safe_home = repo_dir if os.access(repo_dir, os.W_OK) else "/tmp" os.environ["HOME"] = safe_home print(f"[startup] HOME not set or unwritable — setting HOME={safe_home}") streamlit_dir = Path(os.environ["HOME"]) / ".streamlit" try: streamlit_dir.mkdir(parents=True, exist_ok=True) print(f"[startup] ensured {streamlit_dir}") except Exception as e: print(f"[startup] WARNING: could not create {streamlit_dir}: {e}") # ----------------------------- # Imports # ----------------------------- import json from io import BytesIO import hashlib from typing import Dict, Any import streamlit as st import pandas as pd from PIL import Image from huggingface_hub import login # Optional: pdf2image is only needed for PDFs try: from pdf2image import convert_from_bytes except Exception: convert_from_bytes = None # ----------------------------- # Page config & CSS # ----------------------------- st.set_page_config(page_title="Invoice Extractor (Donut) - Batch Mode", layout="wide") st.title("Invoice Extraction") st.markdown( """ """, unsafe_allow_html=True ) # Fixed sizes to prevent reflow wobble FIXED_IMG_WIDTH = 640 DATA_EDITOR_HEIGHT = 380 # ----------------------------- # Helpers # ----------------------------- def ensure_state(k: str, default): """Initialize a session_state key once, then let widgets bind to it via key=... (no value=...).""" if k not in st.session_state: st.session_state[k] = default def clean_float(x) -> float: import re if x is None: return 0.0 if isinstance(x, (int, float)): return float(x) s = str(x).strip() if s == "": return 0.0 s = re.sub(r"[,\s]", "", s) s = re.sub(r"[^\d\.\-]", "", s) if s in ("", ".", "-", "-."): return 0.0 try: return float(s) except Exception: return 0.0 # ----------------------------- # HF login flow (token from session/env/secrets) # ----------------------------- def _get_hf_token(): if st.session_state.get("_hf_token"): return st.session_state.get("_hf_token"), "session" env_tok = os.getenv("HF_TOKEN") if env_tok: return env_tok, "env" try: sec = st.secrets.get("HF_TOKEN", None) if sec: return sec, "secrets" except Exception: pass return None, None hf_token, hf_token_source = _get_hf_token() if hf_token is None: st.subheader("Login Token 🔑") token_input = st.text_input("Enter your Hugging Face token (starts with 'hf_'):", type="password") if token_input: if not token_input.startswith("hf_"): st.error("Invalid token format. Token must start with 'hf_'.") st.stop() try: login(token_input) st.session_state["_hf_token"] = token_input st.session_state.logged_in = True st.success("Logged in successfully. Loading model...") st.rerun() except Exception as e: st.error(f"Failed to log in: {e}") st.stop() else: st.warning("Provide a token here or set HF_TOKEN in the environment.") st.stop() else: try: login(hf_token) st.session_state.logged_in = True except Exception as e: st.error(f"Failed to log in with {hf_token_source or 'unknown'} token: {e}") st.stop() # ----------------------------- # Model config # ----------------------------- HF_MODEL_ID = "Bhuvi13/model-V7" TASK_PROMPT = "" @st.cache_resource(show_spinner=False) def load_model_and_processor(hf_model_id: str, task_prompt: str): try: import torch from transformers import VisionEncoderDecoderModel, DonutProcessor except Exception as e: raise RuntimeError(f"Failed to import ML libraries: {e}") try: processor = DonutProcessor.from_pretrained(hf_model_id) model = VisionEncoderDecoderModel.from_pretrained(hf_model_id) except Exception as e: raise RuntimeError( f"Failed to load model/processor from Hugging Face ({hf_model_id}). " f"Original error: {e}" ) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) with torch.no_grad(): decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids.to(device) return processor, model, device, decoder_input_ids def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids): import torch pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) gen_kwargs = dict( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, max_length=1536, num_beams=4, early_stopping=False, ) with torch.no_grad(): generated_ids = model.generate(**gen_kwargs) raw_pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() cleaned = (raw_pred .replace(processor.tokenizer.eos_token or "", "") .replace(processor.tokenizer.pad_token or "", "") .strip()) token2json_out = processor.token2json(cleaned) if isinstance(token2json_out, str): try: pred_dict = json.loads(token2json_out) except Exception: pred_dict = token2json_out else: pred_dict = token2json_out return pred_dict # ----------------------------- # ORIGINAL (previous) mapping logic — restored verbatim # ----------------------------- def map_prediction_to_ui(pred): import json, re from collections import defaultdict def safe_json_load(s): if s is None: return None if isinstance(s, (dict, list)): return s if isinstance(s, str): s = s.strip() if s == "": return None try: return json.loads(s) except Exception: subs = [] stack = [] start = None for i, ch in enumerate(s): if ch == "{": if not stack: start = i stack.append("{") elif ch == "}": if stack: stack.pop() if not stack and start is not None: subs.append(s[start:i+1]) start = None for sub in subs: try: return json.loads(sub) except Exception: continue return None def clean_number(x): if x is None: return 0.0 if isinstance(x, (int, float)): return float(x) s = str(x).strip() if s == "": return 0.0 s = re.sub(r"[,\s]", "", s) s = re.sub(r"[^\d\.\-]", "", s) if s in ("", ".", "-", "-."): return 0.0 try: return float(s) except Exception: return 0.0 def collect_keys(obj, out): if isinstance(obj, dict): for k, v in obj.items(): lk = str(k).strip().lower() out[lk].append(v) collect_keys(v, out) elif isinstance(obj, list): for it in obj: collect_keys(it, out) def collect_lists_of_dicts(obj, out_lists): if isinstance(obj, dict): for v in obj.values(): if isinstance(v, list) and v and isinstance(v[0], dict): out_lists.append(v) else: collect_lists_of_dicts(v, out_lists) elif isinstance(obj, list): for it in obj: if isinstance(it, list) and it and isinstance(it[0], dict): out_lists.append(it) else: collect_lists_of_dicts(it, out_lists) def map_item_dict(it): if not isinstance(it, dict): return None lower = {str(k).strip().lower(): v for k, v in it.items()} desc = (lower.get("descriptions") or lower.get("description") or lower.get("desc") or lower.get("item") or "") qty = lower.get("quantity") or lower.get("qty") or lower.get("count") or "" unit_price = lower.get("unit_price") or lower.get("price") or "" amount = lower.get("amount") or lower.get("line_total") or lower.get("line total") or lower.get("total") or "" tax = lower.get("tax") or lower.get("tax_amount") or "" line_total = lower.get("line_total") or lower.get("line_total".lower()) or lower.get("line total") or amount return { "Description": str(desc).strip(), "Quantity": float(clean_number(qty)), "Unit Price": float(clean_number(unit_price)), "Amount": float(clean_number(amount)), "Tax": float(clean_number(tax)), "Line Total": float(clean_number(line_total)) } parsed = safe_json_load(pred) if isinstance(pred, str) else pred if parsed is None and isinstance(pred, str): parsed = None if parsed is None and not isinstance(pred, dict): parsed = pred ui = { "Invoice Number": "", "Invoice Date": "", "Due Date": "", "Currency": "", "Subtotal": 0.0, "Tax Percentage": 0.0, "Total Tax": 0.0, "Total Amount": 0.0, "Sender": {"Name": "", "Address": ""}, "Recipient": {"Name": "", "Address": ""}, "Sender Name": "", "Sender Address": "", "Recipient Name": "", "Recipient Address": "", "Bank Details": {}, "Itemized Data": [] } key_map = defaultdict(list) list_candidates = [] if isinstance(parsed, dict): collect_keys(parsed, key_map) collect_lists_of_dicts(parsed, list_candidates) elif isinstance(pred, dict): collect_keys(pred, key_map) collect_lists_of_dicts(pred, list_candidates) def pick_first(*candidate_keys): for k in candidate_keys: lk = k.strip().lower() if lk in key_map: for v in key_map[lk]: if v is None: continue if isinstance(v, (dict, list)): return v s = str(v).strip() if s != "": return s return None ui["Invoice Number"] = pick_first("invoice_no", "invoice_number", "invoiceid", "invoice id") or "" ui["Invoice Date"] = pick_first("invoice_date", "date", "invoice date") or "" ui["Due Date"] = pick_first("due_date", "due_date", "due") or "" ui["Sender Name"] = pick_first("sender_name", "sender") or "" ui["Sender Address"] = pick_first("sender_addr", "sender_address", "sender addr") or "" ui["Recipient Name"] = pick_first("rcpt_name", "recipient_name", "recipient", "rcpt") or "" ui["Recipient Address"] = pick_first("rcpt_addr", "recipient_address", "recipient addr") or "" bank = {} for bk in ("bank_name", "bank_acc_no", "bank_account_number", "bank_acc_name", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "iban"): val = pick_first(bk, bk.replace("bank_", "")) if val: if bk == "iban": bank["bank_iban"] = str(val) else: bank[bk if bk != "bank_acc_no" else "bank_account_number"] = str(val) ui["Bank Details"] = bank ui["Subtotal"] = clean_number(pick_first("subtotal", "sub_total", "sub total") or 0.0) ui["Tax Percentage"] = clean_number(pick_first("tax_rate", "tax_percentage", "tax pct", "tax percentage") or 0.0) ui["Total Tax"] = clean_number(pick_first("tax_amount", "tax", "total_tax") or 0.0) ui["Total Amount"] = clean_number(pick_first("total_amount", "grand_total", "total", "amount") or 0.0) ui["Currency"] = (pick_first("currency") or "").strip() items_rows = [] def list_looks_like_items(lst): if not isinstance(lst, list) or not lst: return False if not isinstance(lst[0], dict): return False expected = {"descriptions", "description", "desc", "item", "quantity", "qty", "amount", "unit_price", "line_total", "line_total".lower(), "line_total"} keys0 = {str(k).strip().lower() for k in lst[0].keys()} return bool(expected.intersection(keys0)) for cand in list_candidates: if list_looks_like_items(cand): for it in cand: row = map_item_dict(it) if row is not None: items_rows.append(row) if items_rows: break if not items_rows: single_candidate_keys = {k.strip().lower() for k in (parsed.keys() if isinstance(parsed, dict) else [])} if isinstance(parsed, dict) else set() item_like_keys = {"descriptions", "description", "desc", "item", "quantity", "qty", "unit_price", "unit price", "price", "amount", "line_total", "line total", "line_total", "line_total".lower(), "sku", "tax", "tax_amount"} if single_candidate_keys and single_candidate_keys.intersection(item_like_keys): single_row = map_item_dict(parsed) if single_row is not None: items_rows.append(single_row) if not items_rows: for k, vals in key_map.items(): for v in vals: if isinstance(v, dict): lower_keys = {str(x).strip().lower() for x in v.keys()} if lower_keys.intersection({"descriptions", "description", "desc", "amount", "line_total", "quantity", "qty", "unit_price"}): row = map_item_dict(v) if row is not None: items_rows.append(row) if not items_rows: desc = pick_first("descriptions", "description") amt = pick_first("amount", "line_total") qty = pick_first("quantity", "qty") unit_price = pick_first("unit_price", "price") if desc or amt or qty or unit_price: items_rows.append({ "Description": str(desc or ""), "Quantity": float(clean_number(qty)), "Unit Price": float(clean_number(unit_price)), "Amount": float(clean_number(amt)), "Tax": float(clean_number(pick_first("tax", "tax_amount") or 0.0)), "Line Total": float(clean_number(amt or 0.0)) }) ui["Itemized Data"] = items_rows ui["Sender"] = {"Name": ui["Sender Name"], "Address": ui["Sender Address"]} ui["Recipient"] = {"Name": ui["Recipient Name"], "Address": ui["Recipient Address"]} return ui def flatten_invoice_to_rows(invoice_data) -> list: EXPECTED_BANK_FIELDS = [ "bank_name", "bank_account_number", "bank_acc_name", "bank_iban", "bank_swift", "bank_routing", "bank_branch" ] rows = [] invoice_data = invoice_data or {} line_items = invoice_data.get("Itemized Data", []) or [] bank_details = {} nested = invoice_data.get("Bank Details", {}) or {} if isinstance(nested, dict): for k, v in nested.items(): key_name = k if str(k).startswith("bank_") else f"bank_{k}" bank_details[key_name] = v for k, v in invoice_data.items(): if isinstance(k, str) and k.lower().startswith("bank_"): bank_details[k] = v for f in EXPECTED_BANK_FIELDS: bank_details.setdefault(f, "") def base_invoice_info(): return { "Invoice Number": invoice_data.get("Invoice Number", ""), "Invoice Date": invoice_data.get("Invoice Date", ""), "Due Date": invoice_data.get("Due Date", ""), "Currency": invoice_data.get("Currency", ""), "Subtotal": invoice_data.get("Subtotal", 0.0), "Tax Percentage": invoice_data.get("Tax Percentage", 0.0), "Total Tax": invoice_data.get("Total Tax", 0.0), "Total Amount": invoice_data.get("Total Amount", 0.0), "Sender Name": invoice_data.get("Sender Name", "") or (invoice_data.get("Sender",{}) or {}).get("Name",""), "Sender Address": invoice_data.get("Sender Address", "") or (invoice_data.get("Sender",{}) or {}).get("Address",""), "Recipient Name": invoice_data.get("Recipient Name", "") or (invoice_data.get("Recipient",{}) or {}).get("Name",""), "Recipient Address": invoice_data.get("Recipient Address", "") or (invoice_data.get("Recipient",{}) or {}).get("Address",""), } if not line_items: row = base_invoice_info() for k in EXPECTED_BANK_FIELDS: row[k] = bank_details.get(k, "") row.update({ "Item Description": "", "Item Quantity": 0, "Item Unit Price": 0.0, "Item Amount": 0.0, "Item Tax": 0.0, "Item Line Total": 0.0, }) rows.append(row) return rows for item in line_items: row = base_invoice_info() for k in EXPECTED_BANK_FIELDS: row[k] = bank_details.get(k, "") row.update({ "Item Description": item.get("Description", "") if isinstance(item, dict) else "", "Item Quantity": item.get("Quantity", 0) if isinstance(item, dict) else 0, "Item Unit Price": item.get("Unit Price", 0.0) if isinstance(item, dict) else 0.0, "Item Amount": item.get("Amount", 0.0) if isinstance(item, dict) else 0.0, "Item Tax": item.get("Tax", 0.0) if isinstance(item, dict) else 0.0, "Item Line Total": item.get("Line Total", item.get("Amount", 0.0)) if isinstance(item, dict) else 0.0, }) rows.append(row) return rows # ----------------------------- # Load model # ----------------------------- try: with st.spinner("Loading model & processor (cached) ..."): processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT) except Exception as e: st.error("Could not load model automatically. See details below.") st.exception(e) st.stop() # ----------------------------- # Session scaffolding # ----------------------------- if "batch_results" not in st.session_state: st.session_state.batch_results = {} if "current_file_hash" not in st.session_state: st.session_state.current_file_hash = None if "is_processing_batch" not in st.session_state: st.session_state.is_processing_batch = False # ----------------------------- # Pre-mount two-column skeleton to avoid layout jump # ----------------------------- frame_left, frame_right = st.columns([1, 1], vertical_alignment="top") # ----------------------------- # Upload / Process # ----------------------------- if not st.session_state.is_processing_batch and len(st.session_state.batch_results) == 0: with frame_left: st.header("📤 Upload Invoices") uploaded_files = st.file_uploader( "Upload invoice images (png/jpg/jpeg/pdf)", type=["png", "jpg", "jpeg", "pdf"], accept_multiple_files=True ) if uploaded_files: st.session_state.is_processing_batch = True progress_bar = st.progress(0) status_text = st.empty() for idx, uploaded_file in enumerate(uploaded_files): status_text.text(f"Processing {idx+1}/{len(uploaded_files)}: {uploaded_file.name}") uploaded_bytes = uploaded_file.read() file_hash = hashlib.sha256(uploaded_bytes).hexdigest() if file_hash in st.session_state.batch_results: progress_bar.progress((idx + 1) / len(uploaded_files)) continue # Load image (first page for PDFs) image = None is_pdf = uploaded_file.name.lower().endswith('.pdf') or (hasattr(uploaded_file, 'type') and uploaded_file.type == 'application/pdf') if is_pdf: if convert_from_bytes is None: st.warning(f"PDF {uploaded_file.name} could not be rendered (pdf2image/poppler missing).") continue try: pages = convert_from_bytes(uploaded_bytes, dpi=200) if len(pages) > 0: image = pages[0].convert("RGB") else: st.warning(f"PDF {uploaded_file.name} has no pages.") continue except Exception: st.warning(f"Could not render PDF {uploaded_file.name}. Ensure 'pdf2image' and poppler are installed.") continue else: try: image = Image.open(BytesIO(uploaded_bytes)).convert("RGB") except Exception: st.warning(f"Failed to open {uploaded_file.name}.") continue if image is None: continue # Inference + mapping try: pred = run_inference_on_image(image, processor, model, device, decoder_input_ids) mapped = map_prediction_to_ui(pred) except Exception as e: st.warning(f"Error processing {uploaded_file.name}: {str(e)}") pred = None mapped = {} safe_mapped = mapped if isinstance(mapped, dict) else {} st.session_state.batch_results[file_hash] = { "file_name": uploaded_file.name, "image": image, "raw_pred": pred, "mapped_data": safe_mapped, "edited_data": safe_mapped.copy() } progress_bar.progress((idx + 1) / len(uploaded_files)) status_text.text("✅ All files processed!") st.session_state.is_processing_batch = False st.rerun() with frame_right: st.caption("Preview & editor will appear here after extraction.") elif len(st.session_state.batch_results) > 0: # --------- Top row: All-results download + Back button ---------- with frame_left: all_rows = [] for file_hash, result in st.session_state.batch_results.items(): rows = flatten_invoice_to_rows(result["edited_data"]) for r in rows: r["Source File"] = result.get("file_name", file_hash) all_rows.extend(rows) if all_rows: full_df = pd.DataFrame(all_rows) cols = list(full_df.columns) if "Source File" in cols: cols = ["Source File"] + [c for c in cols if c != "Source File"] full_df = full_df[cols] csv_bytes = full_df.to_csv(index=False).encode("utf-8") st.download_button("📦 Download All Results (CSV)", csv_bytes, file_name="all_extracted_invoices.csv", mime="text/csv", key="download_all_csv") with frame_right: if st.button("⬅️ Back to Upload"): st.session_state.batch_results.clear() st.session_state.current_file_hash = None st.session_state.is_processing_batch = False st.rerun() # --------- Selector ---------- with frame_left: file_options = {f"{v['file_name']} ({k[:6]})": k for k, v in st.session_state.batch_results.items()} selected_display = st.selectbox("Select invoice to view/edit:", options=list(file_options.keys()), index=0, key="file_selector") selected_hash = file_options[selected_display] if st.session_state.current_file_hash != selected_hash: st.session_state.current_file_hash = selected_hash current = st.session_state.batch_results[selected_hash] image = current["image"] form_data = current["edited_data"] # --------- Initialize widget state ONCE (no value= in widgets) ---------- bank = form_data.get("Bank Details", {}) if isinstance(form_data.get("Bank Details", {}), dict) else {} ensure_state(f"Invoice Number_{selected_hash}", form_data.get('Invoice Number', '')) ensure_state(f"Invoice Date_{selected_hash}", str(form_data.get('Invoice Date', '')).strip()) ensure_state(f"Due Date_{selected_hash}", str(form_data.get('Due Date', '')).strip()) ensure_state(f"Currency_{selected_hash}", form_data.get('Currency', 'USD') or 'USD') ensure_state(f"Currency_Custom_{selected_hash}", form_data.get('Currency', '') if form_data.get('Currency') not in ['USD','EUR','GBP','INR'] else '') ensure_state(f"Subtotal_{selected_hash}", float(form_data.get('Subtotal', 0.0))) ensure_state(f"Tax Percentage_{selected_hash}", float(form_data.get('Tax Percentage', 0.0))) ensure_state(f"Total Tax_{selected_hash}", float(form_data.get('Total Tax', 0.0))) ensure_state(f"Total Amount_{selected_hash}", float(form_data.get('Total Amount', 0.0))) ensure_state(f"Sender Name_{selected_hash}", form_data.get('Sender Name', '')) ensure_state(f"Sender Address_{selected_hash}", form_data.get('Sender Address', '')) ensure_state(f"Recipient Name_{selected_hash}", form_data.get('Recipient Name', '')) ensure_state(f"Recipient Address_{selected_hash}", form_data.get('Recipient Address', '')) ensure_state(f"Bank_bank_name_{selected_hash}", bank.get('bank_name', '')) ensure_state(f"Bank_bank_account_number_{selected_hash}", bank.get('bank_account_number', '') or bank.get('bank_acc_no', '')) ensure_state(f"Bank_bank_acc_name_{selected_hash}", bank.get('bank_acc_name', '')) ensure_state(f"Bank_bank_iban_{selected_hash}", bank.get('bank_iban', '')) ensure_state(f"Bank_bank_swift_{selected_hash}", bank.get('bank_swift', '')) ensure_state(f"Bank_bank_routing_{selected_hash}", bank.get('bank_routing', '')) ensure_state(f"Bank_bank_branch_{selected_hash}", bank.get('bank_branch', '')) # --------- Display (no wobble) ---------- with frame_left: st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH) st.write(f"**File Hash:** {selected_hash[:8]}...") if current.get('raw_pred') is not None: with st.expander("🔍 Show raw model output"): st.json(current['raw_pred']) if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"): with st.spinner("Re-running inference..."): try: pred = run_inference_on_image(image, processor, model, device, decoder_input_ids) mapped = map_prediction_to_ui(pred) safe_mapped = mapped if isinstance(mapped, dict) else {} # Update stored results st.session_state.batch_results[selected_hash]["raw_pred"] = pred st.session_state.batch_results[selected_hash]["mapped_data"] = mapped st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy() # Clear widget state for this file so defaults refresh from new mapped data for key in [k for k in st.session_state.keys() if k.endswith(f"_{selected_hash}")]: del st.session_state[key] st.success("✅ Re-run complete") st.rerun() except Exception as e: st.error(f"Re-run failed: {e}") with frame_right: st.subheader(f"Editable Invoice: {current['file_name']}") # Quick swap outside the form (one clean rerun) swap_cols = st.columns([1,1,2]) with swap_cols[0]: if st.button("⇄ Swap Sender ↔ Recipient", key=f"swap_{selected_hash}"): sn = f"Sender Name_{selected_hash}" rn = f"Recipient Name_{selected_hash}" sa = f"Sender Address_{selected_hash}" ra = f"Recipient Address_{selected_hash}" st.session_state[sn], st.session_state[rn] = st.session_state[rn], st.session_state[sn] st.session_state[sa], st.session_state[ra] = st.session_state[ra], st.session_state[sa] st.rerun() # ----------------- FORM START ----------------- with st.form(key=f"edit_form_{selected_hash}", clear_on_submit=False): tabs = st.tabs(["Invoice Details", "Sender/Recipient", "Bank Details", "Line Items"]) with tabs[0]: st.text_input("Invoice Number", key=f"Invoice Number_{selected_hash}") st.text_input("Invoice Date", key=f"Invoice Date_{selected_hash}") st.text_input("Due Date", key=f"Due Date_{selected_hash}") curr_options = ['USD', 'EUR', 'GBP', 'INR', 'Other'] if st.session_state[f"Currency_{selected_hash}"] not in curr_options: st.session_state[f"Currency_{selected_hash}"] = 'Other' st.selectbox("Currency", options=curr_options, key=f"Currency_{selected_hash}") if st.session_state.get(f"Currency_{selected_hash}") == 'Other': st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}") st.number_input("Subtotal", key=f"Subtotal_{selected_hash}") st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}") st.number_input("Total Tax", key=f"Total Tax_{selected_hash}") st.number_input("Total Amount", key=f"Total Amount_{selected_hash}") with tabs[1]: st.text_input("Sender Name", key=f"Sender Name_{selected_hash}") st.text_area("Sender Address", key=f"Sender Address_{selected_hash}", height=80) st.text_input("Recipient Name", key=f"Recipient Name_{selected_hash}") st.text_area("Recipient Address", key=f"Recipient Address_{selected_hash}", height=80) with tabs[2]: st.text_input("Bank Name", key=f"Bank_bank_name_{selected_hash}") st.text_input("Account Number", key=f"Bank_bank_account_number_{selected_hash}") st.text_input("Account Name", key=f"Bank_bank_acc_name_{selected_hash}") st.text_input("IBAN", key=f"Bank_bank_iban_{selected_hash}") st.text_input("SWIFT", key=f"Bank_bank_swift_{selected_hash}") st.text_input("Routing", key=f"Bank_bank_routing_{selected_hash}") st.text_input("Branch", key=f"Bank_bank_branch_{selected_hash}") with tabs[3]: # Build base DF from current edited_data (not raw mapped) so it's always what the user last saved item_rows = form_data.get('Itemized Data', []) or [] normalized = [] for it in item_rows: if not isinstance(it, dict): it = {} normalized.append({ "Description": it.get("Description", it.get("Item Description", "")), "Quantity": it.get("Quantity", it.get("Item Quantity", 0)), "Unit Price": it.get("Unit Price", it.get("Item Unit Price", 0.0)), "Amount": it.get("Amount", it.get("Item Amount", 0.0)), "Tax": it.get("Tax", it.get("Item Tax", 0.0)), "Line Total": it.get("Line Total", it.get("Item Line Total", 0.0)), }) items_df = pd.DataFrame(normalized) if normalized else pd.DataFrame( columns=["Description", "Quantity", "Unit Price", "Amount", "Tax", "Line Total"] ) edited_df = st.data_editor( items_df, num_rows="dynamic", key=f"items_editor_{selected_hash}", use_container_width=True, height=DATA_EDITOR_HEIGHT, ) saved = st.form_submit_button("💾 Save All Edits") # ----------------- FORM END ----------------- if saved: currency = st.session_state.get(f"Currency_{selected_hash}", 'USD') if currency == 'Other': currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '') updated = { 'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''), 'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''), 'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''), 'Currency': currency, 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0), 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0), 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0), 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0), 'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''), 'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''), 'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''), 'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''), 'Bank Details': { 'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''), 'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''), 'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''), 'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''), 'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''), 'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''), 'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '') }, 'Itemized Data': edited_df.to_dict('records'), 'Sender': {"Name": st.session_state.get(f"Sender Name_{selected_hash}", ''), "Address": st.session_state.get(f"Sender Address_{selected_hash}", '')}, 'Recipient': {"Name": st.session_state.get(f"Recipient Name_{selected_hash}", ''), "Address": st.session_state.get(f"Recipient Address_{selected_hash}", '')}, } st.session_state.batch_results[selected_hash]["edited_data"] = updated st.success(f"✅ Saved: {current['file_name']}") # Per-file CSV download (uses the current editor contents even if not saved) d_currency = st.session_state.get(f"Currency_{selected_hash}", 'USD') if d_currency == 'Other': d_currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '') download_data = { 'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''), 'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''), 'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''), 'Currency': d_currency, 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0), 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0), 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0), 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0), 'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''), 'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''), 'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''), 'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''), 'Bank Details': { 'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''), 'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''), 'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''), 'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''), 'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''), 'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''), 'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '') }, 'Itemized Data': edited_df.to_dict('records') } rows = flatten_invoice_to_rows(download_data) full_df = pd.DataFrame(rows) csv_bytes_one = full_df.to_csv(index=False).encode("utf-8") st.download_button( "📥 Download This Invoice (CSV)", csv_bytes_one, file_name=f"{Path(current['file_name']).stem}_full.csv", mime="text/csv", key=f"dl_{selected_hash}" ) elif st.session_state.is_processing_batch: with frame_left: st.info("⏳ Processing batch... Please wait.") st.progress(0) with frame_right: st.caption("Preview & editor will appear here after extraction.") else: # Shouldn't happen, but keeps skeleton steady with frame_left: st.caption("Ready when you are.") with frame_right: st.caption("Preview & editor will appear here after extraction.")