| | import gradio as gr |
| | import torch, os, json, requests, sys |
| | from PIL import Image |
| | from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig |
| | from torchvision import transforms |
| |
|
| | def load_image_from_URL(url): |
| | res = requests.get(url) |
| |
|
| | if res.status_code == 200: |
| | img = Image.open(requests.get(url, stream = True).raw) |
| |
|
| | if img.mode == "RGBA": |
| | img = img.convert("RGB") |
| |
|
| | return img |
| |
|
| | return None |
| |
|
| | class OCRVQAModel(torch.nn.Module): |
| | def add_tokens(self, list_of_tokens): |
| | self.added_tokens.update(list_of_tokens) |
| | newly_added_num = self.processor.tokenizer.add_tokens(list_of_tokens) |
| | |
| | if newly_added_num > 0: |
| | self.donut.decoder.resize_token_embeddings(len(self.processor.tokenizer)) |
| | |
| | def __init__(self, config): |
| | super().__init__() |
| | |
| | self.model_name_or_path = config['donut'] |
| | self.processor_name_or_path = config['processor'] |
| | self.config_name_or_path = config['config'] |
| | |
| | self.donut_config = VisionEncoderDecoderConfig.from_pretrained(self.config_name_or_path) |
| | self.donut_config.encoder.image_size = [800, 600] |
| | self.donut_config.decoder.max_length = 64 |
| | |
| | self.processor = DonutProcessor.from_pretrained(self.processor_name_or_path) |
| | self.donut = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path, config = self.donut_config) |
| | |
| | self.added_tokens = set([]) |
| | self.setup() |
| | |
| | def setup(self): |
| | self.add_tokens(["<yes/>", "<no/>"]) |
| | self.processor.feature_extractor.size = self.donut_config.encoder.image_size[::-1] |
| | self.processor.feature_extractor.do_align_long_axis = False |
| | |
| | def inference(self, image, prompt, device): |
| | |
| | self.donut.eval() |
| | with torch.no_grad(): |
| |
|
| | print(type(image), type(prompt), file = sys.stderr) |
| | image_ids = self.processor(image, return_tensors="pt").pixel_values.to(device) |
| |
|
| | question = f'<s_docvqa><s_question>{prompt}</s_question><s_answer>' |
| |
|
| | embedded_question = self.processor.tokenizer( |
| | question, |
| | add_special_tokens = False, |
| | return_tensors = "pt" |
| | )["input_ids"].to(device) |
| |
|
| | outputs = self.donut.generate( |
| | image_ids, |
| | decoder_input_ids=embedded_question, |
| | max_length = self.donut.decoder.config.max_position_embeddings, |
| | early_stopping = True, |
| | pad_token_id = self.processor.tokenizer.pad_token_id, |
| | eos_token_id = self.processor.tokenizer.eos_token_id, |
| | use_cache = True, |
| | num_beams = 1, |
| | bad_words_ids = [ |
| | [self.processor.tokenizer.unk_token_id] |
| | ], |
| | return_dict_in_generate = True |
| | ) |
| | |
| | return self.processor.token2json(self.processor.batch_decode(outputs.sequences)[0]) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | model = OCRVQAModel({ |
| | "donut": "ndtran/donut_ocr-vqa-200k", |
| | "processor": "ndtran/donut_ocr-vqa-200k", |
| | "config": "naver-clova-ix/donut-base-finetuned-docvqa" |
| | }) |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = model.to(device) |
| |
|
| | def get_answer(image, url, question) -> str: |
| | global model, device |
| |
|
| | if url is not None and (url.startswith('http') or url.startswith('https')): |
| | result = model.inference(load_image_from_URL(url), question, device) |
| | return result.get('answer', 'I don\'t know :<') |
| | |
| | result = model.inference(image, question, device) |
| | return result.get('answer', 'I don\'t know :<') |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Row(): |
| | gr.Markdown( |
| | """ |
| | ## Donut-OCR-VQA |
| | - This demo uses fine-tuned OCR-VQA-Donut model on the OCR-VQA-200k dataset to answer questions about images. |
| | |
| | ## IO description |
| | - Input is an image or URL that represents a book cover (recommended) and a question that asks about information on the image. |
| | - Output: an answer to the question. |
| | """ |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | |
| | image = gr.Image(shape=(224, 224), type="pil", label="Pick an image") |
| | image_url = gr.Textbox(lines=1, label="Or use this option!", placeholder="Enter the image URL here") |
| | question = gr.Textbox(lines=5, label="Question") |
| |
|
| | ask = gr.Button(label="Get the answer") |
| |
|
| | with gr.Column(): |
| | answer = gr.Label(label="Answer") |
| |
|
| | ask.click(get_answer, inputs=[image, image_url, question], outputs=[answer]) |
| |
|
| | demo.launch() |