Seth0330's picture
Update app.py
41f1938 verified
import streamlit as st
import requests
import json
import re
import os
import time
import mimetypes
import pandas as pd
from fuzzywuzzy import fuzz
from sentence_transformers import SentenceTransformer
import numpy as np
from langchain_community.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, Tool, AgentType
from io import StringIO
st.set_page_config(page_title="EZOFIS Accounts Payable Agent", layout="wide")
# --- Initialize embedding model only once per session ---
if 'embedding_model' not in st.session_state:
with st.spinner("Loading embedding model for PO header normalization (1x cost)..."):
st.session_state['embedding_model'] = SentenceTransformer('all-MiniLM-L6-v2')
CANON_HEADERS = [
"PO Number", "Supplier Name", "Ship To", "Bill To", "PO Date",
"Line Item Number", "Item Description", "Item Quantity",
"Item Unit Price", "Line Item Total", "PO Total Value",
"Currency", "Payment Terms", "Expected Delivery"
]
st.session_state['canon_header_vectors'] = {
h: st.session_state['embedding_model'].encode(h)
for h in CANON_HEADERS
}
st.markdown("""
<style>
.block-card { background: #fff; border-radius: 20px; box-shadow: 0 2px 16px rgba(25,39,64,0.05); padding: 32px 26px 24px 26px; margin-bottom: 24px; }
.step-num { background: #A020F0; color: #fff; border-radius: 999px; padding: 6px 13px; font-weight: 700; margin-right: 14px; font-size: 20px; display: inline-block; vertical-align: middle; }
.stSlider>div>div>div>div { background: #F3F6FB !important; border-radius: 999px; }
</style>
""", unsafe_allow_html=True)
CANON_HEADERS = [
"PO Number", "Supplier Name", "Ship To", "Bill To", "PO Date",
"Line Item Number", "Item Description", "Item Quantity", "Item Unit Price",
"Line Item Total", "PO Total Value", "Currency", "Payment Terms", "Expected Delivery"
]
def semantic_header_normalization(df, canon_headers, canon_vectors, model):
orig_headers = df.columns.tolist()
orig_vecs = {h: model.encode(h) for h in orig_headers}
mapping = {}
for oh in orig_headers:
vec = orig_vecs[oh]
sims = {ch: np.dot(vec, canon_vectors[ch])/(np.linalg.norm(vec)*np.linalg.norm(canon_vectors[ch])) for ch in canon_headers}
best = max(sims, key=sims.get)
if sims[best] > 0.80:
mapping[oh] = best
else:
mapping[oh] = oh # Keep as is if no good match
df = df.rename(columns=mapping)
return df
MODELS = {
"OpenAI GPT-4.1": {
"model": "gpt-4-1106-preview",
"api_env": "OPENAI_API_KEY",
"openai_api_base": None,
},
"Mistral (OpenRouter)": {
"model": "mistralai/ministral-8b",
"api_env": "OPENROUTER_API_KEY",
"openai_api_base": "https://openrouter.ai/api/v1",
}
}
def get_llm(model_choice):
config = MODELS[model_choice]
api_key = os.getenv(config["api_env"])
if not api_key:
st.error(f"API key not set: {config['api_env']}")
st.stop()
return ChatOpenAI(
model=config["model"],
openai_api_key=api_key,
openai_api_base=config["openai_api_base"],
temperature=0.1,
max_tokens=2000
)
def get_extraction_prompt(txt):
return (
"You are an expert invoice parser. "
"Extract data according to the visible table structure and column headers in the invoice. "
"For every line item, only extract fields that correspond to the table columns for that row (do not include header/shipment fields in line items). "
"Merge all multi-line content within a single cell into that field (especially for the 'description' and 'notes'). "
"If you are not able to deduct the currency then use your knowldedge and predict the 'currency' field using the 'supplier_address' data. "
"Shipment/invoice-level fields such as CAR NUMBER, SHIPPING POINT, SHIPMENT NUMBER, CURRENCY, etc., must go ONLY into the 'invoice_header', not as line item fields.\n"
"Use this schema:\n"
'{\n'
' "invoice_header": {\n'
' "car_number": "string or null",\n'
' "shipment_number": "string or null",\n'
' "shipping_point": "string or null",\n'
' "currency": "string or null",\n'
' "invoice_number": "string or null",\n'
' "invoice_date": "string or null",\n'
' "order_number": "string or null",\n'
' "customer_order_number": "string or null",\n'
' "our_order_number": "string or null",\n'
' "sales_order_number": "string or null",\n'
' "purchase_order_number": "string or null",\n'
' "order_date": "string or null",\n'
' "supplier_name": "string or null",\n'
' "supplier_address": "string or null",\n'
' "supplier_phone": "string or null",\n'
' "supplier_email": "string or null",\n'
' "supplier_tax_id": "string or null",\n'
' "customer_name": "string or null",\n'
' "customer_address": "string or null",\n'
' "customer_phone": "string or null",\n'
' "customer_email": "string or null",\n'
' "customer_tax_id": "string or null",\n'
' "ship_to_name": "string or null",\n'
' "ship_to_address": "string or null",\n'
' "bill_to_name": "string or null",\n'
' "bill_to_address": "string or null",\n'
' "remit_to_name": "string or null",\n'
' "remit_to_address": "string or null",\n'
' "tax_id": "string or null",\n'
' "tax_registration_number": "string or null",\n'
' "vat_number": "string or null",\n'
' "payment_terms": "string or null",\n'
' "payment_method": "string or null",\n'
' "payment_reference": "string or null",\n'
' "bank_account_number": "string or null",\n'
' "iban": "string or null",\n'
' "swift_code": "string or null",\n'
' "total_before_tax": "string or null",\n'
' "tax_amount": "string or null",\n'
' "tax_rate": "string or null",\n'
' "shipping_charges": "string or null",\n'
' "discount": "string or null",\n'
' "total_due": "string or null",\n'
' "amount_paid": "string or null",\n'
' "balance_due": "string or null",\n'
' "due_date": "string or null",\n'
' "invoice_status": "string or null",\n'
' "reference_number": "string or null",\n'
' "project_code": "string or null",\n'
' "department": "string or null",\n'
' "contact_person": "string or null",\n'
' "notes": "string or null",\n'
' "additional_info": "string or null"\n'
' },\n'
' "line_items": [\n'
' {\n'
' "quantity": "string or null",\n'
' "units": "string or null",\n'
' "description": "string or null",\n'
' "footage": "string or null",\n'
' "price": "string or null",\n'
' "amount": "string or null",\n'
' "notes": "string or null"\n'
' }\n'
' ]\n'
'}'
"\nIf a field is missing for a line item or header, use null. "
"Do not invent fields. Do not add any header or shipment data to any line item. Return ONLY the JSON object, no explanation.\n"
"\nInvoice Text:\n"
f"{txt}"
)
def clean_json_response(text):
if not text: return None
orig = text
text = re.sub(r'```(?:json)?', '', text).strip()
start, end = text.find('{'), text.rfind('}') + 1
if start < 0 or end < 1:
st.error("Couldn't locate JSON in response."); st.code(orig); return None
frag = text[start:end]
frag = re.sub(r',\s*([}\]])', r'\1', frag)
try:
return json.loads(frag)
except json.JSONDecodeError as e:
repaired = re.sub(r'"\s*"\s*(?="[^"]+"\s*:)', '","', frag)
try:
return json.loads(repaired)
except json.JSONDecodeError:
st.error(f"JSON parse error: {e}"); st.code(frag); return None
def ensure_total_due(invoice_header):
if invoice_header.get("total_due") in [None, ""]:
for field in ["invoice_total", "invoice_value", "total_before_tax", "balance_due", "amount_paid"]:
if field in invoice_header and invoice_header[field]:
invoice_header["total_due"] = invoice_header[field]
break
return invoice_header
def clean_num(val):
if val is None: return None
if isinstance(val, (int, float)): return float(val)
matches = re.findall(r"[-+]?\d[\d,]*\.?\d*", str(val))
if matches:
cleaned = [m.replace(',', '') for m in matches if m]
as_floats = [float(c) for c in cleaned if c.replace('.', '', 1).isdigit()]
if as_floats: return max(as_floats)
return None
def weighted_fuzzy_score(s1, s2):
norm1 = str(s1).strip().lower()
norm2 = str(s2).strip().lower()
norm1 = " ".join(norm1.split())
norm2 = " ".join(norm2.split())
if not norm1 and not norm2:
return 100
if norm1 == norm2:
return 100
return fuzz.token_set_ratio(norm1, norm2)
def find_po_number_in_json(po_number, invoice_json):
def _flatten(obj):
fields = []
if isinstance(obj, dict):
for v in obj.values():
fields.extend(_flatten(v))
elif isinstance(obj, list):
for item in obj:
fields.extend(_flatten(item))
elif obj is not None:
fields.append(str(obj))
return fields
po_str = str(po_number).strip().replace(" ", "").replace(".0", "")
try:
po_int = str(int(float(po_number)))
except:
po_int = po_str
all_strs = [str(s).strip().replace(" ", "").replace(".0", "") for s in _flatten(invoice_json)]
for s in all_strs:
if not s: continue
if po_str and (po_str in s or s in po_str): return True
if po_int and (po_int in s or s in po_int): return True
return False
def drop_embedding_keys(d):
if hasattr(d, "to_dict"):
d = d.to_dict()
return {k: v for k, v in d.items() if "embedding" not in k}
# --- UI Layout ---
st.markdown("<h1 style='font-weight:800; margin-bottom:8px;'>EZOFIS Accounts Payable Agent</h1>", unsafe_allow_html=True)
st.markdown("<div style='font-size:20px; margin-bottom:28px; color:#24345C;'>Modern workflow automation for finance teams</div>", unsafe_allow_html=True)
col1, col2, col3 = st.columns([2,2,3])
# --- col1: PO CSV, Model & Weights, Threshold Sliders ---
with col1:
st.markdown("<span class='step-num'>1</span> <b>Upload Active Purchase Orders (POs)</b>", unsafe_allow_html=True)
po_file = st.file_uploader("CSV with PO number, Supplier, Items, etc.", type=["csv"], key="po_csv", label_visibility="collapsed")
if 'po_file_bytes' not in st.session_state:
st.session_state['po_file_bytes'] = None
if 'last_po_df' not in st.session_state:
st.session_state['last_po_df'] = None
if po_file is not None:
file_bytes = po_file.getvalue()
if st.session_state['po_file_bytes'] != file_bytes:
st.session_state['po_file_bytes'] = file_bytes
df = pd.read_csv(StringIO(po_file.getvalue().decode("utf-8")))
# Semantic normalization (embeddings)
df = semantic_header_normalization(
df, CANON_HEADERS,
st.session_state['canon_header_vectors'],
st.session_state['embedding_model']
)
# Precompute and store embeddings for PO line item descriptions
for c in ["Line Item Description", "Item Description", "Description"]:
if c in df.columns:
po_line_descs = df[c].astype(str).fillna("").tolist()
po_line_desc_embeds = st.session_state['embedding_model'].encode(po_line_descs)
df["_desc_embedding"] = list(po_line_desc_embeds)
break
st.session_state['last_po_df'] = df
st.success(f"Loaded {len(df)} records from uploaded CSV.")
else:
df = st.session_state['last_po_df']
if df is not None:
st.success(f"Loaded {len(df)} records from uploaded CSV.")
else:
df = None
st.markdown("<span class='step-num'>2</span> <b>Select Model & Scoring Weights</b>", unsafe_allow_html=True)
mdl = st.selectbox("LLM Model", list(MODELS.keys()), key="extract_model", index=0)
def int_slider(label, value, key):
return st.slider(label, 0, 100, value, 1, key=key, format="%d")
weight_supplier = int_slider("Supplier Name (%)", 35, "w_supplier")
weight_po_number = int_slider("PO Number (%)", 35, "w_po")
weight_currency = int_slider("Currency (%)", 10, "w_curr")
weight_total_due = int_slider("Total Due (%)", 10, "w_due")
weight_line_item = int_slider("Line Item (%)", 10, "w_line")
weight_sum = weight_supplier + weight_po_number + weight_currency + weight_total_due + weight_line_item
if weight_sum != 100:
st.warning(f"Sum of weights is {weight_sum}%. Adjust so it equals 100%.")
st.markdown("<span class='step-num'>3</span> <b>Set Decision Thresholds</b>", unsafe_allow_html=True)
approved_threshold = st.slider(
"Threshold for 'APPROVED'",
min_value=0,
max_value=100,
value=90,
format="%d"
)
partial_threshold = st.slider(
"Threshold for 'PARTIALLY APPROVED'",
min_value=0,
max_value=approved_threshold - 1,
value=60,
format="%d"
)
# --- col2: Invoice Upload & Extract Button ---
with col2:
st.markdown("<span class='step-num'>4</span> <b>Upload Invoice/Document</b>", unsafe_allow_html=True)
inv_file = st.file_uploader("Upload PDF, DOCX, XLSX, PNG, JPG, TIFF", type=["pdf", "docx", "xlsx", "xls", "png", "jpg", "jpeg", "tiff"], key="invoice_file", label_visibility="collapsed")
st.markdown("<span class='step-num'>5</span> <b>Extract Data</b>", unsafe_allow_html=True)
if st.button("Extract"):
if inv_file:
with st.spinner("Extracting text from document..."):
filename = getattr(inv_file, "name", "uploaded_file")
file_bytes = inv_file.read()
content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
headers = {
"unstract-key": os.getenv("UNSTRACT_API_KEY"),
"Content-Type": content_type,
}
url = "https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper"
r = requests.post(url, headers=headers, data=file_bytes)
if r.status_code != 202:
st.error(f"Unstract: Error uploading file: {r.status_code} - {r.text}")
else:
whisper_hash = r.json().get("whisper_hash")
if whisper_hash:
status_url = f"https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper-status?whisper_hash={whisper_hash}"
status_placeholder = st.empty()
for i in range(30):
status_r = requests.get(status_url, headers={"unstract-key": os.getenv("UNSTRACT_API_KEY")})
if status_r.status_code != 200:
st.error(f"Unstract: Error checking status: {status_r.status_code} - {status_r.text}")
break
status = status_r.json().get("status")
if status == "processed":
status_placeholder.info("EZOFIS AI OCR AGENT STATUS: processed! 🎉")
break
status_placeholder.info(f"EZOFIS AI OCR AGENT STATUS: {status or 'waiting'}... ({i+1})")
time.sleep(2)
else:
status_placeholder.error("Unstract: Timeout waiting for OCR to finish.")
retrieve_url = f"https://llmwhisperer-api.us-central.unstract.com/api/v2/whisper-retrieve?whisper_hash={whisper_hash}&text_only=true"
r = requests.get(retrieve_url, headers={"unstract-key": os.getenv("UNSTRACT_API_KEY")})
if r.status_code != 200:
st.error(f"Unstract: Error retrieving extracted text: {r.status_code} - {r.text}")
else:
try:
data = r.json()
extracted_text = data.get("result_text") or r.text
except Exception:
extracted_text = r.text
with st.spinner("Fine Tuning The Extracted Output..."):
llm = get_llm(mdl)
response = llm.invoke([{"role": "user", "content": get_extraction_prompt(extracted_text)}])
result = response.content if hasattr(response, "content") else response
extracted_info = clean_json_response(result)
if extracted_info:
st.success("Extraction Complete")
if "invoice_header" in extracted_info:
extracted_info["invoice_header"] = ensure_total_due(extracted_info["invoice_header"])
st.session_state['last_extracted_info'] = extracted_info
else:
st.warning("Please upload an invoice/document first.")
# --- col3: Decision, Side-by-Side, and Line Item Hybrid Matching with Navigation ---
with col3:
st.markdown("<span class='step-num'>6</span> <b>AP Agent Decision (LLM Powered)</b>", unsafe_allow_html=True)
decision_made = st.button("Make a decision (EZOFIS AP AGENT)")
if decision_made:
extracted_info = st.session_state.get('last_extracted_info', None)
po_df = st.session_state.get('last_po_df', None)
if extracted_info is not None and po_df is not None:
def po_match_tool_func(text):
inv = extracted_info
scores = []
inv_hdr = inv["invoice_header"]
inv_supplier = inv_hdr.get("supplier_name") or ""
inv_po_number = inv_hdr.get("purchase_order_number") or inv_hdr.get("po_number") or inv_hdr.get("order_number") or ""
inv_currency = inv_hdr.get("currency") or ""
inv_total_due = clean_num(inv_hdr.get("total_due"))
inv_line_items = inv.get("line_items", [])
embedding_model = st.session_state['embedding_model']
matched_po_indices = set() # <--- NEW: to track matched PO indices
# --- Hybrid Line Item Matching: Best-Match for Each Invoice Line ---
for idx, row in po_df.iterrows():
po_supplier = row.get("Supplier Name", "")
po_po_number = str(row.get("PO Number", ""))
po_currency = row.get("Currency", "")
po_total = clean_num(row.get("PO Total Value", "")) or clean_num(row.get("PO Total", ""))
po_desc = row.get("Line Item Description", "") or row.get("Item Description", "") or row.get("Description", "")
po_qty = str(row.get("Qty", "")) or str(row.get("Item Quantity", "")) or str(row.get("Quantity", ""))
po_unit = str(row.get("Rate", "")) or str(row.get("Item Unit Price", "")) or str(row.get("Unit Price", "")) or str(row.get("Price", ""))
po_line_total = clean_num(row.get("Amount", "")) or clean_num(row.get("Line Item Total", "")) or clean_num(row.get("Line Amount", ""))
po_desc_emb = embedding_model.encode(po_desc) if po_desc else None
field_details = []
s_supplier = weighted_fuzzy_score(inv_supplier, po_supplier)
field_details.append({
"field": "Supplier Name", "invoice": inv_supplier, "po": po_supplier, "score": s_supplier
})
s_po_number = 100 if find_po_number_in_json(po_po_number, inv) else 0
field_details.append({
"field": "PO Number", "invoice": po_po_number, "po": po_po_number, "score": s_po_number
})
s_currency = weighted_fuzzy_score(inv_currency, po_currency)
field_details.append({
"field": "Currency", "invoice": inv_currency, "po": po_currency, "score": s_currency
})
s_total = 100 if inv_total_due is not None and po_total is not None and abs(inv_total_due - po_total) < 2 else 0
field_details.append({
"field": "Total Due", "invoice": inv_total_due, "po": po_total, "score": s_total
})
all_line_matches = []
po_lines = po_df[po_df["PO Number"] == po_po_number]
# Hybrid: For each invoice line item, best-match to a PO line
for inv_line in inv_line_items:
best_score = 0
best_po_line = None
best_po_idx = None
desc = inv_line.get("description", "")
qty = inv_line.get("quantity", "")
price = inv_line.get("price", "")
amount = inv_line.get("amount", "")
inv_desc_emb = embedding_model.encode(desc) if desc else None
# Try all matching PO lines
for po_idx, po_line in po_lines.iterrows():
po_ldesc = po_line.get("Line Item Description", "") or po_line.get("Item Description", "") or po_line.get("Description", "")
po_lqty = str(po_line.get("Qty", "")) or str(po_line.get("Item Quantity", "")) or str(po_line.get("Quantity", ""))
po_lunit = str(po_line.get("Rate", "")) or str(po_line.get("Item Unit Price", "")) or str(po_line.get("Unit Price", "")) or str(po_line.get("Price", ""))
po_lamt = clean_num(po_line.get("Amount", "")) or clean_num(po_line.get("Line Item Total", "")) or clean_num(po_line.get("Line Amount", ""))
po_ldesc_emb = po_line.get("_desc_embedding", None)
desc_score_sem = 0
if po_ldesc_emb is not None and inv_desc_emb is not None:
desc_score_sem = float(np.dot(po_ldesc_emb, inv_desc_emb) / (np.linalg.norm(po_ldesc_emb)*np.linalg.norm(inv_desc_emb)))
desc_score_sem = int(100*desc_score_sem)
desc_score_fuz = weighted_fuzzy_score(desc, po_ldesc)
desc_score = int((desc_score_sem + desc_score_fuz)/2)
qty_score = 100 if clean_num(qty) == clean_num(po_lqty) else 0
unit_score = 100 if clean_num(price) == clean_num(po_lunit) else 0
amount_score = 100 if clean_num(amount) == po_lamt else 0
total = desc_score * 0.5 + qty_score * 0.2 + unit_score * 0.15 + amount_score * 0.15
if total > best_score:
best_score = total
best_po_line = {
"description": po_ldesc, "quantity": po_lqty,
"price": po_lunit, "amount": po_lamt
}
best_po_idx = po_idx
best_detail = {
"field": "Line Item",
"invoice": {"description": desc, "quantity": qty, "price": price, "amount": amount},
"po": best_po_line,
"desc_score": desc_score, "qty_score": qty_score,
"unit_score": unit_score, "amount_score": amount_score,
"line_item_score": total,
"po_df_index": po_idx # <--- track matched index
}
if best_po_line:
all_line_matches.append(best_detail)
matched_po_indices.add(best_po_idx) # <--- add matched index
# Choose best line (highest total) as the match for this PO
line_item_score = max([m["line_item_score"] for m in all_line_matches], default=0)
best_line_detail = max(all_line_matches, key=lambda m: m["line_item_score"], default=None)
wsum = weight_supplier + weight_po_number + weight_currency + weight_total_due + weight_line_item
total_score = (
s_supplier * weight_supplier/100 +
s_po_number * weight_po_number/100 +
s_currency * weight_currency/100 +
s_total * weight_total_due/100 +
line_item_score * weight_line_item/100
) if wsum == 100 else 0
reason = (
f"Supplier match: {s_supplier}/100 (invoice: '{inv_supplier}' vs PO: '{po_supplier}'), "
f"PO Number: {s_po_number}/100 ({'found anywhere in JSON' if s_po_number else 'not found'}), "
f"Currency: {s_currency}/100 (invoice: '{inv_currency}' vs PO: '{po_currency}'), "
f"Total Due: {'match' if s_total else 'no match'} (invoice: {inv_total_due} vs PO: {po_total}), "
f"Line item best match: {int(line_item_score)}/100."
)
debug = {
"po_idx": idx, "po_supplier": po_supplier, "po_po_number": po_po_number, "po_total": po_total,
"scores": field_details, "line_item_score": line_item_score, "best_line_detail": best_line_detail,
"all_line_matches": all_line_matches, "total_score": total_score, "inv_total_due": inv_total_due
}
scores.append((row, total_score, reason, debug))
scores.sort(key=lambda tup: tup[1], reverse=True)
if not scores:
return json.dumps({"decision": "REJECTED", "reason": "No POs found.", "debug": {}})
best_row, best_score, reason, debug = scores[0]
return json.dumps({
"score": best_score,
"approved_threshold": approved_threshold,
"partial_threshold": partial_threshold,
"reason": f"Best match score: {int(best_score)}/100. {reason}",
"debug": debug,
"po_row": drop_embedding_keys(best_row) if best_row is not None else None,
"matched_po_indices": list(matched_po_indices) # <--- return this
})
tool = Tool(
name="po_match_tool",
func=po_match_tool_func,
description="Returns JSON with score, thresholds, details. LLM must use this result with strict user logic.",
)
agent_llm = get_llm(mdl)
agent = initialize_agent(
[tool],
agent_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=2,
)
prompt = (
"You are an expert accounts payable agent.\n"
"You MUST use the output of po_match_tool to make the approval decision.\n"
f"User-configured thresholds are: APPROVED = {approved_threshold}, PARTIALLY APPROVED = {partial_threshold}.\n"
"You must STRICTLY follow this logic with NO exceptions and NO overrides:\n"
"- If the score from po_match_tool is greater than or equal to the APPROVED threshold, return decision = APPROVED.\n"
"- Else, if the score is greater than or equal to the PARTIALLY APPROVED threshold, return decision = PARTIALLY APPROVED.\n"
"- Else, return decision = REJECTED.\n"
"DO NOT reject just because of a field mismatch. Do not add your own rules.\n"
"Use only this logic for all decisions, based on the score and thresholds from po_match_tool.\n"
"Always return JSON using this schema, and nothing else:\n"
'{\n'
' "decision": "APPROVED" | "PARTIALLY APPROVED" | "REJECTED",\n'
' "reason": "...",\n'
' "debug": {...},\n'
' "po_row": {...},\n'
' "matched_po_indices": [index, ...]\n'
'}\n'
"First, call po_match_tool and get its JSON. Then apply the threshold rules above to its output.\n"
"Return only the decision JSON.\n"
f"Invoice JSON:\n{json.dumps(extracted_info, indent=2)}"
)
with st.spinner("AP Agent reasoning and making a decision..."):
result = agent.run(prompt)
try:
result_json = clean_json_response(result)
st.session_state['result_json'] = result_json
st.session_state['extracted_info_for_decision'] = extracted_info
st.session_state['line_item_idx'] = 0
except Exception:
st.session_state['result_json'] = None
st.session_state['extracted_info_for_decision'] = None
st.session_state['line_item_idx'] = 0
st.subheader("Decision output not in standard format.")
st.code(result)
# --- Always show results if available ---
result_json = st.session_state.get('result_json', None)
extracted_info = st.session_state.get('extracted_info_for_decision', None)
po_df = st.session_state.get('last_po_df', None) # <-- make sure this is available
if result_json and extracted_info:
status = result_json.get('decision', 'N/A')
score = result_json.get('debug', {}).get('total_score', None)
color = "#4CAF50" if status == "APPROVED" else "#FFC107" if status == "PARTIALLY APPROVED" else "#F44336"
st.markdown(f"""
<div style='
border:2px solid {color};
border-radius:12px;
padding:18px 12px;
margin-bottom:14px;
text-align:center;
background: #f8f9fb;
'>
<div style='font-size:2.1em;font-weight:800;color:{color};'>{status}</div>
<div style='font-size:1.2em;color:#24345C;font-weight:500;margin-top:8px;'>
Total Score: <span style='color:{color};font-weight:700;font-size:1.1em;'>{score if score is not None else '--'}</span>
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("#### Side-by-Side Field Matching")
field_scores = result_json.get("debug", {}).get("scores", [])
if field_scores:
table_rows = ""
for fs in field_scores:
table_rows += f"<tr><td><b>{fs['field']}</b></td><td>{fs['invoice']}</td><td>{fs['po']}</td><td>{fs['score']}</td></tr>"
st.markdown(
f"""
<table style='width:100%;border-collapse:collapse;margin-bottom:20px;'>
<thead>
<tr style='background:#f7f7fa;'><th>Field</th><th>Invoice Value</th><th>PO Value</th><th>Score</th></tr>
</thead>
<tbody>
{table_rows}
</tbody>
</table>
""",
unsafe_allow_html=True
)
st.markdown("#### Line Item Side-by-Side Matching")
invoice_line_items = result_json.get("debug", {}).get("all_line_matches", [])
num_items = len(invoice_line_items)
key = "line_item_idx"
if key not in st.session_state or st.session_state[key] >= num_items:
st.session_state[key] = 0
if num_items > 0:
cols = st.columns([1, 3, 1])
with cols[0]:
if st.button("⬅️", key="prev_btn", disabled=st.session_state[key] == 0):
st.session_state[key] = max(0, st.session_state[key] - 1)
with cols[1]:
st.markdown(
f"<div style='text-align:center;font-weight:500;font-size:1.1em; margin-top:6px;'>"
f"Line Item {st.session_state[key]+1} of {num_items}"
f"</div>", unsafe_allow_html=True
)
with cols[2]:
if st.button("➡️", key="next_btn", disabled=st.session_state[key] == num_items - 1):
st.session_state[key] = min(num_items - 1, st.session_state[key] + 1)
idx = st.session_state[key]
inv_line = invoice_line_items[idx]
inv = inv_line.get("invoice", {})
best_po = inv_line.get("po", {})
desc_score = inv_line.get("desc_score", 0)
qty_score = inv_line.get("qty_score", 0)
unit_score = inv_line.get("unit_score", 0)
amt_score = inv_line.get("amount_score", 0)
line_score = int(inv_line.get("line_item_score", 0))
fields = ["Description", "Quantity", "Price", "Amount"]
table_html = "<table style='width:90%;border-collapse:collapse;margin-bottom:24px;'>"
table_html += "<thead><tr><th></th>"
for f in fields:
table_html += f"<th style='text-align:center'>{f}</th>"
table_html += "</tr></thead><tbody>"
table_html += "<tr><td><b>Invoice Value</b></td>"
table_html += f"<td>{inv.get('description', '')}</td>"
table_html += f"<td>{inv.get('quantity', '')}</td>"
table_html += f"<td>{inv.get('price', '')}</td>"
table_html += f"<td>{inv.get('amount', '')}</td></tr>"
table_html += "<tr><td><b>PO Value</b></td>"
table_html += f"<td>{best_po.get('description', '')}</td>"
table_html += f"<td>{best_po.get('quantity', '')}</td>"
table_html += f"<td>{best_po.get('price', '')}</td>"
table_html += f"<td>{best_po.get('amount', '')}</td></tr>"
table_html += "<tr><td><b>Score</b></td>"
table_html += f"<td>{desc_score}</td><td>{qty_score}</td><td>{unit_score}</td><td>{amt_score}</td></tr>"
table_html += "<tr><td><b>Line Score</b></td><td></td><td></td><td></td><td><b>{}</b></td></tr>".format(line_score)
table_html += "</tbody></table>"
st.markdown(table_html, unsafe_allow_html=True)
else:
st.info("No line items found.")
# --------- PO LINE ITEMS MISSING IN INVOICE ---------
matched_po_row = result_json.get("po_row", None)
matched_indices = set(result_json.get("matched_po_indices", []))
if matched_po_row and po_df is not None:
po_no = matched_po_row.get("PO Number", None)
po_lines_df = po_df[po_df["PO Number"] == po_no]
missing_po_lines = po_lines_df.loc[~po_lines_df.index.isin(matched_indices)]
if not missing_po_lines.empty:
st.markdown("#### PO Line Items Missing in Invoice")
table_html = "<table style='width:90%;border-collapse:collapse;margin-bottom:24px;'>"
table_html += "<thead><tr>"
for f in ["Description", "Quantity", "Price", "Amount"]:
table_html += f"<th style='text-align:center'>{f}</th>"
table_html += "</tr></thead><tbody>"
for _, row in missing_po_lines.iterrows():
table_html += "<tr>"
table_html += f"<td>{row.get('Description', '') or row.get('Item Description', '') or row.get('Line Item Description', '')}</td>"
table_html += f"<td>{row.get('Quantity', '') or row.get('Item Quantity', '') or row.get('Qty', '')}</td>"
table_html += f"<td>{row.get('Price', '') or row.get('Item Unit Price', '') or row.get('Rate', '')}</td>"
table_html += f"<td>{row.get('Amount', '') or row.get('Line Item Total', '') or row.get('Line Amount', '')}</td>"
table_html += "</tr>"
table_html += "</tbody></table>"
st.markdown(table_html, unsafe_allow_html=True)
else:
st.info("No PO line items missing in invoice.")
st.markdown("#### Reason")
st.write(result_json.get('reason', 'N/A'))
st.markdown("#### Debug & Matching Details (Full JSON)")
st.json(result_json.get('debug'))
st.markdown("#### Extracted Invoice JSON")
st.json(extracted_info)
st.markdown("#### Matched PO Row")
st.json(result_json.get('po_row'))
else:
st.info("Run a decision to see results.")