Seth0330's picture
Update app.py
c1904cf verified
import streamlit as st
import requests
import json
import re
import os
import time
import mimetypes
st.set_page_config(page_title="PDF Tools", layout="wide")
# -------- LLM Model Setup (same as before) --------
MODELS = {
"DeepSeek v3": {
"api_url": "https://api.deepseek.com/v1/chat/completions",
"model": "deepseek-chat",
"key_env": "DEEPSEEK_API_KEY",
"response_format": {"type": "json_object"},
},
"DeepSeek R1": {
"api_url": "https://api.deepseek.com/v1/chat/completions",
"model": "deepseek-reasoner",
"key_env": "DEEPSEEK_API_KEY",
"response_format": None,
},
"OpenAI GPT-4.1": {
"api_url": "https://api.openai.com/v1/chat/completions",
"model": "gpt-4-1106-preview",
"key_env": "OPENAI_API_KEY",
"response_format": None,
"extra_headers": {},
},
"Mistral Small": {
"api_url": "https://openrouter.ai/api/v1/chat/completions",
"model": "mistralai/ministral-8b",
"key_env": "OPENROUTER_API_KEY",
"response_format": {"type": "json_object"},
"extra_headers": {
"HTTP-Referer": "https://huggingface.co",
"X-Title": "Invoice Extractor",
},
},
}
def get_api_key(model_choice):
key = os.getenv(MODELS[model_choice]["key_env"])
if not key:
st.error(f"❌ {MODELS[model_choice]['key_env']} not set")
st.stop()
return key
def query_llm(model_choice, prompt):
cfg = MODELS[model_choice]
headers = {
"Authorization": f"Bearer {get_api_key(model_choice)}",
"Content-Type": "application/json",
}
if cfg.get("extra_headers"):
headers.update(cfg["extra_headers"])
payload = {
"model": cfg["model"],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.1,
"max_tokens": 2000,
}
if cfg.get("response_format"):
payload["response_format"] = cfg["response_format"]
try:
with st.spinner(f"🔍 Querying {model_choice}..."):
r = requests.post(cfg["api_url"], headers=headers, json=payload, timeout=90)
if r.status_code != 200:
if "No instances available" in r.text or r.status_code == 503:
st.error(f"{model_choice} is currently unavailable. Please try again later or select another model.")
else:
st.error(f"🚨 API Error {r.status_code}: {r.text}")
return None
content = r.json()["choices"][0]["message"]["content"]
st.session_state.last_api = content
st.session_state.last_raw = r.text
return content
except Exception as e:
st.error(f"Connection error: {e}")
return None
def clean_json_response(text):
if not text:
return None
orig = text
text = re.sub(r'```(?:json)?', '', text).strip()
start, end = text.find('{'), text.rfind('}') + 1
if start < 0 or end < 1:
st.error("Couldn't locate JSON in response.")
st.code(orig)
return None
frag = text[start:end]
frag = re.sub(r',\s*([}\]])', r'\1', frag)
try:
return json.loads(frag)
except json.JSONDecodeError as e:
repaired = re.sub(r'"\s*"\s*(?="[^"]+"\s*:)', '","', frag)
try:
return json.loads(repaired)
except json.JSONDecodeError:
st.error(f"JSON parse error: {e}")
st.code(frag)
return None
def fallback_supplier(text):
for line in text.splitlines():
line = line.strip()
if line:
return line
return None
def get_extraction_prompt(model_choice, txt):
return (
"You are an expert invoice parser. "
"Extract data according to the visible table structure and column headers in the invoice. "
"For every line item, only extract fields that correspond to the table columns for that row (do not include header/shipment fields in line items). "
"Merge all multi-line content within a single cell into that field (especially for the 'description' and 'notes'). "
"Shipment/invoice-level fields such as CAR NUMBER, SHIPPING POINT, SHIPMENT NUMBER, CURRENCY, etc., must go ONLY into the 'invoice_header', not as line item fields.\n"
"Use this schema:\n"
'{\n'
' "invoice_header": {\n'
' "car_number": "string or null",\n'
' "shipment_number": "string or null",\n'
' "shipping_point": "string or null",\n'
' "currency": "string or null",\n'
' "invoice_number": "string or null",\n'
' "invoice_date": "string or null",\n'
' "order_number": "string or null",\n'
' "customer_order_number": "string or null",\n'
' "our_order_number": "string or null",\n'
' "sales_order_number": "string or null",\n'
' "purchase_order_number": "string or null",\n'
' "order_date": "string or null",\n'
' "supplier_name": "string or null",\n'
' "supplier_address": "string or null",\n'
' "supplier_phone": "string or null",\n'
' "supplier_email": "string or null",\n'
' "supplier_tax_id": "string or null",\n'
' "customer_name": "string or null",\n'
' "customer_address": "string or null",\n'
' "customer_phone": "string or null",\n'
' "customer_email": "string or null",\n'
' "customer_tax_id": "string or null",\n'
' "ship_to_name": "string or null",\n'
' "ship_to_address": "string or null",\n'
' "bill_to_name": "string or null",\n'
' "bill_to_address": "string or null",\n'
' "remit_to_name": "string or null",\n'
' "remit_to_address": "string or null",\n'
' "tax_id": "string or null",\n'
' "tax_registration_number": "string or null",\n'
' "vat_number": "string or null",\n'
' "payment_terms": "string or null",\n'
' "payment_method": "string or null",\n'
' "payment_reference": "string or null",\n'
' "bank_account_number": "string or null",\n'
' "iban": "string or null",\n'
' "swift_code": "string or null",\n'
' "total_before_tax": "string or null",\n'
' "tax_amount": "string or null",\n'
' "tax_rate": "string or null",\n'
' "shipping_charges": "string or null",\n'
' "discount": "string or null",\n'
' "total_due": "string or null",\n'
' "amount_paid": "string or null",\n'
' "balance_due": "string or null",\n'
' "due_date": "string or null",\n'
' "invoice_status": "string or null",\n'
' "reference_number": "string or null",\n'
' "project_code": "string or null",\n'
' "department": "string or null",\n'
' "contact_person": "string or null",\n'
' "notes": "string or null",\n'
' "additional_info": "string or null"\n'
' },\n'
' "line_items": [\n'
' {\n'
' "quantity": "string or null",\n'
' "units": "string or null",\n'
' "description": "string or null",\n'
' "footage": "string or null",\n'
' "price": "string or null",\n'
' "amount": "string or null",\n'
' "notes": "string or null"\n'
' }\n'
' ]\n'
'}'
"\nIf a field is missing for a line item or header, use null. "
"Do not invent fields. Do not add any header or shipment data to any line item. Return ONLY the JSON object, no explanation.\n"
"\nInvoice Text:\n"
f"{txt}"
)
def extract_invoice_info(model_choice, text):
prompt = get_extraction_prompt(model_choice, text)
raw = query_llm(model_choice, prompt)
if not raw:
return None
data = clean_json_response(raw)
if not data:
return None
if model_choice.startswith("DeepSeek"):
header = {k: v for k, v in data.items() if k != "line_items"}
items = data.get("line_items", [])
if not isinstance(items, list):
items = []
for itm in items:
if not isinstance(itm, dict):
continue
for k in ("description","quantity","unit_price","total_price"):
itm.setdefault(k, None)
return {"invoice_header": header, "line_items": items}
hdr = data.get("invoice_header", {})
if not hdr and any(k in data for k in ("invoice_number","supplier_name","customer_name")):
hdr = data
for k in ("invoice_number","invoice_date","po_number","invoice_value","supplier_name","customer_name"):
hdr.setdefault(k, None)
if not hdr.get("supplier_name"):
hdr["supplier_name"] = fallback_supplier(text)
items = data.get("line_items", [])
if not isinstance(items, list):
items = []
for itm in items:
if not isinstance(itm, dict):
continue
for k in ("item_number","description","quantity","unit_price","total_price"):
itm.setdefault(k, None)
return {"invoice_header": hdr, "line_items": items}
# --------- File type/content-type detection ---------
def get_content_type(filename):
mime, _ = mimetypes.guess_type(filename)
ext = filename.lower().split('.')[-1]
# Special case for PDF (Unstract quirk)
if ext == "pdf":
return "text/plain"
if mime is None:
return "application/octet-stream"
return mime
# --------- UNSTRACT API Multi-file PDF/Doc/Image-to-Text ---------
UNSTRACT_BASE = "https://llmwhisperer-api.us-central.unstract.com/api/v2"
UNSTRACT_API_KEY = os.getenv("UNSTRACT_API_KEY") # Set this in your environment!
def extract_text_from_unstract(uploaded_file):
filename = getattr(uploaded_file, "name", "uploaded_file")
file_bytes = uploaded_file.read()
content_type = get_content_type(filename)
headers = {
"unstract-key": UNSTRACT_API_KEY,
"Content-Type": content_type,
}
url = f"{UNSTRACT_BASE}/whisper"
with st.spinner("Uploading and processing document with Unstract..."):
r = requests.post(url, headers=headers, data=file_bytes)
if r.status_code != 202:
st.error(f"Unstract: Error uploading file: {r.status_code} - {r.text}")
return None
whisper_hash = r.json().get("whisper_hash")
if not whisper_hash:
st.error("Unstract: No whisper_hash received.")
return None
status_url = f"{UNSTRACT_BASE}/whisper-status?whisper_hash={whisper_hash}"
for i in range(30): # Wait up to 60s (2s x 30)
status_r = requests.get(status_url, headers={"unstract-key": UNSTRACT_API_KEY})
if status_r.status_code != 200:
st.error(f"Unstract: Error checking status: {status_r.status_code} - {status_r.text}")
return None
status = status_r.json().get("status")
if status == "processed":
break
st.info(f"Unstract status: {status or 'waiting'}... ({i+1})")
time.sleep(2)
else:
st.error("Unstract: Timeout waiting for OCR to finish.")
return None
retrieve_url = f"{UNSTRACT_BASE}/whisper-retrieve?whisper_hash={whisper_hash}&text_only=true"
r = requests.get(retrieve_url, headers={"unstract-key": UNSTRACT_API_KEY})
if r.status_code != 200:
st.error(f"Unstract: Error retrieving extracted text: {r.status_code} - {r.text}")
return None
try:
data = r.json()
return data.get("result_text") or r.text
except Exception:
return r.text
# --------- INVOICE EXTRACTOR UI ---------
st.title("Invoice/Document Extractor")
mdl = st.selectbox("Model", list(MODELS.keys()), key="extract_model")
inv_file = st.file_uploader(
"Invoice or Document File",
type=["pdf", "docx", "xlsx", "xls", "png", "jpg", "jpeg", "tiff"]
)
extracted_info = None
if st.button("Extract") and inv_file:
with st.spinner("Extracting text from document using Unstract..."):
text = extract_text_from_unstract(inv_file)
if text:
extracted_info = extract_invoice_info(mdl, text)
if extracted_info:
st.success("Extraction Complete")
st.subheader("Invoice Metadata")
st.table([{k.replace("_", " ").title(): v for k, v in extracted_info["invoice_header"].items()}])
st.subheader("Line Items")
st.table(extracted_info["line_items"])
st.session_state["last_extracted_info"] = extracted_info # store in session
# If we've already extracted info, or in this session, show further controls
extracted_info = extracted_info or st.session_state.get("last_extracted_info", None)
if extracted_info:
st.markdown("---")
st.subheader("📝 Fine-tune Extracted Data with Your Own Prompt")
user_prompt = st.text_area(
"Enter your prompt for further processing or transformation (the extracted JSON will be available as context).",
height=120,
key="custom_prompt"
)
model_2 = st.selectbox("Model for Fine-Tuning Prompt", list(MODELS.keys()), key="refine_model")
if st.button("Run Custom Prompt"):
refine_input = (
"Here is an extracted invoice in JSON format:\n"
f"{json.dumps(extracted_info, indent=2)}\n"
"Follow this instruction and return the result as a JSON object only (no explanation):\n"
f"{user_prompt}"
)
result = query_llm(model_2, refine_input)
refined_json = clean_json_response(result)
st.subheader("Fine-Tuned Output")
if refined_json:
st.json(refined_json)
else:
st.error("Could not parse a valid JSON output from the model.")
st.caption("The prompt is run on the above-extracted fields as JSON. Try instructions like: 'Add a new field for net_amount (amount minus tax) to each line item', or 'Summarize the total quantity ordered', etc.")
if "last_api" in st.session_state:
with st.expander("Debug"):
st.code(st.session_state.last_api)
st.code(st.session_state.last_raw)