File size: 1,964 Bytes
057c981
37b6db0
 
 
 
35c9c68
057c981
7e6ce5c
37b6db0
 
 
 
 
 
 
 
 
 
 
 
 
58e3387
37b6db0
df0427f
37b6db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424e968
37b6db0
 
 
 
 
 
 
057c981
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
from PIL import Image
import json
import re

MODEL_ID = "LLMTestSaurav/donut-6"

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = DonutProcessor.from_pretrained(MODEL_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
model.eval()

def run_donut(image):
    if image is None:
        return {"error": "No image provided"}

    image = image.convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    # pixel_values = processor(image, return_tensors="pt").pixel_values.to()
    task_prompt = '<passport_front>'
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

    outputs = model.generate(
        pixel_values=pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=512,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # Decode output
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    sequence = processor.token2json(sequence)         

    return json.dumps(sequence, indent=2, ensure_ascii=False)

with gr.Blocks() as demo:
    gr.Markdown("# Donut Sanity Check\nUpload an image → get JSON output")
    inp = gr.Image(type="pil", label="Upload Document Image")
    out = gr.Textbox(label="Parsed JSON", lines=20)
    btn = gr.Button("Run Donut")
    btn.click(run_donut, inputs=inp, outputs=out)

demo.launch()