File size: 6,521 Bytes
80d6bdb
 
 
3a86ca3
7080e09
80d6bdb
3a86ca3
7080e09
80d6bdb
 
d0c8e87
7080e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a86ca3
80d6bdb
 
7080e09
 
 
 
 
 
 
 
 
 
d0c8e87
7080e09
 
 
 
 
 
 
 
 
3a86ca3
d0c8e87
 
3a86ca3
 
 
80d6bdb
7080e09
3a86ca3
d0c8e87
 
 
80d6bdb
d0c8e87
 
7080e09
 
3a86ca3
7080e09
3a86ca3
d45f115
 
 
 
 
 
 
 
80d6bdb
d0c8e87
80d6bdb
 
 
 
3a86ca3
80d6bdb
 
d45f115
 
 
 
 
 
 
 
80d6bdb
 
d45f115
d0c8e87
7080e09
 
 
80d6bdb
 
 
 
3a86ca3
80d6bdb
 
d0c8e87
 
 
80d6bdb
d0c8e87
 
 
80d6bdb
d0c8e87
 
80d6bdb
7080e09
d0c8e87
 
1126656
d0c8e87
80d6bdb
41e8072
3a86ca3
d0c8e87
 
80d6bdb
41e8072
 
 
1126656
41e8072
1126656
 
 
 
41e8072
 
 
 
 
 
 
d0c8e87
 
 
 
 
 
80d6bdb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import torch
import json
import spaces
import fitz  # PyMuPDF for PDF handling
from PIL import Image
import io
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info

# --- DETAILED SCHEMAS ---
SCHEMAS = {
    "VODAFONE": {
        "vendor": "VODAFONE ROMANIA",
        "invoice_number": "string",
        "date": "string (DD-MM-YYYY)",
        "client_name": "string",
        "client_address": "string",
        "account_id": "string",
        "billing_period": "string",
        "totals": {
            "subtotal_no_vat": "number",
            "vat_amount": "number",
            "grand_total": "number",
            "currency": "RON"
        }
    },
    "DIGI": {
        "vendor": "DIGI (RCS & RDS)",
        "invoice_number": "string",
        "contract_id": "string",
        "total_amount": "number",
        "iban": "string"
    },
    "GENERAL": {
        "vendor_name": "string",
        "invoice_id": "string",
        "date": "string",
        "total_with_vat": "number",
        "client_name": "string"
    }
}

# --- MODEL LOADING ---
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

def load_model():
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID, 
        torch_dtype="auto", 
        device_map="cuda", 
        quantization_config=quant_config
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID, max_pixels=1280*1280)
    return model, processor

model, processor = load_model()

# --- PDF TO IMAGE HELPER ---
def get_pdf_page_image(pdf_path):
    doc = fitz.open(pdf_path)
    page = doc.load_page(0) 
    pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) 
    img = Image.open(io.BytesIO(pix.tobytes()))
    doc.close()
    return img

# --- INFERENCE ---
@spaces.GPU(duration=60)
def process_invoice(file_info, progress=gr.Progress()):
    if file_info is None: 
        return None, {"error": "No file uploaded"}
    
    # 1. Handle File Type and Preview
    progress(0.1, desc="πŸ“„ Processing document...")
    if file_info.name.lower().endswith(".pdf"):
        image = get_pdf_page_image(file_info.name)
    else:
        image = Image.open(file_info.name)

    # 2. Router & Validation (Identify Vendor or Reject)
    progress(0.3, desc="πŸ” Validating and Identifying Vendor...")
    
    # Updated prompt to provide an 'INVALID' exit
    decision_prompt = """Analyze this image. Is it a financial invoice or receipt?
    - If NO (e.g. random photo, object, landscape): Reply 'INVALID'.
    - If YES: Reply ONLY with 'VODAFONE', 'DIGI', or 'GENERAL'."""

    messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": decision_prompt}]}]
    
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, _ = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(**inputs, max_new_tokens=10)
    raw_choice = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0].strip().upper()
    
    # VALIDATION CHECK: If model says INVALID, stop here
    if "INVALID" in raw_choice:
        progress(1.0, desc="❌ Invalid Document")
        return image, {
            "error": "Validation Failed",
            "message": "The uploaded image does not appear to be an invoice. Extraction cancelled to prevent hallucinations."
        }
    
    vendor_key = "VODAFONE" if "VODAFONE" in raw_choice else ("DIGI" if "DIGI" in raw_choice else "GENERAL")
    
    # 3. Specialist (Extract Data) - Only runs for valid documents
    progress(0.6, desc=f"πŸ€– Extracting {vendor_key} details...")
    schema_json = json.dumps(SCHEMAS[vendor_key], indent=2)
    extract_prompt = f"Extract details as JSON strictly following this schema: {schema_json}. Return ONLY valid JSON."
    
    messages[0]["content"][1]["text"] = extract_prompt
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(**inputs, max_new_tokens=1536)
    result = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
    
    progress(0.9, desc="βš™οΈ Finalizing result...")
    
    # 4. Return Image for Preview and JSON for data
    try:
        data = json.loads(result.strip().replace('```json', '').replace('```', ''))
        progress(1.0, desc="βœ… Success!")
        return image, data
    except:
        progress(1.0, desc="⚠️ Extraction complete with formatting issues")
        return image, {"raw_output": result}

# --- INTERFACE ---
with gr.Blocks(title="InvoiceRecon", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ“‘ IntelliReceipt: Real-Time Invoice AI")
    gr.Markdown("Upload an invoice (PDF or Image) to extract structured data using Qwen2.5-VL.")
    
    with gr.Row():
        # LEFT COLUMN: Inputs and Preview
        with gr.Column(scale=1):
            file_input = gr.File(label="1. Upload Invoice", file_types=[".pdf", ".png", ".jpg"])
            preview_output = gr.Image(label="2. Document Preview", type="pil")
            run_btn = gr.Button("πŸš€ Extract Data", variant="primary")
            
            # The ClearButton can now safely reference json_output because it is defined below
            # but inside the same Row block.
            reset_btn = gr.ClearButton(
                components=[file_input, preview_output], # We will add json_output via a method below
                value="πŸ—‘οΈ Reset All",
                variant="secondary"
            )
        
        # RIGHT COLUMN: JSON Result
        with gr.Column(scale=1):
            json_output = gr.JSON(label="3. Extracted JSON Result")

    # To ensure the Reset All button clears the JSON even though it was defined after the button:
    reset_btn.add(json_output)

    # Important: Ensure inputs and outputs match function signature
    run_btn.click(
        fn=process_invoice, 
        inputs=file_input, 
        outputs=[preview_output, json_output]
    )

demo.launch()