Saurav Chaudhari commited on
Commit
37b6db0
·
1 Parent(s): 057c981

Add application file

Browse files
Files changed (1) hide show
  1. app.py +49 -3
app.py CHANGED
@@ -1,7 +1,53 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+ from PIL import Image
5
+ import json
6
 
7
+ MODEL_ID = "LLMTestSaurav/donut-base-finetuned-ind-cod-tza-rwa"
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ processor = DonutProcessor.from_pretrained(MODEL_ID)
12
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
13
+ model.eval()
14
+
15
+ def run_donut(image):
16
+ if image is None:
17
+ return {"error": "No image provided"}
18
+
19
+ image = image.convert("RGB")
20
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
21
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
22
+ task_prompt = '<passport_front>'
23
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
24
+
25
+ outputs = model.generate(
26
+ pixel_values=pixel_values,
27
+ decoder_input_ids=decoder_input_ids,
28
+ max_length=512,
29
+ early_stopping=True,
30
+ pad_token_id=processor.tokenizer.pad_token_id,
31
+ eos_token_id=processor.tokenizer.eos_token_id,
32
+ use_cache=True,
33
+ num_beams=1,
34
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
35
+ return_dict_in_generate=True,
36
+ )
37
+
38
+ # Decode output
39
+ sequence = processor.batch_decode(outputs.sequences)[0]
40
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
41
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
42
+ sequence = processor.token2json(sequence)
43
+
44
+ return json.dumps(parsed, indent=2, ensure_ascii=False)
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# Donut Sanity Check\nUpload an image → get JSON output")
48
+ inp = gr.Image(type="pil", label="Upload Document Image")
49
+ out = gr.Textbox(label="Parsed JSON", lines=20)
50
+ btn = gr.Button("Run Donut")
51
+ btn.click(run_donut, inputs=inp, outputs=out)
52
 
 
53
  demo.launch()