|
|
import os |
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import traceback |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
|
|
|
from docquery import pipeline |
|
|
from docquery.document import ImageDocument |
|
|
from docquery.ocr_reader import get_ocr_reader |
|
|
|
|
|
|
|
|
MODEL_NAME = "LayoutLMv1 for Invoices" |
|
|
CHECKPOINT = "impira/layoutlm-invoices" |
|
|
PIPELINES = {} |
|
|
|
|
|
def construct_pipeline(model_name): |
|
|
"""Create and cache a document QA pipeline""" |
|
|
global PIPELINES |
|
|
if model_name in PIPELINES: |
|
|
return PIPELINES[model_name] |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
qa_pipeline = pipeline( |
|
|
task="document-question-answering", |
|
|
model=CHECKPOINT, |
|
|
device=device |
|
|
) |
|
|
PIPELINES[model_name] = qa_pipeline |
|
|
return qa_pipeline |
|
|
|
|
|
def process_document(file): |
|
|
"""Process the uploaded document""" |
|
|
if file is None: |
|
|
return None, gr.update(visible=False) |
|
|
|
|
|
try: |
|
|
|
|
|
image = Image.open(file.name) |
|
|
if image.mode != "RGB": |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
document = ImageDocument(image, get_ocr_reader()) |
|
|
|
|
|
return document, gr.update(visible=True, value=image) |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return None, gr.update(visible=False, value=f"Error: {str(e)}") |
|
|
|
|
|
def answer_question(question, document): |
|
|
"""Process question using the document QA pipeline""" |
|
|
if document is None or not question: |
|
|
return "Please upload a document and enter a question" |
|
|
|
|
|
try: |
|
|
|
|
|
qa_pipeline = construct_pipeline(MODEL_NAME) |
|
|
|
|
|
|
|
|
results = qa_pipeline(question=question, **document.context, top_k=1) |
|
|
|
|
|
|
|
|
if results: |
|
|
answer = results[0]["answer"] |
|
|
return answer |
|
|
else: |
|
|
return "No answer found in the document" |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return f"Error processing document: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Invoice Question Answering") |
|
|
gr.Markdown("Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', etc.") |
|
|
|
|
|
|
|
|
document = gr.State(None) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("## 1. Upload a document") |
|
|
upload = gr.File(label="Upload Invoice Image") |
|
|
image_preview = gr.Image(label="Preview", visible=False) |
|
|
|
|
|
gr.Markdown("## 2. Ask a question") |
|
|
question = gr.Textbox( |
|
|
label="Question", |
|
|
placeholder="e.g. What is the invoice number?", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
submit_button = gr.Button("Submit", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("## Results") |
|
|
answer_text = gr.Textbox(label="Answer", lines=2) |
|
|
|
|
|
|
|
|
upload.change( |
|
|
fn=process_document, |
|
|
inputs=[upload], |
|
|
outputs=[document, image_preview] |
|
|
) |
|
|
|
|
|
submit_button.click( |
|
|
fn=answer_question, |
|
|
inputs=[question, document], |
|
|
outputs=[answer_text] |
|
|
) |
|
|
|
|
|
|
|
|
question.submit( |
|
|
fn=answer_question, |
|
|
inputs=[question, document], |
|
|
outputs=[answer_text] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |