Amanda commited on
Commit
c56c5b2
·
1 Parent(s): f872c26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +487 -0
app.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
+ import gradio as gr
9
+ from gradio import processing_utils
10
+
11
+ import torch
12
+ from docquery import pipeline
13
+ from docquery.document import load_bytes, load_document, ImageDocument
14
+ from docquery.ocr_reader import get_ocr_reader
15
+
16
+
17
+ def ensure_list(x):
18
+ if isinstance(x, list):
19
+ return x
20
+ else:
21
+ return [x]
22
+
23
+
24
+ CHECKPOINTS = {
25
+ "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
26
+ }
27
+
28
+ PIPELINES = {}
29
+
30
+
31
+ def construct_pipeline(task, model):
32
+ global PIPELINES
33
+ if model in PIPELINES:
34
+ return PIPELINES[model]
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
38
+ PIPELINES[model] = ret
39
+ return ret
40
+
41
+
42
+ def run_pipeline(model, question, document, top_k):
43
+ pipeline = construct_pipeline("document-question-answering", model)
44
+ return pipeline(question=question, **document.context, top_k=top_k)
45
+
46
+
47
+ # TODO: Move into docquery
48
+ # TODO: Support words past the first page (or window?)
49
+ def lift_word_boxes(document, page):
50
+ return document.context["image"][page][1]
51
+
52
+
53
+ def expand_bbox(word_boxes):
54
+ if len(word_boxes) == 0:
55
+ return None
56
+
57
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
58
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
59
+ return [min_x, min_y, max_x, max_y]
60
+
61
+
62
+ # LayoutLM boxes are normalized to 0, 1000
63
+ def normalize_bbox(box, width, height, padding=0.005):
64
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
65
+ if padding != 0:
66
+ min_x = max(0, min_x - padding)
67
+ min_y = max(0, min_y - padding)
68
+ max_x = min(max_x + padding, 1)
69
+ max_y = min(max_y + padding, 1)
70
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
71
+
72
+
73
+ EXAMPLES = [
74
+ [
75
+ "DL.png",
76
+ "Driver's License",
77
+ ],
78
+ [
79
+ "BC.png",
80
+ "Birth Certificate",
81
+ ],
82
+ [
83
+ "EAC.png",
84
+ "Employment Authorization Card",
85
+ ],
86
+ ]
87
+
88
+ QUESTION_FILES = {
89
+ "Tech Invoice": "acze_tech.pdf",
90
+ "Energy Invoice": "north_sea.pdf",
91
+ }
92
+
93
+ for q in QUESTION_FILES.keys():
94
+ assert any(x[1] == q for x in EXAMPLES)
95
+
96
+ FIELDS = {
97
+ "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
98
+ "Vendor Address": ["Vendor Address?"],
99
+ "Customer Name": ["Customer Name?"],
100
+ "Customer Address": ["Customer Address?"],
101
+ "Invoice Number": ["Invoice Number?"],
102
+ "Invoice Date": ["Invoice Date?"],
103
+ "Due Date": ["Due Date?"],
104
+ "Subtotal": ["Subtotal?"],
105
+ "Total Tax": ["Total Tax?"],
106
+ "Invoice Total": ["Invoice Total?"],
107
+ "Amount Due": ["Amount Due?"],
108
+ "Payment Terms": ["Payment Terms?"],
109
+ "Remit To Name": ["Remit To Name?"],
110
+ "Remit To Address": ["Remit To Address?"],
111
+ }
112
+
113
+
114
+ def empty_table(fields):
115
+ return {"value": [[name, None] for name in fields.keys()], "interactive": False}
116
+
117
+
118
+ def process_document(document, fields, model, error=None):
119
+ if document is not None and error is None:
120
+ preview, json_output, table = process_fields(document, fields, model)
121
+ return (
122
+ document,
123
+ fields,
124
+ preview,
125
+ gr.update(visible=True),
126
+ gr.update(visible=False, value=None),
127
+ json_output,
128
+ table,
129
+ )
130
+ else:
131
+ return (
132
+ None,
133
+ fields,
134
+ None,
135
+ gr.update(visible=False),
136
+ gr.update(visible=True, value=error) if error is not None else None,
137
+ None,
138
+ gr.update(**empty_table(fields)),
139
+ )
140
+
141
+
142
+ def process_path(path, fields, model):
143
+ error = None
144
+ document = None
145
+ if path:
146
+ try:
147
+ document = load_document(path)
148
+ except Exception as e:
149
+ traceback.print_exc()
150
+ error = str(e)
151
+
152
+ return process_document(document, fields, model, error)
153
+
154
+
155
+ def process_upload(file, fields, model):
156
+ return process_path(file.name if file else None, fields, model)
157
+
158
+
159
+ colors = ["#64A087", "green", "black"]
160
+
161
+
162
+ def annotate_page(prediction, pages, document):
163
+ if prediction is not None and "word_ids" in prediction:
164
+ image = pages[prediction["page"]]
165
+ draw = ImageDraw.Draw(image, "RGBA")
166
+ word_boxes = lift_word_boxes(document, prediction["page"])
167
+ x1, y1, x2, y2 = normalize_bbox(
168
+ expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
169
+ image.width,
170
+ image.height,
171
+ )
172
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
173
+
174
+
175
+ def process_question(
176
+ question, document, img_gallery, model, fields, output, output_table
177
+ ):
178
+ field_name = question
179
+ if field_name is not None:
180
+ fields = {field_name: [question], **fields}
181
+
182
+ if not question or document is None:
183
+ return None, document, fields, output, gr.update(value=output_table)
184
+
185
+ text_value = None
186
+ pages = [processing_utils.decode_base64_to_image(p) for p in img_gallery]
187
+ prediction = run_pipeline(model, question, document, 1)
188
+ annotate_page(prediction, pages, document)
189
+
190
+ output = {field_name: prediction, **output}
191
+ table = [[field_name, prediction.get("answer")]] + output_table.values.tolist()
192
+ return (
193
+ None,
194
+ gr.update(visible=True, value=pages),
195
+ fields,
196
+ output,
197
+ gr.update(value=table, interactive=False),
198
+ )
199
+
200
+
201
+ def process_fields(document, fields, model=list(CHECKPOINTS.keys())[0]):
202
+ pages = [x.copy().convert("RGB") for x in document.preview]
203
+
204
+ ret = {}
205
+ table = []
206
+
207
+ for (field_name, questions) in fields.items():
208
+ answers = [
209
+ a
210
+ for q in questions
211
+ for a in ensure_list(run_pipeline(model, q, document, top_k=1))
212
+ if a.get("score", 1) > 0.5
213
+ ]
214
+ answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
215
+ top = answers[0] if len(answers) > 0 else None
216
+ annotate_page(top, pages, document)
217
+ ret[field_name] = top
218
+ table.append([field_name, top.get("answer") if top is not None else None])
219
+
220
+ return (
221
+ gr.update(visible=True, value=pages),
222
+ gr.update(visible=True, value=ret),
223
+ table
224
+ )
225
+
226
+
227
+ def load_example_document(img, title, fields, model):
228
+ document = None
229
+ if img is not None:
230
+ if title in QUESTION_FILES:
231
+ document = load_document(QUESTION_FILES[title])
232
+ else:
233
+ document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
234
+
235
+ return process_document(document, fields, model)
236
+
237
+
238
+ CSS = """
239
+ #question input {
240
+ font-size: 16px;
241
+ }
242
+ #url-textbox, #question-textbox {
243
+ padding: 0 !important;
244
+ }
245
+ #short-upload-box .w-full {
246
+ min-height: 10rem !important;
247
+ }
248
+ /* I think something like this can be used to re-shape
249
+ * the table
250
+ */
251
+ /*
252
+ .gr-samples-table tr {
253
+ display: inline;
254
+ }
255
+ .gr-samples-table .p-2 {
256
+ width: 100px;
257
+ }
258
+ */
259
+ #select-a-file {
260
+ width: 100%;
261
+ }
262
+ #file-clear {
263
+ padding-top: 2px !important;
264
+ padding-bottom: 2px !important;
265
+ padding-left: 8px !important;
266
+ padding-right: 8px !important;
267
+ margin-top: 10px;
268
+ }
269
+ .gradio-container .gr-button-primary {
270
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
271
+ border: 1px solid #B0DCCC;
272
+ border-radius: 8px;
273
+ color: #1B8700;
274
+ }
275
+ .gradio-container.dark button#submit-button {
276
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
277
+ border: 1px solid #B0DCCC;
278
+ border-radius: 8px;
279
+ color: #1B8700
280
+ }
281
+ table.gr-samples-table tr td {
282
+ border: none;
283
+ outline: none;
284
+ }
285
+ table.gr-samples-table tr td:first-of-type {
286
+ width: 0%;
287
+ }
288
+ div#short-upload-box div.absolute {
289
+ display: none !important;
290
+ }
291
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
292
+ gap: 0px 2%;
293
+ }
294
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
295
+ gap: 0px;
296
+ }
297
+ gradio-app h2, .gradio-app h2 {
298
+ padding-top: 10px;
299
+ }
300
+ #answer {
301
+ overflow-y: scroll;
302
+ color: white;
303
+ background: #666;
304
+ border-color: #666;
305
+ font-size: 20px;
306
+ font-weight: bold;
307
+ }
308
+ #answer span {
309
+ color: white;
310
+ }
311
+ #answer textarea {
312
+ color:white;
313
+ background: #777;
314
+ border-color: #777;
315
+ font-size: 18px;
316
+ }
317
+ #url-error input {
318
+ color: red;
319
+ }
320
+ #results-table {
321
+ max-height: 600px;
322
+ overflow-y: scroll;
323
+ }
324
+ """
325
+
326
+ with gr.Blocks(css=CSS) as demo:
327
+ gr.Markdown("# DocQuery for Invoices")
328
+ gr.Markdown(
329
+ "DocQuery (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=invoices_space))"
330
+ " uses LayoutLMv1 fine-tuned on an invoice dataset"
331
+ " as well as DocVQA and SQuAD, which boot its general comprehension skills. The model is an enhanced"
332
+ " QA architecture that supports selecting blocks of text which may be non-consecutive, which is a major"
333
+ " issue when dealing with invoice documents (e.g. addresses)."
334
+ " To use it, simply upload an image or PDF invoice and the model will predict values for several fields."
335
+ " You can also create additional fields by simply typing in a question."
336
+ " DocQuery is available on [Github](https://github.com/impira/docquery)."
337
+ )
338
+
339
+ document = gr.Variable()
340
+ fields = gr.Variable(value={**FIELDS})
341
+ example_question = gr.Textbox(visible=False)
342
+ example_image = gr.Image(visible=False)
343
+
344
+ with gr.Row(equal_height=True):
345
+ with gr.Column():
346
+ with gr.Row():
347
+ gr.Markdown("## Select an invoice", elem_id="select-a-file")
348
+ img_clear_button = gr.Button(
349
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
350
+ )
351
+ image = gr.Gallery(visible=False)
352
+ with gr.Row(equal_height=True):
353
+ with gr.Column():
354
+ with gr.Row():
355
+ url = gr.Textbox(
356
+ show_label=False,
357
+ placeholder="URL",
358
+ lines=1,
359
+ max_lines=1,
360
+ elem_id="url-textbox",
361
+ )
362
+ submit = gr.Button("Get")
363
+ url_error = gr.Textbox(
364
+ visible=False,
365
+ elem_id="url-error",
366
+ max_lines=1,
367
+ interactive=False,
368
+ label="Error",
369
+ )
370
+ gr.Markdown("— or —")
371
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
372
+ gr.Examples(
373
+ examples=EXAMPLES,
374
+ inputs=[example_image, example_question],
375
+ )
376
+
377
+ with gr.Column() as col:
378
+ gr.Markdown("## Results")
379
+ with gr.Tabs():
380
+ with gr.TabItem("Table"):
381
+ output_table = gr.Dataframe(
382
+ headers=["Field", "Value"],
383
+ **empty_table(fields.value),
384
+ elem_id="results-table"
385
+ )
386
+
387
+ with gr.TabItem("JSON"):
388
+ output = gr.JSON(label="Output", visible=True)
389
+
390
+ model = gr.Radio(
391
+ choices=list(CHECKPOINTS.keys()),
392
+ value=list(CHECKPOINTS.keys())[0],
393
+ label="Model",
394
+ visible=False,
395
+ )
396
+
397
+ gr.Markdown("### Ask a question")
398
+ with gr.Row():
399
+ question = gr.Textbox(
400
+ label="Question",
401
+ show_label=False,
402
+ placeholder="e.g. What is the invoice number?",
403
+ lines=1,
404
+ max_lines=1,
405
+ elem_id="question-textbox",
406
+ )
407
+ clear_button = gr.Button("Clear", variant="secondary", visible=False)
408
+ submit_button = gr.Button(
409
+ "Add", variant="primary", elem_id="submit-button"
410
+ )
411
+
412
+ for cb in [img_clear_button, clear_button]:
413
+ cb.click(
414
+ lambda _: (
415
+ gr.update(visible=False, value=None), # image
416
+ None, # document
417
+ # {**FIELDS}, # fields
418
+ gr.update(value=None), # output
419
+ gr.update(**empty_table(fields.value)), # output_table
420
+ gr.update(visible=False),
421
+ None,
422
+ None,
423
+ None,
424
+ gr.update(visible=False, value=None),
425
+ None,
426
+ ),
427
+ inputs=clear_button,
428
+ outputs=[
429
+ image,
430
+ document,
431
+ # fields,
432
+ output,
433
+ output_table,
434
+ img_clear_button,
435
+ example_image,
436
+ upload,
437
+ url,
438
+ url_error,
439
+ question,
440
+ ],
441
+ )
442
+
443
+ submit_outputs = [
444
+ document,
445
+ fields,
446
+ image,
447
+ img_clear_button,
448
+ url_error,
449
+ output,
450
+ output_table,
451
+ ]
452
+
453
+ upload.change(
454
+ fn=process_upload,
455
+ inputs=[upload, fields, model],
456
+ outputs=submit_outputs,
457
+ )
458
+
459
+ submit.click(
460
+ fn=process_path,
461
+ inputs=[url, fields, model],
462
+ outputs=submit_outputs,
463
+ )
464
+
465
+ for action in [question.submit, submit_button.click]:
466
+ action(
467
+ fn=process_question,
468
+ inputs=[question, document, image, model, fields, output, output_table],
469
+ outputs=[question, image, fields, output, output_table],
470
+ )
471
+
472
+ # model.change(
473
+ # process_question,
474
+ # inputs=[question, document, model],
475
+ # outputs=[image, output, output_table],
476
+ # )
477
+
478
+ example_image.change(
479
+ fn=load_example_document,
480
+ inputs=[example_image, example_question, fields, model],
481
+ outputs=submit_outputs,
482
+ )
483
+
484
+ if __name__ == "__main__":
485
+ demo.launch(enable_queue=False)
486
+
487
+ #code modified from Impira/invoices space