Amanda commited on
Commit
6212313
·
1 Parent(s): dd15c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +400 -49
app.py CHANGED
@@ -1,57 +1,408 @@
1
- import re
 
 
 
 
 
 
2
  import gradio as gr
3
 
4
  import torch
5
- from transformers import DonutProcessor, VisionEncoderDecoderModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
8
- model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model.to(device)
 
12
 
13
- def process_document(image, question):
14
- # prepare encoder inputs
15
- pixel_values = processor(image, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # prepare decoder inputs
18
- task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
19
- prompt = task_prompt.replace("{user_input}", question)
20
- decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
21
-
22
- # generate answer
23
- outputs = model.generate(
24
- pixel_values.to(device),
25
- decoder_input_ids=decoder_input_ids.to(device),
26
- max_length=model.decoder.config.max_position_embeddings,
27
- early_stopping=True,
28
- pad_token_id=processor.tokenizer.pad_token_id,
29
- eos_token_id=processor.tokenizer.eos_token_id,
30
- use_cache=True,
31
- num_beams=1,
32
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
33
- return_dict_in_generate=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
-
36
- # postprocess
37
- sequence = processor.batch_decode(outputs.sequences)[0]
38
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
39
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
40
-
41
- return processor.token2json(sequence)
42
-
43
- description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
44
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
45
-
46
- demo = gr.Interface(
47
- fn=process_document,
48
- inputs=["image", "text"],
49
- outputs="json",
50
- title="Demo: Donut 🍩 for DocVQA",
51
- description=description,
52
- article=article,
53
- enable_queue=True,
54
- examples=[["Invoice.jpg", "What is the Invoice Number?"], ["DL.jpg", "What is the Address?"], ["EAC.png", "What is the Document Type?"]],
55
- cache_examples=False)
56
-
57
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
  import gradio as gr
9
 
10
  import torch
11
+ from docquery import pipeline
12
+ from docquery.document import load_document, ImageDocument
13
+ from docquery.ocr_reader import get_ocr_reader
14
+
15
+
16
+ def ensure_list(x):
17
+ if isinstance(x, list):
18
+ return x
19
+ else:
20
+ return [x]
21
+
22
+
23
+ CHECKPOINTS = {
24
+ "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
25
+ "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
26
+ "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
27
+ }
28
+
29
+ PIPELINES = {}
30
+
31
+
32
+ def construct_pipeline(task, model):
33
+ global PIPELINES
34
+ if model in PIPELINES:
35
+ return PIPELINES[model]
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
39
+ PIPELINES[model] = ret
40
+ return ret
41
+
42
+
43
+ def run_pipeline(model, question, document, top_k):
44
+ pipeline = construct_pipeline("document-question-answering", model)
45
+ return pipeline(question=question, **document.context, top_k=top_k)
46
+
47
+
48
+ # TODO: Move into docquery
49
+ # TODO: Support words past the first page (or window?)
50
+ def lift_word_boxes(document, page):
51
+ return document.context["image"][page][1]
52
 
 
 
53
 
54
+ def expand_bbox(word_boxes):
55
+ if len(word_boxes) == 0:
56
+ return None
57
 
58
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
59
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
60
+ return [min_x, min_y, max_x, max_y]
61
+
62
+
63
+ # LayoutLM boxes are normalized to 0, 1000
64
+ def normalize_bbox(box, width, height, padding=0.005):
65
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
66
+ if padding != 0:
67
+ min_x = max(0, min_x - padding)
68
+ min_y = max(0, min_y - padding)
69
+ max_x = min(max_x + padding, 1)
70
+ max_y = min(max_y + padding, 1)
71
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
72
+
73
+
74
+ examples = [
75
+ [
76
+ "Invoice.jpg",
77
+ "What is the invoice number?",
78
+ ],
79
+ [
80
+ "DL.jpf",
81
+ "What is the First Name?",
82
+ ],
83
+ [
84
+ "EAC.png",
85
+ "What is the Document Type?",
86
+ ],
87
 
88
+ ]
89
+
90
+ question_files = {
91
+ "How many likes does the space have?": "https://https://huggingface.co/spaces/AshlingAHydar/IDP",
92
+ "What is the title of post number 5?": "https://news.ycombinator.com",
93
+ }
94
+
95
+
96
+ def process_path(path):
97
+ error = None
98
+ if path:
99
+ try:
100
+ document = load_document(path)
101
+ return (
102
+ document,
103
+ gr.update(visible=True, value=document.preview),
104
+ gr.update(visible=True),
105
+ gr.update(visible=False, value=None),
106
+ gr.update(visible=False, value=None),
107
+ None,
108
+ )
109
+ except Exception as e:
110
+ traceback.print_exc()
111
+ error = str(e)
112
+ return (
113
+ None,
114
+ gr.update(visible=False, value=None),
115
+ gr.update(visible=False),
116
+ gr.update(visible=False, value=None),
117
+ gr.update(visible=False, value=None),
118
+ gr.update(visible=True, value=error) if error is not None else None,
119
+ None,
120
  )
121
+
122
+
123
+ def process_upload(file):
124
+ if file:
125
+ return process_path(file.name)
126
+ else:
127
+ return (
128
+ None,
129
+ gr.update(visible=False, value=None),
130
+ gr.update(visible=False),
131
+ gr.update(visible=False, value=None),
132
+ gr.update(visible=False, value=None),
133
+ None,
134
+ )
135
+
136
+
137
+ colors = ["#64A087", "green", "black"]
138
+
139
+
140
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
141
+ if not question or document is None:
142
+ return None, None, None
143
+
144
+ text_value = None
145
+ predictions = run_pipeline(model, question, document, 3)
146
+ pages = [x.copy().convert("RGB") for x in document.preview]
147
+ for i, p in enumerate(ensure_list(predictions)):
148
+ if i == 0:
149
+ text_value = p["answer"]
150
+ else:
151
+ # Keep the code around to produce multiple boxes, but only show the top
152
+ # prediction for now
153
+ break
154
+
155
+ if "word_ids" in p:
156
+ image = pages[p["page"]]
157
+ draw = ImageDraw.Draw(image, "RGBA")
158
+ word_boxes = lift_word_boxes(document, p["page"])
159
+ x1, y1, x2, y2 = normalize_bbox(
160
+ expand_bbox([word_boxes[i] for i in p["word_ids"]]),
161
+ image.width,
162
+ image.height,
163
+ )
164
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
165
+
166
+ return (
167
+ gr.update(visible=True, value=pages),
168
+ gr.update(visible=True, value=predictions),
169
+ gr.update(
170
+ visible=True,
171
+ value=text_value,
172
+ ),
173
+ )
174
+
175
+
176
+ def load_example_document(img, question, model):
177
+ if img is not None:
178
+ if question in question_files:
179
+ document = load_document(question_files[question])
180
+ else:
181
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
182
+ preview, answer, answer_text = process_question(question, document, model)
183
+ return document, question, preview, gr.update(visible=True), answer, answer_text
184
+ else:
185
+ return None, None, None, gr.update(visible=False), None, None
186
+
187
+
188
+ CSS = """
189
+ #question input {
190
+ font-size: 16px;
191
+ }
192
+ #url-textbox {
193
+ padding: 0 !important;
194
+ }
195
+ #short-upload-box .w-full {
196
+ min-height: 10rem !important;
197
+ }
198
+ /* I think something like this can be used to re-shape
199
+ * the table
200
+ */
201
+ /*
202
+ .gr-samples-table tr {
203
+ display: inline;
204
+ }
205
+ .gr-samples-table .p-2 {
206
+ width: 100px;
207
+ }
208
+ */
209
+ #select-a-file {
210
+ width: 100%;
211
+ }
212
+ #file-clear {
213
+ padding-top: 2px !important;
214
+ padding-bottom: 2px !important;
215
+ padding-left: 8px !important;
216
+ padding-right: 8px !important;
217
+ margin-top: 10px;
218
+ }
219
+ .gradio-container .gr-button-primary {
220
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
221
+ border: 1px solid #B0DCCC;
222
+ border-radius: 8px;
223
+ color: #1B8700;
224
+ }
225
+ .gradio-container.dark button#submit-button {
226
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
227
+ border: 1px solid #B0DCCC;
228
+ border-radius: 8px;
229
+ color: #1B8700
230
+ }
231
+ table.gr-samples-table tr td {
232
+ border: none;
233
+ outline: none;
234
+ }
235
+ table.gr-samples-table tr td:first-of-type {
236
+ width: 0%;
237
+ }
238
+ div#short-upload-box div.absolute {
239
+ display: none !important;
240
+ }
241
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
242
+ gap: 0px 2%;
243
+ }
244
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
245
+ gap: 0px;
246
+ }
247
+ gradio-app h2, .gradio-app h2 {
248
+ padding-top: 10px;
249
+ }
250
+ #answer {
251
+ overflow-y: scroll;
252
+ color: white;
253
+ background: #666;
254
+ border-color: #666;
255
+ font-size: 20px;
256
+ font-weight: bold;
257
+ }
258
+ #answer span {
259
+ color: white;
260
+ }
261
+ #answer textarea {
262
+ color:white;
263
+ background: #777;
264
+ border-color: #777;
265
+ font-size: 18px;
266
+ }
267
+ #url-error input {
268
+ color: red;
269
+ }
270
+ """
271
+
272
+ with gr.Blocks(css=CSS) as demo:
273
+ gr.Markdown("# Document Parser: Document Parser Engine")
274
+ gr.Markdown(
275
+ "Document_Parser is built on top of DocQuery library)"
276
+ " uses LayoutLMv1 fine-tuned on DocVQA, a document visual question"
277
+ " answering dataset, as well as SQuAD, which boosts its English-language comprehension."
278
+
279
+ )
280
+
281
+ document = gr.Variable()
282
+ example_question = gr.Textbox(visible=False)
283
+ example_image = gr.Image(visible=False)
284
+
285
+ with gr.Row(equal_height=True):
286
+ with gr.Column():
287
+ with gr.Row():
288
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
289
+ img_clear_button = gr.Button(
290
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
291
+ )
292
+ image = gr.Gallery(visible=False)
293
+ with gr.Row(equal_height=True):
294
+ with gr.Column():
295
+ with gr.Row():
296
+ url = gr.Textbox(
297
+ show_label=False,
298
+ placeholder="URL",
299
+ lines=1,
300
+ max_lines=1,
301
+ elem_id="url-textbox",
302
+ )
303
+ submit = gr.Button("Get")
304
+ url_error = gr.Textbox(
305
+ visible=False,
306
+ elem_id="url-error",
307
+ max_lines=1,
308
+ interactive=False,
309
+ label="Error",
310
+ )
311
+ gr.Markdown("— or —")
312
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
313
+ gr.Examples(
314
+ examples=examples,
315
+ inputs=[example_image, example_question],
316
+ )
317
+
318
+ with gr.Column() as col:
319
+ gr.Markdown("## 2. Ask a question")
320
+ question = gr.Textbox(
321
+ label="Question",
322
+ placeholder="e.g. What is the invoice number?",
323
+ lines=1,
324
+ max_lines=1,
325
+ )
326
+ model = gr.Radio(
327
+ choices=list(CHECKPOINTS.keys()),
328
+ value=list(CHECKPOINTS.keys())[0],
329
+ label="Model",
330
+ )
331
+
332
+ with gr.Row():
333
+ clear_button = gr.Button("Clear", variant="secondary")
334
+ submit_button = gr.Button(
335
+ "Submit", variant="primary", elem_id="submit-button"
336
+ )
337
+ with gr.Column():
338
+ output_text = gr.Textbox(
339
+ label="Top Answer", visible=False, elem_id="answer"
340
+ )
341
+ output = gr.JSON(label="Output", visible=False)
342
+
343
+ for cb in [img_clear_button, clear_button]:
344
+ cb.click(
345
+ lambda _: (
346
+ gr.update(visible=False, value=None),
347
+ None,
348
+ gr.update(visible=False, value=None),
349
+ gr.update(visible=False, value=None),
350
+ gr.update(visible=False),
351
+ None,
352
+ None,
353
+ None,
354
+ gr.update(visible=False, value=None),
355
+ None,
356
+ ),
357
+ inputs=clear_button,
358
+ outputs=[
359
+ image,
360
+ document,
361
+ output,
362
+ output_text,
363
+ img_clear_button,
364
+ example_image,
365
+ upload,
366
+ url,
367
+ url_error,
368
+ question,
369
+ ],
370
+ )
371
+
372
+ upload.change(
373
+ fn=process_upload,
374
+ inputs=[upload],
375
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
376
+ )
377
+ submit.click(
378
+ fn=process_path,
379
+ inputs=[url],
380
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
381
+ )
382
+
383
+ question.submit(
384
+ fn=process_question,
385
+ inputs=[question, document, model],
386
+ outputs=[image, output, output_text],
387
+ )
388
+
389
+ submit_button.click(
390
+ process_question,
391
+ inputs=[question, document, model],
392
+ outputs=[image, output, output_text],
393
+ )
394
+
395
+ model.change(
396
+ process_question,
397
+ inputs=[question, document, model],
398
+ outputs=[image, output, output_text],
399
+ )
400
+
401
+ example_image.change(
402
+ fn=load_example_document,
403
+ inputs=[example_image, example_question, model],
404
+ outputs=[document, question, image, img_clear_button, output, output_text],
405
+ )
406
+
407
+ if __name__ == "__main__":
408
+ demo.launch(enable_queue=False)