File size: 3,780 Bytes
28cce98
24591ec
28cce98
e7e394c
24591ec
28cce98
 
 
8eb6077
24591ec
 
 
 
 
8eb6077
24591ec
 
 
 
28cce98
24591ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28cce98
0474ce4
28cce98
24591ec
 
28cce98
 
24591ec
 
 
28cce98
24591ec
28cce98
24591ec
28cce98
 
24591ec
 
 
 
8a3d98d
 
24591ec
 
e7e394c
24591ec
 
8a3d98d
24591ec
 
 
 
 
 
8a3d98d
24591ec
 
8eb6077
8a3d98d
28cce98
24591ec
28cce98
 
 
 
 
 
24591ec
28cce98
 
 
 
 
 
 
 
 
 
 
 
 
24591ec
28cce98
 
 
 
 
 
 
 
 
 
 
 
 
24591ec
28cce98
 
 
 
 
 
24591ec
28cce98
8eb6077
8a3d98d
24591ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import gradio as gr
import torch
from PIL import Image
import traceback

# Disable tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Use docquery library instead of direct transformer usage
# This should handle the model correctly as shown in your reference code
from docquery import pipeline
from docquery.document import ImageDocument
from docquery.ocr_reader import get_ocr_reader

# Global variables
MODEL_NAME = "LayoutLMv1 for Invoices"
CHECKPOINT = "impira/layoutlm-invoices"
PIPELINES = {}

def construct_pipeline(model_name):
    """Create and cache a document QA pipeline"""
    global PIPELINES
    if model_name in PIPELINES:
        return PIPELINES[model_name]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    qa_pipeline = pipeline(
        task="document-question-answering", 
        model=CHECKPOINT, 
        device=device
    )
    PIPELINES[model_name] = qa_pipeline
    return qa_pipeline

def process_document(file):
    """Process the uploaded document"""
    if file is None:
        return None, gr.update(visible=False)
    
    try:
        # Open the image
        image = Image.open(file.name)
        if image.mode != "RGB":
            image = image.convert("RGB")
            
        # Create a document using docquery
        document = ImageDocument(image, get_ocr_reader())
        
        return document, gr.update(visible=True, value=image)
    except Exception as e:
        traceback.print_exc()
        return None, gr.update(visible=False, value=f"Error: {str(e)}")

def answer_question(question, document):
    """Process question using the document QA pipeline"""
    if document is None or not question:
        return "Please upload a document and enter a question"
    
    try:
        # Get the pipeline
        qa_pipeline = construct_pipeline(MODEL_NAME)
        
        # Run question answering
        results = qa_pipeline(question=question, **document.context, top_k=1)
        
        # Extract the answer
        if results:
            answer = results[0]["answer"]
            return answer
        else:
            return "No answer found in the document"
    except Exception as e:
        traceback.print_exc()
        return f"Error processing document: {str(e)}"

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Invoice Question Answering")
    gr.Markdown("Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', etc.")
    
    # Document storage
    document = gr.State(None)
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("## 1. Upload a document")
            upload = gr.File(label="Upload Invoice Image")
            image_preview = gr.Image(label="Preview", visible=False)
            
            gr.Markdown("## 2. Ask a question")
            question = gr.Textbox(
                label="Question",
                placeholder="e.g. What is the invoice number?",
                lines=1
            )
            
            submit_button = gr.Button("Submit", variant="primary")
            
        with gr.Column():
            gr.Markdown("## Results")
            answer_text = gr.Textbox(label="Answer", lines=2)
    
    # Set up event handlers
    upload.change(
        fn=process_document,
        inputs=[upload],
        outputs=[document, image_preview]
    )
    
    submit_button.click(
        fn=answer_question,
        inputs=[question, document],
        outputs=[answer_text]
    )
    
    # Also trigger on pressing Enter in question box
    question.submit(
        fn=answer_question,
        inputs=[question, document],
        outputs=[answer_text]
    )

if __name__ == "__main__":
    demo.launch()