File size: 1,928 Bytes
2960296
 
 
 
 
48fff0d
2960296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d5b98b
 
2960296
 
 
 
 
 
 
2d5b98b
2960296
 
d3c10c0
2960296
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
import os


processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def vqa(image, question):
  # global processor, model
  
  # prepare decoder inputs
  task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
  prompt = task_prompt.replace("{user_input}", question)
  decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
  pixel_values = processor(image, return_tensors="pt").pixel_values


  outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=1,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

    # post-process
  sequence = processor.batch_decode(outputs.sequences)[0]
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
  return processor.token2json(sequence)

# dirpath = os.path.join(os.getcwd(), "sample docs/" )
# examples = [[os.path.join(dirpath, x),"what is this document"] for x in os.listdir(dirpath)]


demo = gr.Interface(
    fn=vqa,
    inputs=["image", "text"],
    outputs="json",
    title=f"Donut 🍩 demonstration for VQA task",
    # examples=[[os.path.join(dirpath, x),"what is this document"] for x in os.listdir(dirpath)],
)

demo.launch()