Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import json | |
| import re | |
| import os | |
| import time | |
| import mimetypes | |
| import pandas as pd | |
| from fuzzywuzzy import fuzz | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain.agents import initialize_agent, Tool, AgentType | |
| from io import StringIO | |
| st.set_page_config(page_title="EZOFIS Accounts Payable Agent", layout="wide") | |
| # --- Initialize embedding model only once per session --- | |
| if 'embedding_model' not in st.session_state: | |
| with st.spinner("Loading embedding model for PO header normalization (1x cost)..."): | |
| st.session_state['embedding_model'] = SentenceTransformer('all-MiniLM-L6-v2') | |
| CANON_HEADERS = [ | |
| "PO Number", "Supplier Name", "Ship To", "Bill To", "PO Date", | |
| "Line Item Number", "Item Description", "Item Quantity", | |
| "Item Unit Price", "Line Item Total", "PO Total Value", | |
| "Currency", "Payment Terms", "Expected Delivery" | |
| ] | |
| st.session_state['canon_header_vectors'] = { | |
| h: st.session_state['embedding_model'].encode(h) | |
| for h in CANON_HEADERS | |
| } | |
| st.markdown(""" | |
| <style> | |
| .block-card { background: #fff; border-radius: 20px; box-shadow: 0 2px 16px rgba(25,39,64,0.05); padding: 32px 26px 24px 26px; margin-bottom: 24px; } | |
| .step-num { background: #A020F0; color: #fff; border-radius: 999px; padding: 6px 13px; font-weight: 700; margin-right: 14px; font-size: 20px; display: inline-block; vertical-align: middle; } | |
| .stSlider>div>div>div>div { background: #F3F6FB !important; border-radius: 999px; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| CANON_HEADERS = [ | |
| "PO Number", "Supplier Name", "Ship To", "Bill To", "PO Date", | |
| "Line Item Number", "Item Description", "Item Quantity", "Item Unit Price", | |
| "Line Item Total", "PO Total Value", "Currency", "Payment Terms", "Expected Delivery" | |
| ] | |
| def semantic_header_normalization(df, canon_headers, canon_vectors, model): | |
| orig_headers = df.columns.tolist() | |
| orig_vecs = {h: model.encode(h) for h in orig_headers} | |
| mapping = {} | |
| for oh in orig_headers: | |
| vec = orig_vecs[oh] | |
| sims = {ch: np.dot(vec, canon_vectors[ch])/(np.linalg.norm(vec)*np.linalg.norm(canon_vectors[ch])) for ch in canon_headers} | |
| best = max(sims, key=sims.get) | |
| if sims[best] > 0.80: | |
| mapping[oh] = best | |
| else: | |
| mapping[oh] = oh # Keep as is if no good match | |
| df = df.rename(columns=mapping) | |
| return df | |
| MODELS = { | |
| "OpenAI GPT-4.1": { | |
| "model": "gpt-4-1106-preview", | |
| "api_env": "OPENAI_API_KEY", | |
| "openai_api_base": None, | |
| }, | |
| "Mistral (OpenRouter)": { | |
| "model": "mistralai/ministral-8b", | |
| "api_env": "OPENROUTER_API_KEY", | |
| "openai_api_base": "https://openrouter.ai/api/v1", | |
| } | |
| } | |
| def get_llm(model_choice): | |
| config = MODELS[model_choice] | |
| api_key = os.getenv(config["api_env"]) | |
| if not api_key: | |
| st.error(f"API key not set: {config['api_env']}") | |
| st.stop() | |
| return ChatOpenAI( | |
| model=config["model"], | |
| openai_api_key=api_key, | |
| openai_api_base=config["openai_api_base"], | |
| temperature=0.1, | |
| max_tokens=2000 | |
| ) | |
| def get_extraction_prompt(txt): | |
| return ( | |
| "You are an expert invoice parser. " | |
| "Extract data according to the visible table structure and column headers in the invoice. " | |
| "For every line item, only extract fields that correspond to the table columns for that row (do not include header/shipment fields in line items). " | |
| "Merge all multi-line content within a single cell into that field (especially for the 'description' and 'notes'). " | |
| "If you are not able to deduct the currency then use your knowldedge and predict the 'currency' field using the 'supplier_address' data. " | |
| "Shipment/invoice-level fields such as CAR NUMBER, SHIPPING POINT, SHIPMENT NUMBER, CURRENCY, etc., must go ONLY into the 'invoice_header', not as line item fields.\n" | |
| "Use this schema:\n" | |
| '{\n' | |
| ' "invoice_header": {\n' | |
| ' "car_number": "string or null",\n' | |
| ' "shipment_number": "string or null",\n' | |
| ' "shipping_point": "string or null",\n' | |
| ' "currency": "string or null",\n' | |
| ' "invoice_number": "string or null",\n' | |
| ' "invoice_date": "string or null",\n' | |
| ' "order_number": "string or null",\n' | |
| ' "customer_order_number": "string or null",\n' | |
| ' "our_order_number": "string or null",\n' | |
| ' "sales_order_number": "string or null",\n' | |
| ' "purchase_order_number": "string or null",\n' | |
| ' "order_date": "string or null",\n' | |
| ' "supplier_name": "string or null",\n' | |
| ' "supplier_address": "string or null",\n' | |
| ' "supplier_phone": "string or null",\n' | |
| ' "supplier_email": "string or null",\n' | |
| ' "supplier_tax_id": "string or null",\n' | |
| ' "customer_name": "string or null",\n' | |
| ' "customer_address": "string or null",\n' | |
| ' "customer_phone": "string or null",\n' | |
| ' "customer_email": "string or null",\n' | |
| ' "customer_tax_id": "string or null",\n' | |
| ' "ship_to_name": "string or null",\n' | |
| ' "ship_to_address": "string or null",\n' | |
| ' "bill_to_name": "string or null",\n' | |
| ' "bill_to_address": "string or null",\n' | |
| ' "remit_to_name": "string or null",\n' | |
| ' "remit_to_address": "string or null",\n' | |
| ' "tax_id": "string or null",\n' | |
| ' "tax_registration_number": "string or null",\n' | |
| ' "vat_number": "string or null",\n' | |
| ' "payment_terms": "string or null",\n' | |
| ' "payment_method": "string or null",\n' | |
| ' "payment_reference": "string or null",\n' | |
| ' "bank_account_number": "string or null",\n' | |
| ' "iban": "string or null",\n' | |
| ' "swift_code": "string or null",\n' | |
| ' "total_before_tax": "string or null",\n' | |
| ' "tax_amount": "string or null",\n' | |
| ' "tax_rate": "string or null",\n' | |
| ' "shipping_charges": "string or null",\n' | |
| ' "discount": "string or null",\n' | |
| ' "total_due": "string or null",\n' | |
| ' "amount_paid": "string or null",\n' | |
| ' "balance_due": "string or null",\n' | |
| ' "due_date": "string or null",\n' | |
| ' "invoice_status": "string or null",\n' | |
| ' "reference_number": "string or null",\n' | |
| ' "project_code": "string or null",\n' | |
| ' "department": "string or null",\n' | |
| ' "contact_person": "string or null",\n' | |
| ' "notes": "string or null",\n' | |
| ' "additional_info": "string or null"\n' | |
| ' },\n' | |
| ' "line_items": [\n' | |
| ' {\n' | |
| ' "quantity": "string or null",\n' | |
| ' "units": "string or null",\n' | |
| ' "description": "string or null",\n' | |
| ' "footage": "string or null",\n' | |
| ' "price": "string or null",\n' | |
| ' "amount": "string or null",\n' | |
| ' "notes": "string or null"\n' | |
| ' }\n' | |
| ' ]\n' | |
| '}' | |
| "\nIf a field is missing for a line item or header, use null. " | |
| "Do not invent fields. Do not add any header or shipment data to any line item. Return ONLY the JSON object, no explanation.\n" | |
| "\nInvoice Text:\n" | |
| f"{txt}" | |
| ) | |
| def clean_json_response(text): | |
| if not text: return None | |
| orig = text | |
| text = re.sub(r'```(?:json)?', '', text).strip() | |
| start, end = text.find('{'), text.rfind('}') + 1 | |
| if start < 0 or end < 1: | |
| st.error("Couldn't locate JSON in response."); st.code(orig); return None | |
| frag = text[start:end] | |
| frag = re.sub(r',\s*([}\]])', r'\1', frag) | |
| try: | |
| return json.loads(frag) | |
| except json.JSONDecodeError as e: | |
| repaired = re.sub(r'"\s*"\s*(?="[^"]+"\s*:)', '","', frag) | |
| try: | |
| return json.loads(repaired) | |
| except json.JSONDecodeError: | |
| st.error(f"JSON parse error: {e}"); st.code(frag); return None | |
| def ensure_total_due(invoice_header): | |
| if invoice_header.get("total_due") in [None, ""]: | |
| for field in ["invoice_total", "invoice_value", "total_before_tax", "balance_due", "amount_paid"]: | |
| if field in invoice_header and invoice_header[field]: | |
| invoice_header["total_due"] = invoice_header[field] | |
| break | |
| return invoice_header | |
| def clean_num(val): | |
| if val is None: return None | |
| if isinstance(val, (int, float)): return float(val) | |
| matches = re.findall(r"[-+]?\d[\d,]*\.?\d*", str(val)) | |
| if matches: | |
| cleaned = [m.replace(',', '') for m in matches if m] | |
| as_floats = [float(c) for c in cleaned if c.replace('.', '', 1).isdigit()] | |
| if as_floats: return max(as_floats) | |
| return None | |
| def weighted_fuzzy_score(s1, s2): | |
| norm1 = str(s1).strip().lower() | |
| norm2 = str(s2).strip().lower() | |
| norm1 = " ".join(norm1.split()) | |
| norm2 = " ".join(norm2.split()) | |
| if not norm1 and not norm2: | |
| return 100 | |
| if norm1 == norm2: | |
| return 100 | |
| return fuzz.token_set_ratio(norm1, norm2) | |
| def find_po_number_in_json(po_number, invoice_json): | |
| def _flatten(obj): | |
| fields = [] | |
| if isinstance(obj, dict): | |
| for v in obj.values(): | |
| fields.extend(_flatten(v)) | |
| elif isinstance(obj, list): | |
| for item in obj: | |
| fields.extend(_flatten(item)) | |
| elif obj is not None: | |
| fields.append(str(obj)) | |
| return fields | |
| po_str = str(po_number).strip().replace(" ", "").replace(".0", "") | |
| try: | |
| po_int = str(int(float(po_number))) | |
| except: | |
| po_int = po_str | |
| all_strs = [str(s).strip().replace(" ", "").replace(".0", "") for s in _flatten(invoice_json)] | |
| for s in all_strs: | |
| if not s: continue | |
| if po_str and (po_str in s or s in po_str): return True | |
| if po_int and (po_int in s or s in po_int): return True | |
| return False | |
| def drop_embedding_keys(d): | |
| if hasattr(d, "to_dict"): | |
| d = d.to_dict() | |
| return {k: v for k, v in d.items() if "embedding" not in k} | |
| # --- UI Layout --- | |
| st.markdown("<h1 style='font-weight:800; margin-bottom:8px;'>EZOFIS Accounts Payable Agent</h1>", unsafe_allow_html=True) | |
| st.markdown("<div style='font-size:20px; margin-bottom:28px; color:#24345C;'>Modern workflow automation for finance teams</div>", unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns([2,2,3]) | |
| # --- col1: PO CSV, Model & Weights, Threshold Sliders --- | |
| with col1: | |
| st.markdown("<span class='step-num'>1</span> <b>Upload Active Purchase Orders (POs)</b>", unsafe_allow_html=True) | |
| po_file = st.file_uploader("CSV with PO number, Supplier, Items, etc.", type=["csv"], key="po_csv", label_visibility="collapsed") | |
| if 'po_file_bytes' not in st.session_state: | |
| st.session_state['po_file_bytes'] = None | |
| if 'last_po_df' not in st.session_state: | |
| st.session_state['last_po_df'] = None | |
| if po_file is not None: | |
| file_bytes = po_file.getvalue() | |
| if st.session_state['po_file_bytes'] != file_bytes: | |
| st.session_state['po_file_bytes'] = file_bytes | |
| df = pd.read_csv(StringIO(po_file.getvalue().decode("utf-8"))) | |
| # Semantic normalization (embeddings) | |
| df = semantic_header_normalization( | |
| df, CANON_HEADERS, | |
| st.session_state['canon_header_vectors'], | |
| st.session_state['embedding_model'] | |
| ) | |
| # Precompute and store embeddings for PO line item descriptions | |
| for c in ["Line Item Description", "Item Description", "Description"]: | |
| if c in df.columns: | |
| po_line_descs = df[c].astype(str).fillna("").tolist() | |
| po_line_desc_embeds = st.session_state['embedding_model'].encode(po_line_descs) | |
| df["_desc_embedding"] = list(po_line_desc_embeds) | |
| break | |
| st.session_state['last_po_df'] = df | |
| st.success(f"Loaded {len(df)} records from uploaded CSV.") | |
| else: | |
| df = st.session_state['last_po_df'] | |
| if df is not None: | |
| st.success(f"Loaded {len(df)} records from uploaded CSV.") | |
| else: | |
| df = None | |
| st.markdown("<span class='step-num'>2</span> <b>Select Model & Scoring Weights</b>", unsafe_allow_html=True) | |
| mdl = st.selectbox("LLM Model", list(MODELS.keys()), key="extract_model", index=0) | |
| def int_slider(label, value, key): | |
| return st.slider(label, 0, 100, value, 1, key=key, format="%d") | |
| weight_supplier = int_slider("Supplier Name (%)", 35, "w_supplier") | |
| weight_po_number = int_slider("PO Number (%)", 35, "w_po") | |
| weight_currency = int_slider("Currency (%)", 10, "w_curr") | |
| weight_total_due = int_slider("Total Due (%)", 10, "w_due") | |
| weight_line_item = int_slider("Line Item (%)", 10, "w_line") | |
| weight_sum = weight_supplier + weight_po_number + weight_currency + weight_total_due + weight_line_item | |
| if weight_sum != 100: | |
| st.warning(f"Sum of weights is {weight_sum}%. Adjust so it equals 100%.") | |
| st.markdown("<span class='step-num'>3</span> <b>Set Decision Thresholds</b>", unsafe_allow_html=True) | |
| approved_threshold = st.slider( | |
| "Threshold for 'APPROVED'", | |
| min_value=0, | |
| max_value=100, | |
| value=90, | |
| format="%d" | |
| ) | |
| partial_threshold = st.slider( | |
| "Threshold for 'PARTIALLY APPROVED'", | |
| min_value=0, | |
| max_value=approved_threshold - 1, | |
| value=60, | |
| format="%d" | |
| ) | |
| # --- col2: Invoice Upload & Extract Button --- | |
| with col2: | |
| st.markdown("<span class='step-num'>4</span> <b>Upload Invoice/Document</b>", unsafe_allow_html=True) | |
| inv_file = st.file_uploader("Upload PDF, DOCX, XLSX, PNG, JPG, TIFF", type=["pdf", "docx", "xlsx", "xls", "png", "jpg", "jpeg", "tiff"], key="invoice_file", label_visibility="collapsed") | |
| st.markdown("<span class='step-num'>5</span> <b>Extract Data</b>", unsafe_allow_html=True) | |
| if st.button("Extract"): | |
| if inv_file: | |
| with st.spinner("Extracting text from document..."): | |
| filename = getattr(inv_file, "name", "uploaded_file") | |
| file_bytes = inv_file.read() | |
| content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" | |
| headers = { | |
| "unstract-key": os.getenv("UNSTRACT_API_KEY"), | |
| "Content-Type": content_type, | |
| } | |
| url = "https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper" | |
| r = requests.post(url, headers=headers, data=file_bytes) | |
| if r.status_code != 202: | |
| st.error(f"Unstract: Error uploading file: {r.status_code} - {r.text}") | |
| else: | |
| whisper_hash = r.json().get("whisper_hash") | |
| if whisper_hash: | |
| status_url = f"https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper-status?whisper_hash={whisper_hash}" | |
| status_placeholder = st.empty() | |
| for i in range(30): | |
| status_r = requests.get(status_url, headers={"unstract-key": os.getenv("UNSTRACT_API_KEY")}) | |
| if status_r.status_code != 200: | |
| st.error(f"Unstract: Error checking status: {status_r.status_code} - {status_r.text}") | |
| break | |
| status = status_r.json().get("status") | |
| if status == "processed": | |
| status_placeholder.info("EZOFIS AI OCR AGENT STATUS: processed! 🎉") | |
| break | |
| status_placeholder.info(f"EZOFIS AI OCR AGENT STATUS: {status or 'waiting'}... ({i+1})") | |
| time.sleep(2) | |
| else: | |
| status_placeholder.error("Unstract: Timeout waiting for OCR to finish.") | |
| retrieve_url = f"https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper-retrieve?whisper_hash={whisper_hash}&text_only=true" | |
| r = requests.get(retrieve_url, headers={"unstract-key": os.getenv("UNSTRACT_API_KEY")}) | |
| if r.status_code != 200: | |
| st.error(f"Unstract: Error retrieving extracted text: {r.status_code} - {r.text}") | |
| else: | |
| try: | |
| data = r.json() | |
| extracted_text = data.get("result_text") or r.text | |
| except Exception: | |
| extracted_text = r.text | |
| with st.spinner("Fine Tuning The Extracted Output..."): | |
| llm = get_llm(mdl) | |
| response = llm.invoke([{"role": "user", "content": get_extraction_prompt(extracted_text)}]) | |
| result = response.content if hasattr(response, "content") else response | |
| extracted_info = clean_json_response(result) | |
| if extracted_info: | |
| st.success("Extraction Complete") | |
| if "invoice_header" in extracted_info: | |
| extracted_info["invoice_header"] = ensure_total_due(extracted_info["invoice_header"]) | |
| st.session_state['last_extracted_info'] = extracted_info | |
| else: | |
| st.warning("Please upload an invoice/document first.") | |
| # --- col3: Decision, Side-by-Side, and Line Item Hybrid Matching with Navigation --- | |
| with col3: | |
| st.markdown("<span class='step-num'>6</span> <b>AP Agent Decision (LLM Powered)</b>", unsafe_allow_html=True) | |
| decision_made = st.button("Make a decision (EZOFIS AP AGENT)") | |
| if decision_made: | |
| extracted_info = st.session_state.get('last_extracted_info', None) | |
| po_df = st.session_state.get('last_po_df', None) | |
| if extracted_info is not None and po_df is not None: | |
| def po_match_tool_func(text): | |
| inv = extracted_info | |
| scores = [] | |
| inv_hdr = inv["invoice_header"] | |
| inv_supplier = inv_hdr.get("supplier_name") or "" | |
| inv_po_number = inv_hdr.get("purchase_order_number") or inv_hdr.get("po_number") or inv_hdr.get("order_number") or "" | |
| inv_currency = inv_hdr.get("currency") or "" | |
| inv_total_due = clean_num(inv_hdr.get("total_due")) | |
| inv_line_items = inv.get("line_items", []) | |
| embedding_model = st.session_state['embedding_model'] | |
| matched_po_indices = set() # <--- NEW: to track matched PO indices | |
| # --- Hybrid Line Item Matching: Best-Match for Each Invoice Line --- | |
| for idx, row in po_df.iterrows(): | |
| po_supplier = row.get("Supplier Name", "") | |
| po_po_number = str(row.get("PO Number", "")) | |
| po_currency = row.get("Currency", "") | |
| po_total = clean_num(row.get("PO Total Value", "")) or clean_num(row.get("PO Total", "")) | |
| po_desc = row.get("Line Item Description", "") or row.get("Item Description", "") or row.get("Description", "") | |
| po_qty = str(row.get("Qty", "")) or str(row.get("Item Quantity", "")) or str(row.get("Quantity", "")) | |
| po_unit = str(row.get("Rate", "")) or str(row.get("Item Unit Price", "")) or str(row.get("Unit Price", "")) or str(row.get("Price", "")) | |
| po_line_total = clean_num(row.get("Amount", "")) or clean_num(row.get("Line Item Total", "")) or clean_num(row.get("Line Amount", "")) | |
| po_desc_emb = embedding_model.encode(po_desc) if po_desc else None | |
| field_details = [] | |
| s_supplier = weighted_fuzzy_score(inv_supplier, po_supplier) | |
| field_details.append({ | |
| "field": "Supplier Name", "invoice": inv_supplier, "po": po_supplier, "score": s_supplier | |
| }) | |
| s_po_number = 100 if find_po_number_in_json(po_po_number, inv) else 0 | |
| field_details.append({ | |
| "field": "PO Number", "invoice": po_po_number, "po": po_po_number, "score": s_po_number | |
| }) | |
| s_currency = weighted_fuzzy_score(inv_currency, po_currency) | |
| field_details.append({ | |
| "field": "Currency", "invoice": inv_currency, "po": po_currency, "score": s_currency | |
| }) | |
| s_total = 100 if inv_total_due is not None and po_total is not None and abs(inv_total_due - po_total) < 2 else 0 | |
| field_details.append({ | |
| "field": "Total Due", "invoice": inv_total_due, "po": po_total, "score": s_total | |
| }) | |
| all_line_matches = [] | |
| po_lines = po_df[po_df["PO Number"] == po_po_number] | |
| # Hybrid: For each invoice line item, best-match to a PO line | |
| for inv_line in inv_line_items: | |
| best_score = 0 | |
| best_po_line = None | |
| best_po_idx = None | |
| desc = inv_line.get("description", "") | |
| qty = inv_line.get("quantity", "") | |
| price = inv_line.get("price", "") | |
| amount = inv_line.get("amount", "") | |
| inv_desc_emb = embedding_model.encode(desc) if desc else None | |
| # Try all matching PO lines | |
| for po_idx, po_line in po_lines.iterrows(): | |
| po_ldesc = po_line.get("Line Item Description", "") or po_line.get("Item Description", "") or po_line.get("Description", "") | |
| po_lqty = str(po_line.get("Qty", "")) or str(po_line.get("Item Quantity", "")) or str(po_line.get("Quantity", "")) | |
| po_lunit = str(po_line.get("Rate", "")) or str(po_line.get("Item Unit Price", "")) or str(po_line.get("Unit Price", "")) or str(po_line.get("Price", "")) | |
| po_lamt = clean_num(po_line.get("Amount", "")) or clean_num(po_line.get("Line Item Total", "")) or clean_num(po_line.get("Line Amount", "")) | |
| po_ldesc_emb = po_line.get("_desc_embedding", None) | |
| desc_score_sem = 0 | |
| if po_ldesc_emb is not None and inv_desc_emb is not None: | |
| desc_score_sem = float(np.dot(po_ldesc_emb, inv_desc_emb) / (np.linalg.norm(po_ldesc_emb)*np.linalg.norm(inv_desc_emb))) | |
| desc_score_sem = int(100*desc_score_sem) | |
| desc_score_fuz = weighted_fuzzy_score(desc, po_ldesc) | |
| desc_score = int((desc_score_sem + desc_score_fuz)/2) | |
| qty_score = 100 if clean_num(qty) == clean_num(po_lqty) else 0 | |
| unit_score = 100 if clean_num(price) == clean_num(po_lunit) else 0 | |
| amount_score = 100 if clean_num(amount) == po_lamt else 0 | |
| total = desc_score * 0.5 + qty_score * 0.2 + unit_score * 0.15 + amount_score * 0.15 | |
| if total > best_score: | |
| best_score = total | |
| best_po_line = { | |
| "description": po_ldesc, "quantity": po_lqty, | |
| "price": po_lunit, "amount": po_lamt | |
| } | |
| best_po_idx = po_idx | |
| best_detail = { | |
| "field": "Line Item", | |
| "invoice": {"description": desc, "quantity": qty, "price": price, "amount": amount}, | |
| "po": best_po_line, | |
| "desc_score": desc_score, "qty_score": qty_score, | |
| "unit_score": unit_score, "amount_score": amount_score, | |
| "line_item_score": total, | |
| "po_df_index": po_idx # <--- track matched index | |
| } | |
| if best_po_line: | |
| all_line_matches.append(best_detail) | |
| matched_po_indices.add(best_po_idx) # <--- add matched index | |
| # Choose best line (highest total) as the match for this PO | |
| line_item_score = max([m["line_item_score"] for m in all_line_matches], default=0) | |
| best_line_detail = max(all_line_matches, key=lambda m: m["line_item_score"], default=None) | |
| wsum = weight_supplier + weight_po_number + weight_currency + weight_total_due + weight_line_item | |
| total_score = ( | |
| s_supplier * weight_supplier/100 + | |
| s_po_number * weight_po_number/100 + | |
| s_currency * weight_currency/100 + | |
| s_total * weight_total_due/100 + | |
| line_item_score * weight_line_item/100 | |
| ) if wsum == 100 else 0 | |
| reason = ( | |
| f"Supplier match: {s_supplier}/100 (invoice: '{inv_supplier}' vs PO: '{po_supplier}'), " | |
| f"PO Number: {s_po_number}/100 ({'found anywhere in JSON' if s_po_number else 'not found'}), " | |
| f"Currency: {s_currency}/100 (invoice: '{inv_currency}' vs PO: '{po_currency}'), " | |
| f"Total Due: {'match' if s_total else 'no match'} (invoice: {inv_total_due} vs PO: {po_total}), " | |
| f"Line item best match: {int(line_item_score)}/100." | |
| ) | |
| debug = { | |
| "po_idx": idx, "po_supplier": po_supplier, "po_po_number": po_po_number, "po_total": po_total, | |
| "scores": field_details, "line_item_score": line_item_score, "best_line_detail": best_line_detail, | |
| "all_line_matches": all_line_matches, "total_score": total_score, "inv_total_due": inv_total_due | |
| } | |
| scores.append((row, total_score, reason, debug)) | |
| scores.sort(key=lambda tup: tup[1], reverse=True) | |
| if not scores: | |
| return json.dumps({"decision": "REJECTED", "reason": "No POs found.", "debug": {}}) | |
| best_row, best_score, reason, debug = scores[0] | |
| return json.dumps({ | |
| "score": best_score, | |
| "approved_threshold": approved_threshold, | |
| "partial_threshold": partial_threshold, | |
| "reason": f"Best match score: {int(best_score)}/100. {reason}", | |
| "debug": debug, | |
| "po_row": drop_embedding_keys(best_row) if best_row is not None else None, | |
| "matched_po_indices": list(matched_po_indices) # <--- return this | |
| }) | |
| tool = Tool( | |
| name="po_match_tool", | |
| func=po_match_tool_func, | |
| description="Returns JSON with score, thresholds, details. LLM must use this result with strict user logic.", | |
| ) | |
| agent_llm = get_llm(mdl) | |
| agent = initialize_agent( | |
| [tool], | |
| agent_llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| max_iterations=2, | |
| ) | |
| prompt = ( | |
| "You are an expert accounts payable agent.\n" | |
| "You MUST use the output of po_match_tool to make the approval decision.\n" | |
| f"User-configured thresholds are: APPROVED = {approved_threshold}, PARTIALLY APPROVED = {partial_threshold}.\n" | |
| "You must STRICTLY follow this logic with NO exceptions and NO overrides:\n" | |
| "- If the score from po_match_tool is greater than or equal to the APPROVED threshold, return decision = APPROVED.\n" | |
| "- Else, if the score is greater than or equal to the PARTIALLY APPROVED threshold, return decision = PARTIALLY APPROVED.\n" | |
| "- Else, return decision = REJECTED.\n" | |
| "DO NOT reject just because of a field mismatch. Do not add your own rules.\n" | |
| "Use only this logic for all decisions, based on the score and thresholds from po_match_tool.\n" | |
| "Always return JSON using this schema, and nothing else:\n" | |
| '{\n' | |
| ' "decision": "APPROVED" | "PARTIALLY APPROVED" | "REJECTED",\n' | |
| ' "reason": "...",\n' | |
| ' "debug": {...},\n' | |
| ' "po_row": {...},\n' | |
| ' "matched_po_indices": [index, ...]\n' | |
| '}\n' | |
| "First, call po_match_tool and get its JSON. Then apply the threshold rules above to its output.\n" | |
| "Return only the decision JSON.\n" | |
| f"Invoice JSON:\n{json.dumps(extracted_info, indent=2)}" | |
| ) | |
| with st.spinner("AP Agent reasoning and making a decision..."): | |
| result = agent.run(prompt) | |
| try: | |
| result_json = clean_json_response(result) | |
| st.session_state['result_json'] = result_json | |
| st.session_state['extracted_info_for_decision'] = extracted_info | |
| st.session_state['line_item_idx'] = 0 | |
| except Exception: | |
| st.session_state['result_json'] = None | |
| st.session_state['extracted_info_for_decision'] = None | |
| st.session_state['line_item_idx'] = 0 | |
| st.subheader("Decision output not in standard format.") | |
| st.code(result) | |
| # --- Always show results if available --- | |
| result_json = st.session_state.get('result_json', None) | |
| extracted_info = st.session_state.get('extracted_info_for_decision', None) | |
| po_df = st.session_state.get('last_po_df', None) # <-- make sure this is available | |
| if result_json and extracted_info: | |
| status = result_json.get('decision', 'N/A') | |
| score = result_json.get('debug', {}).get('total_score', None) | |
| color = "#4CAF50" if status == "APPROVED" else "#FFC107" if status == "PARTIALLY APPROVED" else "#F44336" | |
| st.markdown(f""" | |
| <div style=' | |
| border:2px solid {color}; | |
| border-radius:12px; | |
| padding:18px 12px; | |
| margin-bottom:14px; | |
| text-align:center; | |
| background: #f8f9fb; | |
| '> | |
| <div style='font-size:2.1em;font-weight:800;color:{color};'>{status}</div> | |
| <div style='font-size:1.2em;color:#24345C;font-weight:500;margin-top:8px;'> | |
| Total Score: <span style='color:{color};font-weight:700;font-size:1.1em;'>{score if score is not None else '--'}</span> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("#### Side-by-Side Field Matching") | |
| field_scores = result_json.get("debug", {}).get("scores", []) | |
| if field_scores: | |
| table_rows = "" | |
| for fs in field_scores: | |
| table_rows += f"<tr><td><b>{fs['field']}</b></td><td>{fs['invoice']}</td><td>{fs['po']}</td><td>{fs['score']}</td></tr>" | |
| st.markdown( | |
| f""" | |
| <table style='width:100%;border-collapse:collapse;margin-bottom:20px;'> | |
| <thead> | |
| <tr style='background:#f7f7fa;'><th>Field</th><th>Invoice Value</th><th>PO Value</th><th>Score</th></tr> | |
| </thead> | |
| <tbody> | |
| {table_rows} | |
| </tbody> | |
| </table> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown("#### Line Item Side-by-Side Matching") | |
| invoice_line_items = result_json.get("debug", {}).get("all_line_matches", []) | |
| num_items = len(invoice_line_items) | |
| key = "line_item_idx" | |
| if key not in st.session_state or st.session_state[key] >= num_items: | |
| st.session_state[key] = 0 | |
| if num_items > 0: | |
| cols = st.columns([1, 3, 1]) | |
| with cols[0]: | |
| if st.button("⬅️", key="prev_btn", disabled=st.session_state[key] == 0): | |
| st.session_state[key] = max(0, st.session_state[key] - 1) | |
| with cols[1]: | |
| st.markdown( | |
| f"<div style='text-align:center;font-weight:500;font-size:1.1em; margin-top:6px;'>" | |
| f"Line Item {st.session_state[key]+1} of {num_items}" | |
| f"</div>", unsafe_allow_html=True | |
| ) | |
| with cols[2]: | |
| if st.button("➡️", key="next_btn", disabled=st.session_state[key] == num_items - 1): | |
| st.session_state[key] = min(num_items - 1, st.session_state[key] + 1) | |
| idx = st.session_state[key] | |
| inv_line = invoice_line_items[idx] | |
| inv = inv_line.get("invoice", {}) | |
| best_po = inv_line.get("po", {}) | |
| desc_score = inv_line.get("desc_score", 0) | |
| qty_score = inv_line.get("qty_score", 0) | |
| unit_score = inv_line.get("unit_score", 0) | |
| amt_score = inv_line.get("amount_score", 0) | |
| line_score = int(inv_line.get("line_item_score", 0)) | |
| fields = ["Description", "Quantity", "Price", "Amount"] | |
| table_html = "<table style='width:90%;border-collapse:collapse;margin-bottom:24px;'>" | |
| table_html += "<thead><tr><th></th>" | |
| for f in fields: | |
| table_html += f"<th style='text-align:center'>{f}</th>" | |
| table_html += "</tr></thead><tbody>" | |
| table_html += "<tr><td><b>Invoice Value</b></td>" | |
| table_html += f"<td>{inv.get('description', '')}</td>" | |
| table_html += f"<td>{inv.get('quantity', '')}</td>" | |
| table_html += f"<td>{inv.get('price', '')}</td>" | |
| table_html += f"<td>{inv.get('amount', '')}</td></tr>" | |
| table_html += "<tr><td><b>PO Value</b></td>" | |
| table_html += f"<td>{best_po.get('description', '')}</td>" | |
| table_html += f"<td>{best_po.get('quantity', '')}</td>" | |
| table_html += f"<td>{best_po.get('price', '')}</td>" | |
| table_html += f"<td>{best_po.get('amount', '')}</td></tr>" | |
| table_html += "<tr><td><b>Score</b></td>" | |
| table_html += f"<td>{desc_score}</td><td>{qty_score}</td><td>{unit_score}</td><td>{amt_score}</td></tr>" | |
| table_html += "<tr><td><b>Line Score</b></td><td></td><td></td><td></td><td><b>{}</b></td></tr>".format(line_score) | |
| table_html += "</tbody></table>" | |
| st.markdown(table_html, unsafe_allow_html=True) | |
| else: | |
| st.info("No line items found.") | |
| # --------- PO LINE ITEMS MISSING IN INVOICE --------- | |
| matched_po_row = result_json.get("po_row", None) | |
| matched_indices = set(result_json.get("matched_po_indices", [])) | |
| if matched_po_row and po_df is not None: | |
| po_no = matched_po_row.get("PO Number", None) | |
| po_lines_df = po_df[po_df["PO Number"] == po_no] | |
| missing_po_lines = po_lines_df.loc[~po_lines_df.index.isin(matched_indices)] | |
| if not missing_po_lines.empty: | |
| st.markdown("#### PO Line Items Missing in Invoice") | |
| table_html = "<table style='width:90%;border-collapse:collapse;margin-bottom:24px;'>" | |
| table_html += "<thead><tr>" | |
| for f in ["Description", "Quantity", "Price", "Amount"]: | |
| table_html += f"<th style='text-align:center'>{f}</th>" | |
| table_html += "</tr></thead><tbody>" | |
| for _, row in missing_po_lines.iterrows(): | |
| table_html += "<tr>" | |
| table_html += f"<td>{row.get('Description', '') or row.get('Item Description', '') or row.get('Line Item Description', '')}</td>" | |
| table_html += f"<td>{row.get('Quantity', '') or row.get('Item Quantity', '') or row.get('Qty', '')}</td>" | |
| table_html += f"<td>{row.get('Price', '') or row.get('Item Unit Price', '') or row.get('Rate', '')}</td>" | |
| table_html += f"<td>{row.get('Amount', '') or row.get('Line Item Total', '') or row.get('Line Amount', '')}</td>" | |
| table_html += "</tr>" | |
| table_html += "</tbody></table>" | |
| st.markdown(table_html, unsafe_allow_html=True) | |
| else: | |
| st.info("No PO line items missing in invoice.") | |
| st.markdown("#### Reason") | |
| st.write(result_json.get('reason', 'N/A')) | |
| st.markdown("#### Debug & Matching Details (Full JSON)") | |
| st.json(result_json.get('debug')) | |
| st.markdown("#### Extracted Invoice JSON") | |
| st.json(extracted_info) | |
| st.markdown("#### Matched PO Row") | |
| st.json(result_json.get('po_row')) | |
| else: | |
| st.info("Run a decision to see results.") |