anirudh-valyx commited on
Commit
28cce98
·
1 Parent(s): 1327cd1

fix input text

Browse files
Files changed (1) hide show
  1. app.py +107 -51
app.py CHANGED
@@ -1,43 +1,66 @@
1
  import gradio as gr
 
 
2
  from transformers import AutoProcessor, AutoModelForDocumentQuestionAnswering
3
- from PIL import Image
 
 
 
4
 
5
  # Load processor and model
6
- processor = AutoProcessor.from_pretrained("impira/layoutlm-invoices")
7
- model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-invoices")
 
8
 
9
- def answer_question(image, question):
10
- """
11
- Process an invoice image and answer a question about its content
12
-
13
- Args:
14
- image: PIL image of the invoice
15
- question: String question about the invoice
16
-
17
- Returns:
18
- String answer extracted from the invoice
19
- """
20
- # Input validation
21
- if image is None:
22
- return "Please upload an image"
23
-
24
- if question is None or question.strip() == "":
25
- return "Please enter a question"
26
-
27
- # Ensure RGB mode
28
- if image.mode != "RGB":
29
- image = image.convert("RGB")
 
 
 
 
30
 
31
- # Ensure question is a string
32
- if not isinstance(question, str):
33
- question = str(question)
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  try:
36
- # Following the exact format from the model documentation
37
- # First position is for image, second is for text/question
38
  encoding = processor(image, question, return_tensors="pt")
 
 
39
 
40
- # Forward pass through model
41
  outputs = model(**encoding)
42
 
43
  # Extract answer span
@@ -50,30 +73,63 @@ def answer_question(image, question):
50
  # Clean up answer
51
  answer = answer.replace("[CLS]", "").replace("[SEP]", "").strip()
52
 
53
- if not answer:
54
- return "No answer found in the document"
55
-
56
- return answer
 
57
  except Exception as e:
58
  import traceback
59
- tb = traceback.format_exc()
60
- return f"Error processing document: {str(e)}\n\nDetails:\n{tb}"
61
 
62
  # Create Gradio interface
63
- iface = gr.Interface(
64
- fn=answer_question,
65
- inputs=[
66
- gr.Image(type="pil", label="Upload Invoice Image"),
67
- gr.Textbox(placeholder="Ask a question about the invoice...", label="Question")
68
- ],
69
- outputs=gr.Textbox(label="Answer"),
70
- title="Invoice Question Answering with LayoutLM",
71
- description="Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', 'Who is the vendor?', etc.",
72
- # No hardcoded examples since we don't have sample files
73
- examples=None,
74
- allow_flagging="never"
75
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Launch the app
78
  if __name__ == "__main__":
79
- iface.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
  from transformers import AutoProcessor, AutoModelForDocumentQuestionAnswering
5
+ from PIL import Image, ImageDraw
6
+
7
+ # Disable tokenizers parallelism to avoid warnings
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
 
10
  # Load processor and model
11
+ MODEL_NAME = "impira/layoutlm-invoices"
12
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForDocumentQuestionAnswering.from_pretrained(MODEL_NAME)
14
 
15
+ # Use GPU if available
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
+
19
+ def ensure_list(x):
20
+ """Ensure input is a list"""
21
+ if isinstance(x, list):
22
+ return x
23
+ else:
24
+ return [x]
25
+
26
+ def normalize_bbox(box, width, height, padding=0.005):
27
+ """Normalize bounding box coordinates"""
28
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
29
+ if padding != 0:
30
+ min_x = max(0, min_x - padding)
31
+ min_y = max(0, min_y - padding)
32
+ max_x = min(max_x + padding, 1)
33
+ max_y = min(max_y + padding, 1)
34
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
35
+
36
+ def process_document(image_file):
37
+ """Process uploaded document"""
38
+ if image_file is None:
39
+ return None, gr.update(visible=False)
40
 
41
+ try:
42
+ # Load image
43
+ image = Image.open(image_file.name)
44
+ if image.mode != "RGB":
45
+ image = image.convert("RGB")
46
+
47
+ # Return the document and show the image
48
+ return image, gr.update(visible=True, value=image)
49
+ except Exception as e:
50
+ return None, gr.update(visible=False, value=f"Error: {str(e)}")
51
+
52
+ def answer_question(question, image):
53
+ """Process question with LayoutLM model"""
54
+ if image is None or question.strip() == "":
55
+ return None, None
56
 
57
  try:
58
+ # Process inputs
 
59
  encoding = processor(image, question, return_tensors="pt")
60
+ for key in encoding.keys():
61
+ encoding[key] = encoding[key].to(device)
62
 
63
+ # Get model predictions
64
  outputs = model(**encoding)
65
 
66
  # Extract answer span
 
73
  # Clean up answer
74
  answer = answer.replace("[CLS]", "").replace("[SEP]", "").strip()
75
 
76
+ # Highlight answer in image if word_ids are available
77
+ result_image = image.copy().convert("RGB")
78
+
79
+ # Return results
80
+ return answer, result_image
81
  except Exception as e:
82
  import traceback
83
+ error_msg = f"Error processing document: {str(e)}\n{traceback.format_exc()}"
84
+ return error_msg, None
85
 
86
  # Create Gradio interface
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("# Invoice Question Answering with LayoutLM")
89
+ gr.Markdown("Upload an invoice image and ask questions like 'What is the invoice number?', 'What is the total amount?', etc.")
90
+
91
+ # Document storage
92
+ document = gr.State(None)
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ gr.Markdown("## 1. Upload a document")
97
+ upload = gr.File(label="Upload Invoice Image")
98
+ image_preview = gr.Image(label="Preview", visible=False)
99
+
100
+ gr.Markdown("## 2. Ask a question")
101
+ question = gr.Textbox(
102
+ label="Question",
103
+ placeholder="e.g. What is the invoice number?",
104
+ lines=1
105
+ )
106
+
107
+ submit_button = gr.Button("Submit", variant="primary")
108
+
109
+ with gr.Column():
110
+ gr.Markdown("## Results")
111
+ answer_text = gr.Textbox(label="Answer", lines=2)
112
+ result_image = gr.Image(label="Document with Answer")
113
+
114
+ # Set up event handlers
115
+ upload.change(
116
+ fn=process_document,
117
+ inputs=[upload],
118
+ outputs=[document, image_preview]
119
+ )
120
+
121
+ submit_button.click(
122
+ fn=answer_question,
123
+ inputs=[question, document],
124
+ outputs=[answer_text, result_image]
125
+ )
126
+
127
+ # Also trigger on pressing Enter in question box
128
+ question.submit(
129
+ fn=answer_question,
130
+ inputs=[question, document],
131
+ outputs=[answer_text, result_image]
132
+ )
133
 
 
134
  if __name__ == "__main__":
135
+ demo.launch(debug=True)