| | |
| | |
| | |
| | import os |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| | _home = os.environ.get("HOME", "") |
| | if _home in ("", "/", None): |
| | repo_dir = os.getcwd() |
| | safe_home = repo_dir if os.access(repo_dir, os.W_OK) else "/tmp" |
| | os.environ["HOME"] = safe_home |
| | print(f"[startup] HOME not set or unwritable — setting HOME={safe_home}") |
| |
|
| | streamlit_dir = Path(os.environ["HOME"]) / ".streamlit" |
| | try: |
| | streamlit_dir.mkdir(parents=True, exist_ok=True) |
| | print(f"[startup] ensured {streamlit_dir}") |
| | except Exception as e: |
| | print(f"[startup] WARNING: could not create {streamlit_dir}: {e}") |
| |
|
| | |
| | |
| | |
| | import json |
| | from io import BytesIO |
| | import hashlib |
| | from typing import Dict, Any |
| |
|
| | import streamlit as st |
| | import pandas as pd |
| | from PIL import Image |
| | from huggingface_hub import login |
| |
|
| | |
| | try: |
| | from pdf2image import convert_from_bytes |
| | except Exception: |
| | convert_from_bytes = None |
| |
|
| | |
| | |
| | |
| | st.set_page_config(page_title="Invoice Extractor (Donut) - Batch Mode", layout="wide") |
| | st.title("Invoice Extraction") |
| |
|
| | st.markdown( |
| | """ |
| | <style> |
| | .stApp { background-color: #ECECEC !important; } |
| | div.block-container { padding-top: 1rem; padding-bottom: 1rem; } |
| | [data-testid="stSidebar"] { background-color: #F7F7F7 !important; } |
| | div[data-testid="stTabs"] > div > div { padding-bottom: 6px !important; } |
| | /* Keep right column steady on first render post-extraction */ |
| | [data-testid="column"]:nth-of-type(2) { min-height: 780px; } |
| | </style> |
| | """, |
| | unsafe_allow_html=True |
| | ) |
| |
|
| | |
| | FIXED_IMG_WIDTH = 640 |
| | DATA_EDITOR_HEIGHT = 380 |
| |
|
| | |
| | |
| | |
| | def ensure_state(k: str, default): |
| | """Initialize a session_state key once, then let widgets bind to it via key=... (no value=...).""" |
| | if k not in st.session_state: |
| | st.session_state[k] = default |
| |
|
| | def clean_float(x) -> float: |
| | import re |
| | if x is None: |
| | return 0.0 |
| | if isinstance(x, (int, float)): |
| | return float(x) |
| | s = str(x).strip() |
| | if s == "": |
| | return 0.0 |
| | s = re.sub(r"[,\s]", "", s) |
| | s = re.sub(r"[^\d\.\-]", "", s) |
| | if s in ("", ".", "-", "-."): |
| | return 0.0 |
| | try: |
| | return float(s) |
| | except Exception: |
| | return 0.0 |
| |
|
| | |
| | |
| | |
| | def _get_hf_token(): |
| | if st.session_state.get("_hf_token"): |
| | return st.session_state.get("_hf_token"), "session" |
| |
|
| | env_tok = os.getenv("HF_TOKEN") |
| | if env_tok: |
| | return env_tok, "env" |
| |
|
| | try: |
| | sec = st.secrets.get("HF_TOKEN", None) |
| | if sec: |
| | return sec, "secrets" |
| | except Exception: |
| | pass |
| |
|
| | return None, None |
| |
|
| | hf_token, hf_token_source = _get_hf_token() |
| |
|
| | if hf_token is None: |
| | st.subheader("Login Token 🔑") |
| | token_input = st.text_input("Enter your Hugging Face token (starts with 'hf_'):", type="password") |
| | if token_input: |
| | if not token_input.startswith("hf_"): |
| | st.error("Invalid token format. Token must start with 'hf_'.") |
| | st.stop() |
| | try: |
| | login(token_input) |
| | st.session_state["_hf_token"] = token_input |
| | st.session_state.logged_in = True |
| | st.success("Logged in successfully. Loading model...") |
| | st.rerun() |
| | except Exception as e: |
| | st.error(f"Failed to log in: {e}") |
| | st.stop() |
| | else: |
| | st.warning("Provide a token here or set HF_TOKEN in the environment.") |
| | st.stop() |
| | else: |
| | try: |
| | login(hf_token) |
| | st.session_state.logged_in = True |
| | except Exception as e: |
| | st.error(f"Failed to log in with {hf_token_source or 'unknown'} token: {e}") |
| | st.stop() |
| |
|
| | |
| | |
| | |
| | HF_MODEL_ID = "Bhuvi13/model-V7" |
| | TASK_PROMPT = "<s_cord-v2>" |
| |
|
| | @st.cache_resource(show_spinner=False) |
| | def load_model_and_processor(hf_model_id: str, task_prompt: str): |
| | try: |
| | import torch |
| | from transformers import VisionEncoderDecoderModel, DonutProcessor |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to import ML libraries: {e}") |
| |
|
| | try: |
| | processor = DonutProcessor.from_pretrained(hf_model_id) |
| | model = VisionEncoderDecoderModel.from_pretrained(hf_model_id) |
| | except Exception as e: |
| | raise RuntimeError( |
| | f"Failed to load model/processor from Hugging Face ({hf_model_id}). " |
| | f"Original error: {e}" |
| | ) |
| |
|
| | model.eval() |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| |
|
| | with torch.no_grad(): |
| | decoder_input_ids = processor.tokenizer( |
| | task_prompt, |
| | add_special_tokens=False, |
| | return_tensors="pt" |
| | ).input_ids.to(device) |
| |
|
| | return processor, model, device, decoder_input_ids |
| |
|
| | def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids): |
| | import torch |
| | pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) |
| | gen_kwargs = dict( |
| | pixel_values=pixel_values, |
| | decoder_input_ids=decoder_input_ids, |
| | max_length=1536, |
| | num_beams=4, |
| | early_stopping=False, |
| | ) |
| | with torch.no_grad(): |
| | generated_ids = model.generate(**gen_kwargs) |
| |
|
| | raw_pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
| | cleaned = (raw_pred |
| | .replace(processor.tokenizer.eos_token or "", "") |
| | .replace(processor.tokenizer.pad_token or "", "") |
| | .strip()) |
| | token2json_out = processor.token2json(cleaned) |
| |
|
| | if isinstance(token2json_out, str): |
| | try: |
| | pred_dict = json.loads(token2json_out) |
| | except Exception: |
| | pred_dict = token2json_out |
| | else: |
| | pred_dict = token2json_out |
| | return pred_dict |
| |
|
| | |
| | |
| | |
| | def map_prediction_to_ui(pred): |
| | import json, re |
| | from collections import defaultdict |
| |
|
| | def safe_json_load(s): |
| | if s is None: |
| | return None |
| | if isinstance(s, (dict, list)): |
| | return s |
| | if isinstance(s, str): |
| | s = s.strip() |
| | if s == "": |
| | return None |
| | try: |
| | return json.loads(s) |
| | except Exception: |
| | subs = [] |
| | stack = [] |
| | start = None |
| | for i, ch in enumerate(s): |
| | if ch == "{": |
| | if not stack: |
| | start = i |
| | stack.append("{") |
| | elif ch == "}": |
| | if stack: |
| | stack.pop() |
| | if not stack and start is not None: |
| | subs.append(s[start:i+1]) |
| | start = None |
| | for sub in subs: |
| | try: |
| | return json.loads(sub) |
| | except Exception: |
| | continue |
| | return None |
| |
|
| | def clean_number(x): |
| | if x is None: |
| | return 0.0 |
| | if isinstance(x, (int, float)): |
| | return float(x) |
| | s = str(x).strip() |
| | if s == "": |
| | return 0.0 |
| | s = re.sub(r"[,\s]", "", s) |
| | s = re.sub(r"[^\d\.\-]", "", s) |
| | if s in ("", ".", "-", "-."): |
| | return 0.0 |
| | try: |
| | return float(s) |
| | except Exception: |
| | return 0.0 |
| |
|
| | def collect_keys(obj, out): |
| | if isinstance(obj, dict): |
| | for k, v in obj.items(): |
| | lk = str(k).strip().lower() |
| | out[lk].append(v) |
| | collect_keys(v, out) |
| | elif isinstance(obj, list): |
| | for it in obj: |
| | collect_keys(it, out) |
| |
|
| | def collect_lists_of_dicts(obj, out_lists): |
| | if isinstance(obj, dict): |
| | for v in obj.values(): |
| | if isinstance(v, list) and v and isinstance(v[0], dict): |
| | out_lists.append(v) |
| | else: |
| | collect_lists_of_dicts(v, out_lists) |
| | elif isinstance(obj, list): |
| | for it in obj: |
| | if isinstance(it, list) and it and isinstance(it[0], dict): |
| | out_lists.append(it) |
| | else: |
| | collect_lists_of_dicts(it, out_lists) |
| |
|
| | def map_item_dict(it): |
| | if not isinstance(it, dict): |
| | return None |
| | lower = {str(k).strip().lower(): v for k, v in it.items()} |
| | desc = (lower.get("descriptions") or lower.get("description") or lower.get("desc") or lower.get("item") or "") |
| | qty = lower.get("quantity") or lower.get("qty") or lower.get("count") or "" |
| | unit_price = lower.get("unit_price") or lower.get("price") or "" |
| | amount = lower.get("amount") or lower.get("line_total") or lower.get("line total") or lower.get("total") or "" |
| | tax = lower.get("tax") or lower.get("tax_amount") or "" |
| | line_total = lower.get("line_total") or lower.get("line_total".lower()) or lower.get("line total") or amount |
| |
|
| | return { |
| | "Description": str(desc).strip(), |
| | "Quantity": float(clean_number(qty)), |
| | "Unit Price": float(clean_number(unit_price)), |
| | "Amount": float(clean_number(amount)), |
| | "Tax": float(clean_number(tax)), |
| | "Line Total": float(clean_number(line_total)) |
| | } |
| |
|
| | parsed = safe_json_load(pred) if isinstance(pred, str) else pred |
| | if parsed is None and isinstance(pred, str): |
| | parsed = None |
| |
|
| | if parsed is None and not isinstance(pred, dict): |
| | parsed = pred |
| |
|
| | ui = { |
| | "Invoice Number": "", |
| | "Invoice Date": "", |
| | "Due Date": "", |
| | "Currency": "", |
| | "Subtotal": 0.0, |
| | "Tax Percentage": 0.0, |
| | "Total Tax": 0.0, |
| | "Total Amount": 0.0, |
| | "Sender": {"Name": "", "Address": ""}, |
| | "Recipient": {"Name": "", "Address": ""}, |
| | "Sender Name": "", |
| | "Sender Address": "", |
| | "Recipient Name": "", |
| | "Recipient Address": "", |
| | "Bank Details": {}, |
| | "Itemized Data": [] |
| | } |
| |
|
| | key_map = defaultdict(list) |
| | list_candidates = [] |
| | if isinstance(parsed, dict): |
| | collect_keys(parsed, key_map) |
| | collect_lists_of_dicts(parsed, list_candidates) |
| | elif isinstance(pred, dict): |
| | collect_keys(pred, key_map) |
| | collect_lists_of_dicts(pred, list_candidates) |
| |
|
| | def pick_first(*candidate_keys): |
| | for k in candidate_keys: |
| | lk = k.strip().lower() |
| | if lk in key_map: |
| | for v in key_map[lk]: |
| | if v is None: |
| | continue |
| | if isinstance(v, (dict, list)): |
| | return v |
| | s = str(v).strip() |
| | if s != "": |
| | return s |
| | return None |
| |
|
| | ui["Invoice Number"] = pick_first("invoice_no", "invoice_number", "invoiceid", "invoice id") or "" |
| | ui["Invoice Date"] = pick_first("invoice_date", "date", "invoice date") or "" |
| | ui["Due Date"] = pick_first("due_date", "due_date", "due") or "" |
| | ui["Sender Name"] = pick_first("sender_name", "sender") or "" |
| | ui["Sender Address"] = pick_first("sender_addr", "sender_address", "sender addr") or "" |
| | ui["Recipient Name"] = pick_first("rcpt_name", "recipient_name", "recipient", "rcpt") or "" |
| | ui["Recipient Address"] = pick_first("rcpt_addr", "recipient_address", "recipient addr") or "" |
| |
|
| | bank = {} |
| | for bk in ("bank_name", "bank_acc_no", "bank_account_number", "bank_acc_name", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "iban"): |
| | val = pick_first(bk, bk.replace("bank_", "")) |
| | if val: |
| | if bk == "iban": |
| | bank["bank_iban"] = str(val) |
| | else: |
| | bank[bk if bk != "bank_acc_no" else "bank_account_number"] = str(val) |
| | ui["Bank Details"] = bank |
| |
|
| | ui["Subtotal"] = clean_number(pick_first("subtotal", "sub_total", "sub total") or 0.0) |
| | ui["Tax Percentage"] = clean_number(pick_first("tax_rate", "tax_percentage", "tax pct", "tax percentage") or 0.0) |
| | ui["Total Tax"] = clean_number(pick_first("tax_amount", "tax", "total_tax") or 0.0) |
| | ui["Total Amount"] = clean_number(pick_first("total_amount", "grand_total", "total", "amount") or 0.0) |
| | ui["Currency"] = (pick_first("currency") or "").strip() |
| |
|
| | items_rows = [] |
| |
|
| | def list_looks_like_items(lst): |
| | if not isinstance(lst, list) or not lst: |
| | return False |
| | if not isinstance(lst[0], dict): |
| | return False |
| | expected = {"descriptions", "description", "desc", "item", "quantity", "qty", "amount", "unit_price", "line_total", "line_total".lower(), "line_total"} |
| | keys0 = {str(k).strip().lower() for k in lst[0].keys()} |
| | return bool(expected.intersection(keys0)) |
| |
|
| | for cand in list_candidates: |
| | if list_looks_like_items(cand): |
| | for it in cand: |
| | row = map_item_dict(it) |
| | if row is not None: |
| | items_rows.append(row) |
| | if items_rows: |
| | break |
| |
|
| | if not items_rows: |
| | single_candidate_keys = {k.strip().lower() for k in (parsed.keys() if isinstance(parsed, dict) else [])} if isinstance(parsed, dict) else set() |
| | item_like_keys = {"descriptions", "description", "desc", "item", "quantity", "qty", "unit_price", "unit price", "price", "amount", "line_total", "line total", "line_total", "line_total".lower(), "sku", "tax", "tax_amount"} |
| | if single_candidate_keys and single_candidate_keys.intersection(item_like_keys): |
| | single_row = map_item_dict(parsed) |
| | if single_row is not None: |
| | items_rows.append(single_row) |
| |
|
| | if not items_rows: |
| | for k, vals in key_map.items(): |
| | for v in vals: |
| | if isinstance(v, dict): |
| | lower_keys = {str(x).strip().lower() for x in v.keys()} |
| | if lower_keys.intersection({"descriptions", "description", "desc", "amount", "line_total", "quantity", "qty", "unit_price"}): |
| | row = map_item_dict(v) |
| | if row is not None: |
| | items_rows.append(row) |
| |
|
| | if not items_rows: |
| | desc = pick_first("descriptions", "description") |
| | amt = pick_first("amount", "line_total") |
| | qty = pick_first("quantity", "qty") |
| | unit_price = pick_first("unit_price", "price") |
| | if desc or amt or qty or unit_price: |
| | items_rows.append({ |
| | "Description": str(desc or ""), |
| | "Quantity": float(clean_number(qty)), |
| | "Unit Price": float(clean_number(unit_price)), |
| | "Amount": float(clean_number(amt)), |
| | "Tax": float(clean_number(pick_first("tax", "tax_amount") or 0.0)), |
| | "Line Total": float(clean_number(amt or 0.0)) |
| | }) |
| |
|
| | ui["Itemized Data"] = items_rows |
| | ui["Sender"] = {"Name": ui["Sender Name"], "Address": ui["Sender Address"]} |
| | ui["Recipient"] = {"Name": ui["Recipient Name"], "Address": ui["Recipient Address"]} |
| |
|
| | return ui |
| |
|
| | def flatten_invoice_to_rows(invoice_data) -> list: |
| | EXPECTED_BANK_FIELDS = [ |
| | "bank_name", |
| | "bank_account_number", |
| | "bank_acc_name", |
| | "bank_iban", |
| | "bank_swift", |
| | "bank_routing", |
| | "bank_branch" |
| | ] |
| | rows = [] |
| | invoice_data = invoice_data or {} |
| | line_items = invoice_data.get("Itemized Data", []) or [] |
| |
|
| | bank_details = {} |
| | nested = invoice_data.get("Bank Details", {}) or {} |
| | if isinstance(nested, dict): |
| | for k, v in nested.items(): |
| | key_name = k if str(k).startswith("bank_") else f"bank_{k}" |
| | bank_details[key_name] = v |
| |
|
| | for k, v in invoice_data.items(): |
| | if isinstance(k, str) and k.lower().startswith("bank_"): |
| | bank_details[k] = v |
| |
|
| | for f in EXPECTED_BANK_FIELDS: |
| | bank_details.setdefault(f, "") |
| |
|
| | def base_invoice_info(): |
| | return { |
| | "Invoice Number": invoice_data.get("Invoice Number", ""), |
| | "Invoice Date": invoice_data.get("Invoice Date", ""), |
| | "Due Date": invoice_data.get("Due Date", ""), |
| | "Currency": invoice_data.get("Currency", ""), |
| | "Subtotal": invoice_data.get("Subtotal", 0.0), |
| | "Tax Percentage": invoice_data.get("Tax Percentage", 0.0), |
| | "Total Tax": invoice_data.get("Total Tax", 0.0), |
| | "Total Amount": invoice_data.get("Total Amount", 0.0), |
| | "Sender Name": invoice_data.get("Sender Name", "") or (invoice_data.get("Sender",{}) or {}).get("Name",""), |
| | "Sender Address": invoice_data.get("Sender Address", "") or (invoice_data.get("Sender",{}) or {}).get("Address",""), |
| | "Recipient Name": invoice_data.get("Recipient Name", "") or (invoice_data.get("Recipient",{}) or {}).get("Name",""), |
| | "Recipient Address": invoice_data.get("Recipient Address", "") or (invoice_data.get("Recipient",{}) or {}).get("Address",""), |
| | } |
| |
|
| | if not line_items: |
| | row = base_invoice_info() |
| | for k in EXPECTED_BANK_FIELDS: |
| | row[k] = bank_details.get(k, "") |
| | row.update({ |
| | "Item Description": "", |
| | "Item Quantity": 0, |
| | "Item Unit Price": 0.0, |
| | "Item Amount": 0.0, |
| | "Item Tax": 0.0, |
| | "Item Line Total": 0.0, |
| | }) |
| | rows.append(row) |
| | return rows |
| |
|
| | for item in line_items: |
| | row = base_invoice_info() |
| | for k in EXPECTED_BANK_FIELDS: |
| | row[k] = bank_details.get(k, "") |
| | row.update({ |
| | "Item Description": item.get("Description", "") if isinstance(item, dict) else "", |
| | "Item Quantity": item.get("Quantity", 0) if isinstance(item, dict) else 0, |
| | "Item Unit Price": item.get("Unit Price", 0.0) if isinstance(item, dict) else 0.0, |
| | "Item Amount": item.get("Amount", 0.0) if isinstance(item, dict) else 0.0, |
| | "Item Tax": item.get("Tax", 0.0) if isinstance(item, dict) else 0.0, |
| | "Item Line Total": item.get("Line Total", item.get("Amount", 0.0)) if isinstance(item, dict) else 0.0, |
| | }) |
| | rows.append(row) |
| | return rows |
| |
|
| | |
| | |
| | |
| | try: |
| | with st.spinner("Loading model & processor (cached) ..."): |
| | processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT) |
| | except Exception as e: |
| | st.error("Could not load model automatically. See details below.") |
| | st.exception(e) |
| | st.stop() |
| |
|
| | |
| | |
| | |
| | if "batch_results" not in st.session_state: |
| | st.session_state.batch_results = {} |
| | if "current_file_hash" not in st.session_state: |
| | st.session_state.current_file_hash = None |
| | if "is_processing_batch" not in st.session_state: |
| | st.session_state.is_processing_batch = False |
| |
|
| | |
| | |
| | |
| | frame_left, frame_right = st.columns([1, 1], vertical_alignment="top") |
| |
|
| | |
| | |
| | |
| | if not st.session_state.is_processing_batch and len(st.session_state.batch_results) == 0: |
| | with frame_left: |
| | st.header("📤 Upload Invoices") |
| | uploaded_files = st.file_uploader( |
| | "Upload invoice images (png/jpg/jpeg/pdf)", |
| | type=["png", "jpg", "jpeg", "pdf"], |
| | accept_multiple_files=True |
| | ) |
| |
|
| | if uploaded_files: |
| | st.session_state.is_processing_batch = True |
| | progress_bar = st.progress(0) |
| | status_text = st.empty() |
| |
|
| | for idx, uploaded_file in enumerate(uploaded_files): |
| | status_text.text(f"Processing {idx+1}/{len(uploaded_files)}: {uploaded_file.name}") |
| | uploaded_bytes = uploaded_file.read() |
| | file_hash = hashlib.sha256(uploaded_bytes).hexdigest() |
| |
|
| | if file_hash in st.session_state.batch_results: |
| | progress_bar.progress((idx + 1) / len(uploaded_files)) |
| | continue |
| |
|
| | |
| | image = None |
| | is_pdf = uploaded_file.name.lower().endswith('.pdf') or (hasattr(uploaded_file, 'type') and uploaded_file.type == 'application/pdf') |
| | if is_pdf: |
| | if convert_from_bytes is None: |
| | st.warning(f"PDF {uploaded_file.name} could not be rendered (pdf2image/poppler missing).") |
| | continue |
| | try: |
| | pages = convert_from_bytes(uploaded_bytes, dpi=200) |
| | if len(pages) > 0: |
| | image = pages[0].convert("RGB") |
| | else: |
| | st.warning(f"PDF {uploaded_file.name} has no pages.") |
| | continue |
| | except Exception: |
| | st.warning(f"Could not render PDF {uploaded_file.name}. Ensure 'pdf2image' and poppler are installed.") |
| | continue |
| | else: |
| | try: |
| | image = Image.open(BytesIO(uploaded_bytes)).convert("RGB") |
| | except Exception: |
| | st.warning(f"Failed to open {uploaded_file.name}.") |
| | continue |
| |
|
| | if image is None: |
| | continue |
| |
|
| | |
| | try: |
| | pred = run_inference_on_image(image, processor, model, device, decoder_input_ids) |
| | mapped = map_prediction_to_ui(pred) |
| | except Exception as e: |
| | st.warning(f"Error processing {uploaded_file.name}: {str(e)}") |
| | pred = None |
| | mapped = {} |
| |
|
| | safe_mapped = mapped if isinstance(mapped, dict) else {} |
| |
|
| | st.session_state.batch_results[file_hash] = { |
| | "file_name": uploaded_file.name, |
| | "image": image, |
| | "raw_pred": pred, |
| | "mapped_data": safe_mapped, |
| | "edited_data": safe_mapped.copy() |
| | } |
| |
|
| | progress_bar.progress((idx + 1) / len(uploaded_files)) |
| |
|
| | status_text.text("✅ All files processed!") |
| | st.session_state.is_processing_batch = False |
| | st.rerun() |
| |
|
| | with frame_right: |
| | st.caption("Preview & editor will appear here after extraction.") |
| |
|
| | elif len(st.session_state.batch_results) > 0: |
| |
|
| | |
| | with frame_left: |
| | all_rows = [] |
| | for file_hash, result in st.session_state.batch_results.items(): |
| | rows = flatten_invoice_to_rows(result["edited_data"]) |
| | for r in rows: |
| | r["Source File"] = result.get("file_name", file_hash) |
| | all_rows.extend(rows) |
| |
|
| | if all_rows: |
| | full_df = pd.DataFrame(all_rows) |
| | cols = list(full_df.columns) |
| | if "Source File" in cols: |
| | cols = ["Source File"] + [c for c in cols if c != "Source File"] |
| | full_df = full_df[cols] |
| | csv_bytes = full_df.to_csv(index=False).encode("utf-8") |
| | st.download_button("📦 Download All Results (CSV)", csv_bytes, |
| | file_name="all_extracted_invoices.csv", mime="text/csv", key="download_all_csv") |
| |
|
| | with frame_right: |
| | if st.button("⬅️ Back to Upload"): |
| | st.session_state.batch_results.clear() |
| | st.session_state.current_file_hash = None |
| | st.session_state.is_processing_batch = False |
| | st.rerun() |
| |
|
| | |
| | with frame_left: |
| | file_options = {f"{v['file_name']} ({k[:6]})": k for k, v in st.session_state.batch_results.items()} |
| | selected_display = st.selectbox("Select invoice to view/edit:", options=list(file_options.keys()), index=0, key="file_selector") |
| | selected_hash = file_options[selected_display] |
| | if st.session_state.current_file_hash != selected_hash: |
| | st.session_state.current_file_hash = selected_hash |
| |
|
| | current = st.session_state.batch_results[selected_hash] |
| | image = current["image"] |
| | form_data = current["edited_data"] |
| |
|
| | |
| | bank = form_data.get("Bank Details", {}) if isinstance(form_data.get("Bank Details", {}), dict) else {} |
| | ensure_state(f"Invoice Number_{selected_hash}", form_data.get('Invoice Number', '')) |
| | ensure_state(f"Invoice Date_{selected_hash}", str(form_data.get('Invoice Date', '')).strip()) |
| | ensure_state(f"Due Date_{selected_hash}", str(form_data.get('Due Date', '')).strip()) |
| | ensure_state(f"Currency_{selected_hash}", form_data.get('Currency', 'USD') or 'USD') |
| | ensure_state(f"Currency_Custom_{selected_hash}", form_data.get('Currency', '') if form_data.get('Currency') not in ['USD','EUR','GBP','INR'] else '') |
| | ensure_state(f"Subtotal_{selected_hash}", float(form_data.get('Subtotal', 0.0))) |
| | ensure_state(f"Tax Percentage_{selected_hash}", float(form_data.get('Tax Percentage', 0.0))) |
| | ensure_state(f"Total Tax_{selected_hash}", float(form_data.get('Total Tax', 0.0))) |
| | ensure_state(f"Total Amount_{selected_hash}", float(form_data.get('Total Amount', 0.0))) |
| | ensure_state(f"Sender Name_{selected_hash}", form_data.get('Sender Name', '')) |
| | ensure_state(f"Sender Address_{selected_hash}", form_data.get('Sender Address', '')) |
| | ensure_state(f"Recipient Name_{selected_hash}", form_data.get('Recipient Name', '')) |
| | ensure_state(f"Recipient Address_{selected_hash}", form_data.get('Recipient Address', '')) |
| | ensure_state(f"Bank_bank_name_{selected_hash}", bank.get('bank_name', '')) |
| | ensure_state(f"Bank_bank_account_number_{selected_hash}", bank.get('bank_account_number', '') or bank.get('bank_acc_no', '')) |
| | ensure_state(f"Bank_bank_acc_name_{selected_hash}", bank.get('bank_acc_name', '')) |
| | ensure_state(f"Bank_bank_iban_{selected_hash}", bank.get('bank_iban', '')) |
| | ensure_state(f"Bank_bank_swift_{selected_hash}", bank.get('bank_swift', '')) |
| | ensure_state(f"Bank_bank_routing_{selected_hash}", bank.get('bank_routing', '')) |
| | ensure_state(f"Bank_bank_branch_{selected_hash}", bank.get('bank_branch', '')) |
| |
|
| | |
| | with frame_left: |
| | st.image(image, caption=current["file_name"], width=FIXED_IMG_WIDTH) |
| | st.write(f"**File Hash:** {selected_hash[:8]}...") |
| | if current.get('raw_pred') is not None: |
| | with st.expander("🔍 Show raw model output"): |
| | st.json(current['raw_pred']) |
| |
|
| | if st.button("🔁 Re-Run Inference", key=f"rerun_{selected_hash}"): |
| | with st.spinner("Re-running inference..."): |
| | try: |
| | pred = run_inference_on_image(image, processor, model, device, decoder_input_ids) |
| | mapped = map_prediction_to_ui(pred) |
| | safe_mapped = mapped if isinstance(mapped, dict) else {} |
| |
|
| | |
| | st.session_state.batch_results[selected_hash]["raw_pred"] = pred |
| | st.session_state.batch_results[selected_hash]["mapped_data"] = mapped |
| | st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy() |
| |
|
| | |
| | for key in [k for k in st.session_state.keys() if k.endswith(f"_{selected_hash}")]: |
| | del st.session_state[key] |
| |
|
| | st.success("✅ Re-run complete") |
| | st.rerun() |
| | except Exception as e: |
| | st.error(f"Re-run failed: {e}") |
| |
|
| | with frame_right: |
| | st.subheader(f"Editable Invoice: {current['file_name']}") |
| |
|
| | |
| | swap_cols = st.columns([1,1,2]) |
| | with swap_cols[0]: |
| | if st.button("⇄ Swap Sender ↔ Recipient", key=f"swap_{selected_hash}"): |
| | sn = f"Sender Name_{selected_hash}" |
| | rn = f"Recipient Name_{selected_hash}" |
| | sa = f"Sender Address_{selected_hash}" |
| | ra = f"Recipient Address_{selected_hash}" |
| | st.session_state[sn], st.session_state[rn] = st.session_state[rn], st.session_state[sn] |
| | st.session_state[sa], st.session_state[ra] = st.session_state[ra], st.session_state[sa] |
| | st.rerun() |
| |
|
| | |
| | with st.form(key=f"edit_form_{selected_hash}", clear_on_submit=False): |
| | tabs = st.tabs(["Invoice Details", "Sender/Recipient", "Bank Details", "Line Items"]) |
| |
|
| | with tabs[0]: |
| | st.text_input("Invoice Number", key=f"Invoice Number_{selected_hash}") |
| | st.text_input("Invoice Date", key=f"Invoice Date_{selected_hash}") |
| | st.text_input("Due Date", key=f"Due Date_{selected_hash}") |
| |
|
| | curr_options = ['USD', 'EUR', 'GBP', 'INR', 'Other'] |
| | if st.session_state[f"Currency_{selected_hash}"] not in curr_options: |
| | st.session_state[f"Currency_{selected_hash}"] = 'Other' |
| | st.selectbox("Currency", options=curr_options, key=f"Currency_{selected_hash}") |
| |
|
| | if st.session_state.get(f"Currency_{selected_hash}") == 'Other': |
| | st.text_input("Specify Currency", key=f"Currency_Custom_{selected_hash}") |
| |
|
| | st.number_input("Subtotal", key=f"Subtotal_{selected_hash}") |
| | st.number_input("Tax %", key=f"Tax Percentage_{selected_hash}") |
| | st.number_input("Total Tax", key=f"Total Tax_{selected_hash}") |
| | st.number_input("Total Amount", key=f"Total Amount_{selected_hash}") |
| |
|
| | with tabs[1]: |
| | st.text_input("Sender Name", key=f"Sender Name_{selected_hash}") |
| | st.text_area("Sender Address", key=f"Sender Address_{selected_hash}", height=80) |
| | st.text_input("Recipient Name", key=f"Recipient Name_{selected_hash}") |
| | st.text_area("Recipient Address", key=f"Recipient Address_{selected_hash}", height=80) |
| |
|
| | with tabs[2]: |
| | st.text_input("Bank Name", key=f"Bank_bank_name_{selected_hash}") |
| | st.text_input("Account Number", key=f"Bank_bank_account_number_{selected_hash}") |
| | st.text_input("Account Name", key=f"Bank_bank_acc_name_{selected_hash}") |
| | st.text_input("IBAN", key=f"Bank_bank_iban_{selected_hash}") |
| | st.text_input("SWIFT", key=f"Bank_bank_swift_{selected_hash}") |
| | st.text_input("Routing", key=f"Bank_bank_routing_{selected_hash}") |
| | st.text_input("Branch", key=f"Bank_bank_branch_{selected_hash}") |
| |
|
| | with tabs[3]: |
| | |
| | item_rows = form_data.get('Itemized Data', []) or [] |
| | normalized = [] |
| | for it in item_rows: |
| | if not isinstance(it, dict): |
| | it = {} |
| | normalized.append({ |
| | "Description": it.get("Description", it.get("Item Description", "")), |
| | "Quantity": it.get("Quantity", it.get("Item Quantity", 0)), |
| | "Unit Price": it.get("Unit Price", it.get("Item Unit Price", 0.0)), |
| | "Amount": it.get("Amount", it.get("Item Amount", 0.0)), |
| | "Tax": it.get("Tax", it.get("Item Tax", 0.0)), |
| | "Line Total": it.get("Line Total", it.get("Item Line Total", 0.0)), |
| | }) |
| |
|
| | items_df = pd.DataFrame(normalized) if normalized else pd.DataFrame( |
| | columns=["Description", "Quantity", "Unit Price", "Amount", "Tax", "Line Total"] |
| | ) |
| | edited_df = st.data_editor( |
| | items_df, |
| | num_rows="dynamic", |
| | key=f"items_editor_{selected_hash}", |
| | use_container_width=True, |
| | height=DATA_EDITOR_HEIGHT, |
| | ) |
| |
|
| | saved = st.form_submit_button("💾 Save All Edits") |
| | |
| |
|
| | if saved: |
| | currency = st.session_state.get(f"Currency_{selected_hash}", 'USD') |
| | if currency == 'Other': |
| | currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '') |
| |
|
| | updated = { |
| | 'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''), |
| | 'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''), |
| | 'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''), |
| | 'Currency': currency, |
| | 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0), |
| | 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0), |
| | 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0), |
| | 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0), |
| | 'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''), |
| | 'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''), |
| | 'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''), |
| | 'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''), |
| | 'Bank Details': { |
| | 'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''), |
| | 'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''), |
| | 'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''), |
| | 'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''), |
| | 'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''), |
| | 'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''), |
| | 'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '') |
| | }, |
| | 'Itemized Data': edited_df.to_dict('records'), |
| | 'Sender': {"Name": st.session_state.get(f"Sender Name_{selected_hash}", ''), |
| | "Address": st.session_state.get(f"Sender Address_{selected_hash}", '')}, |
| | 'Recipient': {"Name": st.session_state.get(f"Recipient Name_{selected_hash}", ''), |
| | "Address": st.session_state.get(f"Recipient Address_{selected_hash}", '')}, |
| | } |
| |
|
| | st.session_state.batch_results[selected_hash]["edited_data"] = updated |
| | st.success(f"✅ Saved: {current['file_name']}") |
| |
|
| | |
| | d_currency = st.session_state.get(f"Currency_{selected_hash}", 'USD') |
| | if d_currency == 'Other': |
| | d_currency = st.session_state.get(f"Currency_Custom_{selected_hash}", '') |
| | download_data = { |
| | 'Invoice Number': st.session_state.get(f"Invoice Number_{selected_hash}", ''), |
| | 'Invoice Date': st.session_state.get(f"Invoice Date_{selected_hash}", ''), |
| | 'Due Date': st.session_state.get(f"Due Date_{selected_hash}", ''), |
| | 'Currency': d_currency, |
| | 'Subtotal': st.session_state.get(f"Subtotal_{selected_hash}", 0.0), |
| | 'Tax Percentage': st.session_state.get(f"Tax Percentage_{selected_hash}", 0.0), |
| | 'Total Tax': st.session_state.get(f"Total Tax_{selected_hash}", 0.0), |
| | 'Total Amount': st.session_state.get(f"Total Amount_{selected_hash}", 0.0), |
| | 'Sender Name': st.session_state.get(f"Sender Name_{selected_hash}", ''), |
| | 'Sender Address': st.session_state.get(f"Sender Address_{selected_hash}", ''), |
| | 'Recipient Name': st.session_state.get(f"Recipient Name_{selected_hash}", ''), |
| | 'Recipient Address': st.session_state.get(f"Recipient Address_{selected_hash}", ''), |
| | 'Bank Details': { |
| | 'bank_name': st.session_state.get(f"Bank_bank_name_{selected_hash}", ''), |
| | 'bank_account_number': st.session_state.get(f"Bank_bank_account_number_{selected_hash}", ''), |
| | 'bank_acc_name': st.session_state.get(f"Bank_bank_acc_name_{selected_hash}", ''), |
| | 'bank_iban': st.session_state.get(f"Bank_bank_iban_{selected_hash}", ''), |
| | 'bank_swift': st.session_state.get(f"Bank_bank_swift_{selected_hash}", ''), |
| | 'bank_routing': st.session_state.get(f"Bank_bank_routing_{selected_hash}", ''), |
| | 'bank_branch': st.session_state.get(f"Bank_bank_branch_{selected_hash}", '') |
| | }, |
| | 'Itemized Data': edited_df.to_dict('records') |
| | } |
| | rows = flatten_invoice_to_rows(download_data) |
| | full_df = pd.DataFrame(rows) |
| | csv_bytes_one = full_df.to_csv(index=False).encode("utf-8") |
| | st.download_button( |
| | "📥 Download This Invoice (CSV)", |
| | csv_bytes_one, |
| | file_name=f"{Path(current['file_name']).stem}_full.csv", |
| | mime="text/csv", |
| | key=f"dl_{selected_hash}" |
| | ) |
| |
|
| | elif st.session_state.is_processing_batch: |
| | with frame_left: |
| | st.info("⏳ Processing batch... Please wait.") |
| | st.progress(0) |
| | with frame_right: |
| | st.caption("Preview & editor will appear here after extraction.") |
| |
|
| | else: |
| | |
| | with frame_left: |
| | st.caption("Ready when you are.") |
| | with frame_right: |
| | st.caption("Preview & editor will appear here after extraction.") |