Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import json | |
| import spaces | |
| import fitz # PyMuPDF for PDF handling | |
| from PIL import Image | |
| import io | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig | |
| from qwen_vl_utils import process_vision_info | |
| # --- DETAILED SCHEMAS --- | |
| SCHEMAS = { | |
| "VODAFONE": { | |
| "vendor": "VODAFONE ROMANIA", | |
| "invoice_number": "string", | |
| "date": "string (DD-MM-YYYY)", | |
| "client_name": "string", | |
| "client_address": "string", | |
| "account_id": "string", | |
| "billing_period": "string", | |
| "totals": { | |
| "subtotal_no_vat": "number", | |
| "vat_amount": "number", | |
| "grand_total": "number", | |
| "currency": "RON" | |
| } | |
| }, | |
| "DIGI": { | |
| "vendor": "DIGI (RCS & RDS)", | |
| "invoice_number": "string", | |
| "contract_id": "string", | |
| "total_amount": "number", | |
| "iban": "string" | |
| }, | |
| "GENERAL": { | |
| "vendor_name": "string", | |
| "invoice_id": "string", | |
| "date": "string", | |
| "total_with_vat": "number", | |
| "client_name": "string" | |
| } | |
| } | |
| # --- MODEL LOADING --- | |
| MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| def load_model(): | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="cuda", | |
| quantization_config=quant_config | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, max_pixels=1280*1280) | |
| return model, processor | |
| model, processor = load_model() | |
| # --- PDF TO IMAGE HELPER --- | |
| def get_pdf_page_image(pdf_path): | |
| doc = fitz.open(pdf_path) | |
| page = doc.load_page(0) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) | |
| img = Image.open(io.BytesIO(pix.tobytes())) | |
| doc.close() | |
| return img | |
| # --- INFERENCE --- | |
| def process_invoice(file_info, progress=gr.Progress()): | |
| if file_info is None: | |
| return None, {"error": "No file uploaded"} | |
| # 1. Handle File Type and Preview | |
| progress(0.1, desc="π Processing document...") | |
| if file_info.name.lower().endswith(".pdf"): | |
| image = get_pdf_page_image(file_info.name) | |
| else: | |
| image = Image.open(file_info.name) | |
| # 2. Router & Validation (Identify Vendor or Reject) | |
| progress(0.3, desc="π Validating and Identifying Vendor...") | |
| # Updated prompt to provide an 'INVALID' exit | |
| decision_prompt = """Analyze this image. Is it a financial invoice or receipt? | |
| - If NO (e.g. random photo, object, landscape): Reply 'INVALID'. | |
| - If YES: Reply ONLY with 'VODAFONE', 'DIGI', or 'GENERAL'.""" | |
| messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": decision_prompt}]}] | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| image_inputs, _ = process_vision_info(messages) | |
| inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device) | |
| generated_ids = model.generate(**inputs, max_new_tokens=10) | |
| raw_choice = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0].strip().upper() | |
| # VALIDATION CHECK: If model says INVALID, stop here | |
| if "INVALID" in raw_choice: | |
| progress(1.0, desc="β Invalid Document") | |
| return image, { | |
| "error": "Validation Failed", | |
| "message": "The uploaded image does not appear to be an invoice. Extraction cancelled to prevent hallucinations." | |
| } | |
| vendor_key = "VODAFONE" if "VODAFONE" in raw_choice else ("DIGI" if "DIGI" in raw_choice else "GENERAL") | |
| # 3. Specialist (Extract Data) - Only runs for valid documents | |
| progress(0.6, desc=f"π€ Extracting {vendor_key} details...") | |
| schema_json = json.dumps(SCHEMAS[vendor_key], indent=2) | |
| extract_prompt = f"Extract details as JSON strictly following this schema: {schema_json}. Return ONLY valid JSON." | |
| messages[0]["content"][1]["text"] = extract_prompt | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device) | |
| generated_ids = model.generate(**inputs, max_new_tokens=1536) | |
| result = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] | |
| progress(0.9, desc="βοΈ Finalizing result...") | |
| # 4. Return Image for Preview and JSON for data | |
| try: | |
| data = json.loads(result.strip().replace('```json', '').replace('```', '')) | |
| progress(1.0, desc="β Success!") | |
| return image, data | |
| except: | |
| progress(1.0, desc="β οΈ Extraction complete with formatting issues") | |
| return image, {"raw_output": result} | |
| # --- INTERFACE --- | |
| with gr.Blocks(title="InvoiceRecon", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π IntelliReceipt: Real-Time Invoice AI") | |
| gr.Markdown("Upload an invoice (PDF or Image) to extract structured data using Qwen2.5-VL.") | |
| with gr.Row(): | |
| # LEFT COLUMN: Inputs and Preview | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="1. Upload Invoice", file_types=[".pdf", ".png", ".jpg"]) | |
| preview_output = gr.Image(label="2. Document Preview", type="pil") | |
| run_btn = gr.Button("π Extract Data", variant="primary") | |
| # The ClearButton can now safely reference json_output because it is defined below | |
| # but inside the same Row block. | |
| reset_btn = gr.ClearButton( | |
| components=[file_input, preview_output], # We will add json_output via a method below | |
| value="ποΈ Reset All", | |
| variant="secondary" | |
| ) | |
| # RIGHT COLUMN: JSON Result | |
| with gr.Column(scale=1): | |
| json_output = gr.JSON(label="3. Extracted JSON Result") | |
| # To ensure the Reset All button clears the JSON even though it was defined after the button: | |
| reset_btn.add(json_output) | |
| # Important: Ensure inputs and outputs match function signature | |
| run_btn.click( | |
| fn=process_invoice, | |
| inputs=file_input, | |
| outputs=[preview_output, json_output] | |
| ) | |
| demo.launch() |