Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import json | |
| import re | |
| import os | |
| import time | |
| import mimetypes | |
| st.set_page_config(page_title="PDF Tools", layout="wide") | |
| # -------- LLM Model Setup (same as before) -------- | |
| MODELS = { | |
| "DeepSeek v3": { | |
| "api_url": "https://api.deepseek.com/v1/chat/completions", | |
| "model": "deepseek-chat", | |
| "key_env": "DEEPSEEK_API_KEY", | |
| "response_format": {"type": "json_object"}, | |
| }, | |
| "DeepSeek R1": { | |
| "api_url": "https://api.deepseek.com/v1/chat/completions", | |
| "model": "deepseek-reasoner", | |
| "key_env": "DEEPSEEK_API_KEY", | |
| "response_format": None, | |
| }, | |
| "OpenAI GPT-4.1": { | |
| "api_url": "https://api.openai.com/v1/chat/completions", | |
| "model": "gpt-4-1106-preview", | |
| "key_env": "OPENAI_API_KEY", | |
| "response_format": None, | |
| "extra_headers": {}, | |
| }, | |
| "Mistral Small": { | |
| "api_url": "https://openrouter.ai/api/v1/chat/completions", | |
| "model": "mistralai/ministral-8b", | |
| "key_env": "OPENROUTER_API_KEY", | |
| "response_format": {"type": "json_object"}, | |
| "extra_headers": { | |
| "HTTP-Referer": "https://huggingface.co", | |
| "X-Title": "Invoice Extractor", | |
| }, | |
| }, | |
| } | |
| def get_api_key(model_choice): | |
| key = os.getenv(MODELS[model_choice]["key_env"]) | |
| if not key: | |
| st.error(f"❌ {MODELS[model_choice]['key_env']} not set") | |
| st.stop() | |
| return key | |
| def query_llm(model_choice, prompt): | |
| cfg = MODELS[model_choice] | |
| headers = { | |
| "Authorization": f"Bearer {get_api_key(model_choice)}", | |
| "Content-Type": "application/json", | |
| } | |
| if cfg.get("extra_headers"): | |
| headers.update(cfg["extra_headers"]) | |
| payload = { | |
| "model": cfg["model"], | |
| "messages": [{"role": "user", "content": prompt}], | |
| "temperature": 0.1, | |
| "max_tokens": 2000, | |
| } | |
| if cfg.get("response_format"): | |
| payload["response_format"] = cfg["response_format"] | |
| try: | |
| with st.spinner(f"🔍 Querying {model_choice}..."): | |
| r = requests.post(cfg["api_url"], headers=headers, json=payload, timeout=90) | |
| if r.status_code != 200: | |
| if "No instances available" in r.text or r.status_code == 503: | |
| st.error(f"{model_choice} is currently unavailable. Please try again later or select another model.") | |
| else: | |
| st.error(f"🚨 API Error {r.status_code}: {r.text}") | |
| return None | |
| content = r.json()["choices"][0]["message"]["content"] | |
| st.session_state.last_api = content | |
| st.session_state.last_raw = r.text | |
| return content | |
| except Exception as e: | |
| st.error(f"Connection error: {e}") | |
| return None | |
| def clean_json_response(text): | |
| if not text: | |
| return None | |
| orig = text | |
| text = re.sub(r'```(?:json)?', '', text).strip() | |
| start, end = text.find('{'), text.rfind('}') + 1 | |
| if start < 0 or end < 1: | |
| st.error("Couldn't locate JSON in response.") | |
| st.code(orig) | |
| return None | |
| frag = text[start:end] | |
| frag = re.sub(r',\s*([}\]])', r'\1', frag) | |
| try: | |
| return json.loads(frag) | |
| except json.JSONDecodeError as e: | |
| repaired = re.sub(r'"\s*"\s*(?="[^"]+"\s*:)', '","', frag) | |
| try: | |
| return json.loads(repaired) | |
| except json.JSONDecodeError: | |
| st.error(f"JSON parse error: {e}") | |
| st.code(frag) | |
| return None | |
| def fallback_supplier(text): | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if line: | |
| return line | |
| return None | |
| def get_extraction_prompt(model_choice, txt): | |
| return ( | |
| "You are an expert invoice parser. " | |
| "Extract data according to the visible table structure and column headers in the invoice. " | |
| "For every line item, only extract fields that correspond to the table columns for that row (do not include header/shipment fields in line items). " | |
| "Merge all multi-line content within a single cell into that field (especially for the 'description' and 'notes'). " | |
| "Shipment/invoice-level fields such as CAR NUMBER, SHIPPING POINT, SHIPMENT NUMBER, CURRENCY, etc., must go ONLY into the 'invoice_header', not as line item fields.\n" | |
| "Use this schema:\n" | |
| '{\n' | |
| ' "invoice_header": {\n' | |
| ' "car_number": "string or null",\n' | |
| ' "shipment_number": "string or null",\n' | |
| ' "shipping_point": "string or null",\n' | |
| ' "currency": "string or null",\n' | |
| ' "invoice_number": "string or null",\n' | |
| ' "invoice_date": "string or null",\n' | |
| ' "order_number": "string or null",\n' | |
| ' "customer_order_number": "string or null",\n' | |
| ' "our_order_number": "string or null",\n' | |
| ' "sales_order_number": "string or null",\n' | |
| ' "purchase_order_number": "string or null",\n' | |
| ' "order_date": "string or null",\n' | |
| ' "supplier_name": "string or null",\n' | |
| ' "supplier_address": "string or null",\n' | |
| ' "supplier_phone": "string or null",\n' | |
| ' "supplier_email": "string or null",\n' | |
| ' "supplier_tax_id": "string or null",\n' | |
| ' "customer_name": "string or null",\n' | |
| ' "customer_address": "string or null",\n' | |
| ' "customer_phone": "string or null",\n' | |
| ' "customer_email": "string or null",\n' | |
| ' "customer_tax_id": "string or null",\n' | |
| ' "ship_to_name": "string or null",\n' | |
| ' "ship_to_address": "string or null",\n' | |
| ' "bill_to_name": "string or null",\n' | |
| ' "bill_to_address": "string or null",\n' | |
| ' "remit_to_name": "string or null",\n' | |
| ' "remit_to_address": "string or null",\n' | |
| ' "tax_id": "string or null",\n' | |
| ' "tax_registration_number": "string or null",\n' | |
| ' "vat_number": "string or null",\n' | |
| ' "payment_terms": "string or null",\n' | |
| ' "payment_method": "string or null",\n' | |
| ' "payment_reference": "string or null",\n' | |
| ' "bank_account_number": "string or null",\n' | |
| ' "iban": "string or null",\n' | |
| ' "swift_code": "string or null",\n' | |
| ' "total_before_tax": "string or null",\n' | |
| ' "tax_amount": "string or null",\n' | |
| ' "tax_rate": "string or null",\n' | |
| ' "shipping_charges": "string or null",\n' | |
| ' "discount": "string or null",\n' | |
| ' "total_due": "string or null",\n' | |
| ' "amount_paid": "string or null",\n' | |
| ' "balance_due": "string or null",\n' | |
| ' "due_date": "string or null",\n' | |
| ' "invoice_status": "string or null",\n' | |
| ' "reference_number": "string or null",\n' | |
| ' "project_code": "string or null",\n' | |
| ' "department": "string or null",\n' | |
| ' "contact_person": "string or null",\n' | |
| ' "notes": "string or null",\n' | |
| ' "additional_info": "string or null"\n' | |
| ' },\n' | |
| ' "line_items": [\n' | |
| ' {\n' | |
| ' "quantity": "string or null",\n' | |
| ' "units": "string or null",\n' | |
| ' "description": "string or null",\n' | |
| ' "footage": "string or null",\n' | |
| ' "price": "string or null",\n' | |
| ' "amount": "string or null",\n' | |
| ' "notes": "string or null"\n' | |
| ' }\n' | |
| ' ]\n' | |
| '}' | |
| "\nIf a field is missing for a line item or header, use null. " | |
| "Do not invent fields. Do not add any header or shipment data to any line item. Return ONLY the JSON object, no explanation.\n" | |
| "\nInvoice Text:\n" | |
| f"{txt}" | |
| ) | |
| def extract_invoice_info(model_choice, text): | |
| prompt = get_extraction_prompt(model_choice, text) | |
| raw = query_llm(model_choice, prompt) | |
| if not raw: | |
| return None | |
| data = clean_json_response(raw) | |
| if not data: | |
| return None | |
| if model_choice.startswith("DeepSeek"): | |
| header = {k: v for k, v in data.items() if k != "line_items"} | |
| items = data.get("line_items", []) | |
| if not isinstance(items, list): | |
| items = [] | |
| for itm in items: | |
| if not isinstance(itm, dict): | |
| continue | |
| for k in ("description","quantity","unit_price","total_price"): | |
| itm.setdefault(k, None) | |
| return {"invoice_header": header, "line_items": items} | |
| hdr = data.get("invoice_header", {}) | |
| if not hdr and any(k in data for k in ("invoice_number","supplier_name","customer_name")): | |
| hdr = data | |
| for k in ("invoice_number","invoice_date","po_number","invoice_value","supplier_name","customer_name"): | |
| hdr.setdefault(k, None) | |
| if not hdr.get("supplier_name"): | |
| hdr["supplier_name"] = fallback_supplier(text) | |
| items = data.get("line_items", []) | |
| if not isinstance(items, list): | |
| items = [] | |
| for itm in items: | |
| if not isinstance(itm, dict): | |
| continue | |
| for k in ("item_number","description","quantity","unit_price","total_price"): | |
| itm.setdefault(k, None) | |
| return {"invoice_header": hdr, "line_items": items} | |
| # --------- File type/content-type detection --------- | |
| def get_content_type(filename): | |
| mime, _ = mimetypes.guess_type(filename) | |
| ext = filename.lower().split('.')[-1] | |
| # Special case for PDF (Unstract quirk) | |
| if ext == "pdf": | |
| return "text/plain" | |
| if mime is None: | |
| return "application/octet-stream" | |
| return mime | |
| # --------- UNSTRACT API Multi-file PDF/Doc/Image-to-Text --------- | |
| UNSTRACT_BASE = "https://llmwhisperer-api.us-central.unstract.com/api/v2" | |
| UNSTRACT_API_KEY = os.getenv("UNSTRACT_API_KEY") # Set this in your environment! | |
| def extract_text_from_unstract(uploaded_file): | |
| filename = getattr(uploaded_file, "name", "uploaded_file") | |
| file_bytes = uploaded_file.read() | |
| content_type = get_content_type(filename) | |
| headers = { | |
| "unstract-key": UNSTRACT_API_KEY, | |
| "Content-Type": content_type, | |
| } | |
| url = f"{UNSTRACT_BASE}/whisper" | |
| with st.spinner("Uploading and processing document with Unstract..."): | |
| r = requests.post(url, headers=headers, data=file_bytes) | |
| if r.status_code != 202: | |
| st.error(f"Unstract: Error uploading file: {r.status_code} - {r.text}") | |
| return None | |
| whisper_hash = r.json().get("whisper_hash") | |
| if not whisper_hash: | |
| st.error("Unstract: No whisper_hash received.") | |
| return None | |
| status_url = f"{UNSTRACT_BASE}/whisper-status?whisper_hash={whisper_hash}" | |
| for i in range(30): # Wait up to 60s (2s x 30) | |
| status_r = requests.get(status_url, headers={"unstract-key": UNSTRACT_API_KEY}) | |
| if status_r.status_code != 200: | |
| st.error(f"Unstract: Error checking status: {status_r.status_code} - {status_r.text}") | |
| return None | |
| status = status_r.json().get("status") | |
| if status == "processed": | |
| break | |
| st.info(f"Unstract status: {status or 'waiting'}... ({i+1})") | |
| time.sleep(2) | |
| else: | |
| st.error("Unstract: Timeout waiting for OCR to finish.") | |
| return None | |
| retrieve_url = f"{UNSTRACT_BASE}/whisper-retrieve?whisper_hash={whisper_hash}&text_only=true" | |
| r = requests.get(retrieve_url, headers={"unstract-key": UNSTRACT_API_KEY}) | |
| if r.status_code != 200: | |
| st.error(f"Unstract: Error retrieving extracted text: {r.status_code} - {r.text}") | |
| return None | |
| try: | |
| data = r.json() | |
| return data.get("result_text") or r.text | |
| except Exception: | |
| return r.text | |
| # --------- INVOICE EXTRACTOR UI --------- | |
| st.title("Invoice/Document Extractor") | |
| mdl = st.selectbox("Model", list(MODELS.keys()), key="extract_model") | |
| inv_file = st.file_uploader( | |
| "Invoice or Document File", | |
| type=["pdf", "docx", "xlsx", "xls", "png", "jpg", "jpeg", "tiff"] | |
| ) | |
| extracted_info = None | |
| if st.button("Extract") and inv_file: | |
| with st.spinner("Extracting text from document using Unstract..."): | |
| text = extract_text_from_unstract(inv_file) | |
| if text: | |
| extracted_info = extract_invoice_info(mdl, text) | |
| if extracted_info: | |
| st.success("Extraction Complete") | |
| st.subheader("Invoice Metadata") | |
| st.table([{k.replace("_", " ").title(): v for k, v in extracted_info["invoice_header"].items()}]) | |
| st.subheader("Line Items") | |
| st.table(extracted_info["line_items"]) | |
| st.session_state["last_extracted_info"] = extracted_info # store in session | |
| # If we've already extracted info, or in this session, show further controls | |
| extracted_info = extracted_info or st.session_state.get("last_extracted_info", None) | |
| if extracted_info: | |
| st.markdown("---") | |
| st.subheader("📝 Fine-tune Extracted Data with Your Own Prompt") | |
| user_prompt = st.text_area( | |
| "Enter your prompt for further processing or transformation (the extracted JSON will be available as context).", | |
| height=120, | |
| key="custom_prompt" | |
| ) | |
| model_2 = st.selectbox("Model for Fine-Tuning Prompt", list(MODELS.keys()), key="refine_model") | |
| if st.button("Run Custom Prompt"): | |
| refine_input = ( | |
| "Here is an extracted invoice in JSON format:\n" | |
| f"{json.dumps(extracted_info, indent=2)}\n" | |
| "Follow this instruction and return the result as a JSON object only (no explanation):\n" | |
| f"{user_prompt}" | |
| ) | |
| result = query_llm(model_2, refine_input) | |
| refined_json = clean_json_response(result) | |
| st.subheader("Fine-Tuned Output") | |
| if refined_json: | |
| st.json(refined_json) | |
| else: | |
| st.error("Could not parse a valid JSON output from the model.") | |
| st.caption("The prompt is run on the above-extracted fields as JSON. Try instructions like: 'Add a new field for net_amount (amount minus tax) to each line item', or 'Summarize the total quantity ordered', etc.") | |
| if "last_api" in st.session_state: | |
| with st.expander("Debug"): | |
| st.code(st.session_state.last_api) | |
| st.code(st.session_state.last_raw) | |