InvoiceRecon / app.py
Liviu16's picture
Update app.py
d45f115 verified
import gradio as gr
import torch
import json
import spaces
import fitz # PyMuPDF for PDF handling
from PIL import Image
import io
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
# --- DETAILED SCHEMAS ---
SCHEMAS = {
"VODAFONE": {
"vendor": "VODAFONE ROMANIA",
"invoice_number": "string",
"date": "string (DD-MM-YYYY)",
"client_name": "string",
"client_address": "string",
"account_id": "string",
"billing_period": "string",
"totals": {
"subtotal_no_vat": "number",
"vat_amount": "number",
"grand_total": "number",
"currency": "RON"
}
},
"DIGI": {
"vendor": "DIGI (RCS & RDS)",
"invoice_number": "string",
"contract_id": "string",
"total_amount": "number",
"iban": "string"
},
"GENERAL": {
"vendor_name": "string",
"invoice_id": "string",
"date": "string",
"total_with_vat": "number",
"client_name": "string"
}
}
# --- MODEL LOADING ---
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
def load_model():
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="cuda",
quantization_config=quant_config
)
processor = AutoProcessor.from_pretrained(MODEL_ID, max_pixels=1280*1280)
return model, processor
model, processor = load_model()
# --- PDF TO IMAGE HELPER ---
def get_pdf_page_image(pdf_path):
doc = fitz.open(pdf_path)
page = doc.load_page(0)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img = Image.open(io.BytesIO(pix.tobytes()))
doc.close()
return img
# --- INFERENCE ---
@spaces.GPU(duration=60)
def process_invoice(file_info, progress=gr.Progress()):
if file_info is None:
return None, {"error": "No file uploaded"}
# 1. Handle File Type and Preview
progress(0.1, desc="πŸ“„ Processing document...")
if file_info.name.lower().endswith(".pdf"):
image = get_pdf_page_image(file_info.name)
else:
image = Image.open(file_info.name)
# 2. Router & Validation (Identify Vendor or Reject)
progress(0.3, desc="πŸ” Validating and Identifying Vendor...")
# Updated prompt to provide an 'INVALID' exit
decision_prompt = """Analyze this image. Is it a financial invoice or receipt?
- If NO (e.g. random photo, object, landscape): Reply 'INVALID'.
- If YES: Reply ONLY with 'VODAFONE', 'DIGI', or 'GENERAL'."""
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": decision_prompt}]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=10)
raw_choice = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0].strip().upper()
# VALIDATION CHECK: If model says INVALID, stop here
if "INVALID" in raw_choice:
progress(1.0, desc="❌ Invalid Document")
return image, {
"error": "Validation Failed",
"message": "The uploaded image does not appear to be an invoice. Extraction cancelled to prevent hallucinations."
}
vendor_key = "VODAFONE" if "VODAFONE" in raw_choice else ("DIGI" if "DIGI" in raw_choice else "GENERAL")
# 3. Specialist (Extract Data) - Only runs for valid documents
progress(0.6, desc=f"πŸ€– Extracting {vendor_key} details...")
schema_json = json.dumps(SCHEMAS[vendor_key], indent=2)
extract_prompt = f"Extract details as JSON strictly following this schema: {schema_json}. Return ONLY valid JSON."
messages[0]["content"][1]["text"] = extract_prompt
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1536)
result = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
progress(0.9, desc="βš™οΈ Finalizing result...")
# 4. Return Image for Preview and JSON for data
try:
data = json.loads(result.strip().replace('```json', '').replace('```', ''))
progress(1.0, desc="βœ… Success!")
return image, data
except:
progress(1.0, desc="⚠️ Extraction complete with formatting issues")
return image, {"raw_output": result}
# --- INTERFACE ---
with gr.Blocks(title="InvoiceRecon", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“‘ IntelliReceipt: Real-Time Invoice AI")
gr.Markdown("Upload an invoice (PDF or Image) to extract structured data using Qwen2.5-VL.")
with gr.Row():
# LEFT COLUMN: Inputs and Preview
with gr.Column(scale=1):
file_input = gr.File(label="1. Upload Invoice", file_types=[".pdf", ".png", ".jpg"])
preview_output = gr.Image(label="2. Document Preview", type="pil")
run_btn = gr.Button("πŸš€ Extract Data", variant="primary")
# The ClearButton can now safely reference json_output because it is defined below
# but inside the same Row block.
reset_btn = gr.ClearButton(
components=[file_input, preview_output], # We will add json_output via a method below
value="πŸ—‘οΈ Reset All",
variant="secondary"
)
# RIGHT COLUMN: JSON Result
with gr.Column(scale=1):
json_output = gr.JSON(label="3. Extracted JSON Result")
# To ensure the Reset All button clears the JSON even though it was defined after the button:
reset_btn.add(json_output)
# Important: Ensure inputs and outputs match function signature
run_btn.click(
fn=process_invoice,
inputs=file_input,
outputs=[preview_output, json_output]
)
demo.launch()