Invoice_parser / app.py
ashvin-savani's picture
Alpha
294474c
raw
history blame
3.38 kB
import os
import time
import base64
import json
import gc
import torch
import io
from transformers import AutoProcessor, AutoModelForImageTextToText
from qwen_vl_utils import process_vision_info
import gradio as gr
import spaces
# Model setup
MODEL_NAME = "numind/NuExtract-2.0-4B"
device = "cuda" # ZeroGPU provides GPU
model = AutoModelForImageTextToText.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
dtype=torch.bfloat16,
device_map=None, # Load on CPU, move to GPU in function
)
processor = AutoProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
padding_side='left',
use_fast=True,
)
# Invoice schema
invoice_schema = {
"invoice_number": "",
"invoice_date": "",
"supplier_name": "",
"supplier_address": "",
"total_amount": "",
"currency": "",
"items": [
{
"description": "",
"quantity": "",
"unit_price": "",
"total_price": ""
}
]
}
def encode_image_to_base64(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode("utf-8")
def encode_image_from_pil(image):
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def prepare_prompt(image_path):
base64_image = encode_image_to_base64(image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{base64_image}"}
]
}
]
text = processor.tokenizer.apply_chat_template(
messages,
template=json.dumps(invoice_schema, indent=4),
tokenize=False,
add_generation_prompt=True
)
return messages, text
@spaces.GPU
def process_image(image):
if image is None:
return "No image provided."
base64_str = encode_image_from_pil(image)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{base64_str}"}
]
}
]
text = processor.tokenizer.apply_chat_template(
messages,
template=json.dumps(invoice_schema, indent=4),
tokenize=False,
add_generation_prompt=True
)
image_inputs = process_vision_info(messages)[0] or []
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
).to(device)
generation_config = {
"do_sample": False,
"num_beams": 1,
"max_new_tokens": 2048,
}
generated_ids = model.generate(**inputs, **generation_config)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return output_text
# Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Invoice Image"),
outputs=gr.Textbox(label="Extracted Invoice Data (JSON)"),
title="Invoice Parser with NuExtract",
description="Upload an invoice image to extract structured data using AI."
)
iface.launch()