File size: 3,780 Bytes
28cce98 24591ec 28cce98 e7e394c 24591ec 28cce98 8eb6077 24591ec 8eb6077 24591ec 28cce98 24591ec 28cce98 0474ce4 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 8a3d98d 24591ec e7e394c 24591ec 8a3d98d 24591ec 8a3d98d 24591ec 8eb6077 8a3d98d 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 28cce98 24591ec 28cce98 8eb6077 8a3d98d 24591ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import os
import gradio as gr
import torch
from PIL import Image
import traceback
# Disable tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Use docquery library instead of direct transformer usage
# This should handle the model correctly as shown in your reference code
from docquery import pipeline
from docquery.document import ImageDocument
from docquery.ocr_reader import get_ocr_reader
# Global variables
MODEL_NAME = "LayoutLMv1 for Invoices"
CHECKPOINT = "impira/layoutlm-invoices"
PIPELINES = {}
def construct_pipeline(model_name):
"""Create and cache a document QA pipeline"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
qa_pipeline = pipeline(
task="document-question-answering",
model=CHECKPOINT,
device=device
)
PIPELINES[model_name] = qa_pipeline
return qa_pipeline
def process_document(file):
"""Process the uploaded document"""
if file is None:
return None, gr.update(visible=False)
try:
# Open the image
image = Image.open(file.name)
if image.mode != "RGB":
image = image.convert("RGB")
# Create a document using docquery
document = ImageDocument(image, get_ocr_reader())
return document, gr.update(visible=True, value=image)
except Exception as e:
traceback.print_exc()
return None, gr.update(visible=False, value=f"Error: {str(e)}")
def answer_question(question, document):
"""Process question using the document QA pipeline"""
if document is None or not question:
return "Please upload a document and enter a question"
try:
# Get the pipeline
qa_pipeline = construct_pipeline(MODEL_NAME)
# Run question answering
results = qa_pipeline(question=question, **document.context, top_k=1)
# Extract the answer
if results:
answer = results[0]["answer"]
return answer
else:
return "No answer found in the document"
except Exception as e:
traceback.print_exc()
return f"Error processing document: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Invoice Question Answering")
gr.Markdown("Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', etc.")
# Document storage
document = gr.State(None)
with gr.Row():
with gr.Column():
gr.Markdown("## 1. Upload a document")
upload = gr.File(label="Upload Invoice Image")
image_preview = gr.Image(label="Preview", visible=False)
gr.Markdown("## 2. Ask a question")
question = gr.Textbox(
label="Question",
placeholder="e.g. What is the invoice number?",
lines=1
)
submit_button = gr.Button("Submit", variant="primary")
with gr.Column():
gr.Markdown("## Results")
answer_text = gr.Textbox(label="Answer", lines=2)
# Set up event handlers
upload.change(
fn=process_document,
inputs=[upload],
outputs=[document, image_preview]
)
submit_button.click(
fn=answer_question,
inputs=[question, document],
outputs=[answer_text]
)
# Also trigger on pressing Enter in question box
question.submit(
fn=answer_question,
inputs=[question, document],
outputs=[answer_text]
)
if __name__ == "__main__":
demo.launch() |