Amanda commited on
Commit
38ac80e
·
1 Parent(s): 411abf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -483
app.py CHANGED
@@ -1,487 +1,56 @@
1
- import os
2
-
3
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
-
5
- from PIL import Image, ImageDraw
6
- import traceback
7
-
8
  import gradio as gr
9
- from gradio import processing_utils
10
 
11
  import torch
12
- from docquery import pipeline
13
- from docquery.document import load_bytes, load_document, ImageDocument
14
- from docquery.ocr_reader import get_ocr_reader
15
-
16
-
17
- def ensure_list(x):
18
- if isinstance(x, list):
19
- return x
20
- else:
21
- return [x]
22
-
23
-
24
- CHECKPOINTS = {
25
- "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
26
- }
27
-
28
- PIPELINES = {}
29
-
30
-
31
- def construct_pipeline(task, model):
32
- global PIPELINES
33
- if model in PIPELINES:
34
- return PIPELINES[model]
35
-
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
- ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
38
- PIPELINES[model] = ret
39
- return ret
40
-
41
-
42
- def run_pipeline(model, question, document, top_k):
43
- pipeline = construct_pipeline("document-question-answering", model)
44
- return pipeline(question=question, **document.context, top_k=top_k)
45
-
46
-
47
- # TODO: Move into docquery
48
- # TODO: Support words past the first page (or window?)
49
- def lift_word_boxes(document, page):
50
- return document.context["image"][page][1]
51
-
52
-
53
- def expand_bbox(word_boxes):
54
- if len(word_boxes) == 0:
55
- return None
56
-
57
- min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
58
- min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
59
- return [min_x, min_y, max_x, max_y]
60
-
61
-
62
- # LayoutLM boxes are normalized to 0, 1000
63
- def normalize_bbox(box, width, height, padding=0.005):
64
- min_x, min_y, max_x, max_y = [c / 1000 for c in box]
65
- if padding != 0:
66
- min_x = max(0, min_x - padding)
67
- min_y = max(0, min_y - padding)
68
- max_x = min(max_x + padding, 1)
69
- max_y = min(max_y + padding, 1)
70
- return [min_x * width, min_y * height, max_x * width, max_y * height]
71
-
72
-
73
- EXAMPLES = [
74
- [
75
- "DL.jpg",
76
- "Driver's License",
77
- ],
78
- [
79
- "BC.jfif",
80
- "Birth Certificate",
81
- ],
82
- [
83
- "EAC.png",
84
- "Employment Authorization Card",
85
- ],
86
- ]
87
-
88
- QUESTION_FILES = {
89
- "Tech Invoice": "acze_tech.pdf",
90
- "Energy Invoice": "north_sea.pdf",
91
- }
92
-
93
- for q in QUESTION_FILES.keys():
94
- assert any(x[1] == q for x in EXAMPLES)
95
-
96
- FIELDS = {
97
- "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
98
- "Vendor Address": ["Vendor Address?"],
99
- "Customer Name": ["Customer Name?"],
100
- "Customer Address": ["Customer Address?"],
101
- "Invoice Number": ["Invoice Number?"],
102
- "Invoice Date": ["Invoice Date?"],
103
- "Due Date": ["Due Date?"],
104
- "Subtotal": ["Subtotal?"],
105
- "Total Tax": ["Total Tax?"],
106
- "Invoice Total": ["Invoice Total?"],
107
- "Amount Due": ["Amount Due?"],
108
- "Payment Terms": ["Payment Terms?"],
109
- "Remit To Name": ["Remit To Name?"],
110
- "Remit To Address": ["Remit To Address?"],
111
- }
112
-
113
-
114
- def empty_table(fields):
115
- return {"value": [[name, None] for name in fields.keys()], "interactive": False}
116
-
117
-
118
- def process_document(document, fields, model, error=None):
119
- if document is not None and error is None:
120
- preview, json_output, table = process_fields(document, fields, model)
121
- return (
122
- document,
123
- fields,
124
- preview,
125
- gr.update(visible=True),
126
- gr.update(visible=False, value=None),
127
- json_output,
128
- table,
129
- )
130
- else:
131
- return (
132
- None,
133
- fields,
134
- None,
135
- gr.update(visible=False),
136
- gr.update(visible=True, value=error) if error is not None else None,
137
- None,
138
- gr.update(**empty_table(fields)),
139
- )
140
-
141
-
142
- def process_path(path, fields, model):
143
- error = None
144
- document = None
145
- if path:
146
- try:
147
- document = load_document(path)
148
- except Exception as e:
149
- traceback.print_exc()
150
- error = str(e)
151
-
152
- return process_document(document, fields, model, error)
153
-
154
-
155
- def process_upload(file, fields, model):
156
- return process_path(file.name if file else None, fields, model)
157
-
158
-
159
- colors = ["#64A087", "green", "black"]
160
-
161
-
162
- def annotate_page(prediction, pages, document):
163
- if prediction is not None and "word_ids" in prediction:
164
- image = pages[prediction["page"]]
165
- draw = ImageDraw.Draw(image, "RGBA")
166
- word_boxes = lift_word_boxes(document, prediction["page"])
167
- x1, y1, x2, y2 = normalize_bbox(
168
- expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
169
- image.width,
170
- image.height,
171
- )
172
- draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
173
-
174
-
175
- def process_question(
176
- question, document, img_gallery, model, fields, output, output_table
177
- ):
178
- field_name = question
179
- if field_name is not None:
180
- fields = {field_name: [question], **fields}
181
-
182
- if not question or document is None:
183
- return None, document, fields, output, gr.update(value=output_table)
184
-
185
- text_value = None
186
- pages = [processing_utils.decode_base64_to_image(p) for p in img_gallery]
187
- prediction = run_pipeline(model, question, document, 1)
188
- annotate_page(prediction, pages, document)
189
-
190
- output = {field_name: prediction, **output}
191
- table = [[field_name, prediction.get("answer")]] + output_table.values.tolist()
192
- return (
193
- None,
194
- gr.update(visible=True, value=pages),
195
- fields,
196
- output,
197
- gr.update(value=table, interactive=False),
198
- )
199
-
200
-
201
- def process_fields(document, fields, model=list(CHECKPOINTS.keys())[0]):
202
- pages = [x.copy().convert("RGB") for x in document.preview]
203
-
204
- ret = {}
205
- table = []
206
-
207
- for (field_name, questions) in fields.items():
208
- answers = [
209
- a
210
- for q in questions
211
- for a in ensure_list(run_pipeline(model, q, document, top_k=1))
212
- if a.get("score", 1) > 0.5
213
- ]
214
- answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
215
- top = answers[0] if len(answers) > 0 else None
216
- annotate_page(top, pages, document)
217
- ret[field_name] = top
218
- table.append([field_name, top.get("answer") if top is not None else None])
219
-
220
- return (
221
- gr.update(visible=True, value=pages),
222
- gr.update(visible=True, value=ret),
223
- table
224
- )
225
-
226
-
227
- def load_example_document(img, title, fields, model):
228
- document = None
229
- if img is not None:
230
- if title in QUESTION_FILES:
231
- document = load_document(QUESTION_FILES[title])
232
- else:
233
- document = ImageDocument(Image.fromarray(img), ocr_reader=get_ocr_reader())
234
-
235
- return process_document(document, fields, model)
236
-
237
-
238
- CSS = """
239
- #question input {
240
- font-size: 16px;
241
- }
242
- #url-textbox, #question-textbox {
243
- padding: 0 !important;
244
- }
245
- #short-upload-box .w-full {
246
- min-height: 10rem !important;
247
- }
248
- /* I think something like this can be used to re-shape
249
- * the table
250
- */
251
- /*
252
- .gr-samples-table tr {
253
- display: inline;
254
- }
255
- .gr-samples-table .p-2 {
256
- width: 100px;
257
- }
258
- */
259
- #select-a-file {
260
- width: 100%;
261
- }
262
- #file-clear {
263
- padding-top: 2px !important;
264
- padding-bottom: 2px !important;
265
- padding-left: 8px !important;
266
- padding-right: 8px !important;
267
- margin-top: 10px;
268
- }
269
- .gradio-container .gr-button-primary {
270
- background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
271
- border: 1px solid #B0DCCC;
272
- border-radius: 8px;
273
- color: #1B8700;
274
- }
275
- .gradio-container.dark button#submit-button {
276
- background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
277
- border: 1px solid #B0DCCC;
278
- border-radius: 8px;
279
- color: #1B8700
280
- }
281
- table.gr-samples-table tr td {
282
- border: none;
283
- outline: none;
284
- }
285
- table.gr-samples-table tr td:first-of-type {
286
- width: 0%;
287
- }
288
- div#short-upload-box div.absolute {
289
- display: none !important;
290
- }
291
- gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
292
- gap: 0px 2%;
293
- }
294
- gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
295
- gap: 0px;
296
- }
297
- gradio-app h2, .gradio-app h2 {
298
- padding-top: 10px;
299
- }
300
- #answer {
301
- overflow-y: scroll;
302
- color: white;
303
- background: #666;
304
- border-color: #666;
305
- font-size: 20px;
306
- font-weight: bold;
307
- }
308
- #answer span {
309
- color: white;
310
- }
311
- #answer textarea {
312
- color:white;
313
- background: #777;
314
- border-color: #777;
315
- font-size: 18px;
316
- }
317
- #url-error input {
318
- color: red;
319
- }
320
- #results-table {
321
- max-height: 600px;
322
- overflow-y: scroll;
323
- }
324
- """
325
-
326
- with gr.Blocks(css=CSS) as demo:
327
- gr.Markdown("# DocQuery for Invoices")
328
- gr.Markdown(
329
- "DocQuery (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=invoices_space))"
330
- " uses LayoutLMv1 fine-tuned on an invoice dataset"
331
- " as well as DocVQA and SQuAD, which boot its general comprehension skills. The model is an enhanced"
332
- " QA architecture that supports selecting blocks of text which may be non-consecutive, which is a major"
333
- " issue when dealing with invoice documents (e.g. addresses)."
334
- " To use it, simply upload an image or PDF invoice and the model will predict values for several fields."
335
- " You can also create additional fields by simply typing in a question."
336
- " DocQuery is available on [Github](https://github.com/impira/docquery)."
337
  )
338
-
339
- document = gr.Variable()
340
- fields = gr.Variable(value={**FIELDS})
341
- example_question = gr.Textbox(visible=False)
342
- example_image = gr.Image(visible=False)
343
-
344
- with gr.Row(equal_height=True):
345
- with gr.Column():
346
- with gr.Row():
347
- gr.Markdown("## Select an invoice", elem_id="select-a-file")
348
- img_clear_button = gr.Button(
349
- "Clear", variant="secondary", elem_id="file-clear", visible=False
350
- )
351
- image = gr.Gallery(visible=False)
352
- with gr.Row(equal_height=True):
353
- with gr.Column():
354
- with gr.Row():
355
- url = gr.Textbox(
356
- show_label=False,
357
- placeholder="URL",
358
- lines=1,
359
- max_lines=1,
360
- elem_id="url-textbox",
361
- )
362
- submit = gr.Button("Get")
363
- url_error = gr.Textbox(
364
- visible=False,
365
- elem_id="url-error",
366
- max_lines=1,
367
- interactive=False,
368
- label="Error",
369
- )
370
- gr.Markdown("— or —")
371
- upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
372
- gr.Examples(
373
- examples=EXAMPLES,
374
- inputs=[example_image, example_question],
375
- )
376
-
377
- with gr.Column() as col:
378
- gr.Markdown("## Results")
379
- with gr.Tabs():
380
- with gr.TabItem("Table"):
381
- output_table = gr.Dataframe(
382
- headers=["Field", "Value"],
383
- **empty_table(fields.value),
384
- elem_id="results-table"
385
- )
386
-
387
- with gr.TabItem("JSON"):
388
- output = gr.JSON(label="Output", visible=True)
389
-
390
- model = gr.Radio(
391
- choices=list(CHECKPOINTS.keys()),
392
- value=list(CHECKPOINTS.keys())[0],
393
- label="Model",
394
- visible=False,
395
- )
396
-
397
- gr.Markdown("### Ask a question")
398
- with gr.Row():
399
- question = gr.Textbox(
400
- label="Question",
401
- show_label=False,
402
- placeholder="e.g. What is the invoice number?",
403
- lines=1,
404
- max_lines=1,
405
- elem_id="question-textbox",
406
- )
407
- clear_button = gr.Button("Clear", variant="secondary", visible=False)
408
- submit_button = gr.Button(
409
- "Add", variant="primary", elem_id="submit-button"
410
- )
411
-
412
- for cb in [img_clear_button, clear_button]:
413
- cb.click(
414
- lambda _: (
415
- gr.update(visible=False, value=None), # image
416
- None, # document
417
- # {**FIELDS}, # fields
418
- gr.update(value=None), # output
419
- gr.update(**empty_table(fields.value)), # output_table
420
- gr.update(visible=False),
421
- None,
422
- None,
423
- None,
424
- gr.update(visible=False, value=None),
425
- None,
426
- ),
427
- inputs=clear_button,
428
- outputs=[
429
- image,
430
- document,
431
- # fields,
432
- output,
433
- output_table,
434
- img_clear_button,
435
- example_image,
436
- upload,
437
- url,
438
- url_error,
439
- question,
440
- ],
441
- )
442
-
443
- submit_outputs = [
444
- document,
445
- fields,
446
- image,
447
- img_clear_button,
448
- url_error,
449
- output,
450
- output_table,
451
- ]
452
-
453
- upload.change(
454
- fn=process_upload,
455
- inputs=[upload, fields, model],
456
- outputs=submit_outputs,
457
- )
458
-
459
- submit.click(
460
- fn=process_path,
461
- inputs=[url, fields, model],
462
- outputs=submit_outputs,
463
- )
464
-
465
- for action in [question.submit, submit_button.click]:
466
- action(
467
- fn=process_question,
468
- inputs=[question, document, image, model, fields, output, output_table],
469
- outputs=[question, image, fields, output, output_table],
470
- )
471
-
472
- # model.change(
473
- # process_question,
474
- # inputs=[question, document, model],
475
- # outputs=[image, output, output_table],
476
- # )
477
-
478
- example_image.change(
479
- fn=load_example_document,
480
- inputs=[example_image, example_question, fields, model],
481
- outputs=submit_outputs,
482
- )
483
-
484
- if __name__ == "__main__":
485
- demo.launch(enable_queue=False)
486
-
487
- #code modified from Impira/invoices space
 
1
+ import re
 
 
 
 
 
 
2
  import gradio as gr
 
3
 
4
  import torch
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+
7
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
8
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+
13
+ def process_document(image):
14
+ # prepare encoder inputs
15
+ pixel_values = processor(image, return_tensors="pt").pixel_values
16
+
17
+ # prepare decoder inputs
18
+ task_prompt = "<s_cord-v2>"
19
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
20
+
21
+ # generate answer
22
+ outputs = model.generate(
23
+ pixel_values.to(device),
24
+ decoder_input_ids=decoder_input_ids.to(device),
25
+ max_length=model.decoder.config.max_position_embeddings,
26
+ early_stopping=True,
27
+ pad_token_id=processor.tokenizer.pad_token_id,
28
+ eos_token_id=processor.tokenizer.eos_token_id,
29
+ use_cache=True,
30
+ num_beams=1,
31
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
32
+ return_dict_in_generate=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
+
35
+ # postprocess
36
+ sequence = processor.batch_decode(outputs.sequences)[0]
37
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
38
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
39
+
40
+ return processor.token2json(sequence)
41
+
42
+ description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on CORD (document parsing). To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
43
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
44
+
45
+ demo = gr.Interface(
46
+ fn=process_document,
47
+ inputs="image",
48
+ outputs="json",
49
+ title="Demo: Donut 🍩 for Document Parsing",
50
+ description=description,
51
+ article=article,
52
+ enable_queue=True,
53
+ examples=[["DL.jpg"], ["EAC.png"], ["BC.jfif"]],
54
+ cache_examples=False)
55
+
56
+ demo.launch()