layout-lm / app.py
anirudh-valyx's picture
fix input text
24591ec
import os
import gradio as gr
import torch
from PIL import Image
import traceback
# Disable tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Use docquery library instead of direct transformer usage
# This should handle the model correctly as shown in your reference code
from docquery import pipeline
from docquery.document import ImageDocument
from docquery.ocr_reader import get_ocr_reader
# Global variables
MODEL_NAME = "LayoutLMv1 for Invoices"
CHECKPOINT = "impira/layoutlm-invoices"
PIPELINES = {}
def construct_pipeline(model_name):
"""Create and cache a document QA pipeline"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
device = "cuda" if torch.cuda.is_available() else "cpu"
qa_pipeline = pipeline(
task="document-question-answering",
model=CHECKPOINT,
device=device
)
PIPELINES[model_name] = qa_pipeline
return qa_pipeline
def process_document(file):
"""Process the uploaded document"""
if file is None:
return None, gr.update(visible=False)
try:
# Open the image
image = Image.open(file.name)
if image.mode != "RGB":
image = image.convert("RGB")
# Create a document using docquery
document = ImageDocument(image, get_ocr_reader())
return document, gr.update(visible=True, value=image)
except Exception as e:
traceback.print_exc()
return None, gr.update(visible=False, value=f"Error: {str(e)}")
def answer_question(question, document):
"""Process question using the document QA pipeline"""
if document is None or not question:
return "Please upload a document and enter a question"
try:
# Get the pipeline
qa_pipeline = construct_pipeline(MODEL_NAME)
# Run question answering
results = qa_pipeline(question=question, **document.context, top_k=1)
# Extract the answer
if results:
answer = results[0]["answer"]
return answer
else:
return "No answer found in the document"
except Exception as e:
traceback.print_exc()
return f"Error processing document: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Invoice Question Answering")
gr.Markdown("Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', etc.")
# Document storage
document = gr.State(None)
with gr.Row():
with gr.Column():
gr.Markdown("## 1. Upload a document")
upload = gr.File(label="Upload Invoice Image")
image_preview = gr.Image(label="Preview", visible=False)
gr.Markdown("## 2. Ask a question")
question = gr.Textbox(
label="Question",
placeholder="e.g. What is the invoice number?",
lines=1
)
submit_button = gr.Button("Submit", variant="primary")
with gr.Column():
gr.Markdown("## Results")
answer_text = gr.Textbox(label="Answer", lines=2)
# Set up event handlers
upload.change(
fn=process_document,
inputs=[upload],
outputs=[document, image_preview]
)
submit_button.click(
fn=answer_question,
inputs=[question, document],
outputs=[answer_text]
)
# Also trigger on pressing Enter in question box
question.submit(
fn=answer_question,
inputs=[question, document],
outputs=[answer_text]
)
if __name__ == "__main__":
demo.launch()