suthawadee commited on
Commit
87d2548
·
verified ·
1 Parent(s): caaf518

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ import re
6
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
7
+ from donut import DonutModel
8
+ import json
9
+
10
+
11
+ def demo_process(input_img, question=None):
12
+ global processor, model
13
+
14
+ input_img = Image.fromarray(input_img)
15
+ pixel_values = processor(input_img, return_tensors="pt").pixel_values.to(device)
16
+
17
+ if question:
18
+ task_prompt = f"<s_{question}>"
19
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
20
+ else:
21
+ task_prompt = "<s_cord-v2>"
22
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
23
+
24
+ with torch.no_grad():
25
+ outputs = model.generate(
26
+ pixel_values,
27
+ decoder_input_ids=decoder_input_ids,
28
+ max_length=1024, # เปลี่ยนตามความต้องการ
29
+ early_stopping=True,
30
+ pad_token_id=processor.tokenizer.pad_token_id,
31
+ eos_token_id=processor.tokenizer.eos_token_id,
32
+ use_cache=True,
33
+ num_beams=1,
34
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
35
+ return_dict_in_generate=True,
36
+ )
37
+
38
+ seq = processor.batch_decode(outputs.sequences)[0]
39
+ seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
40
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip()
41
+ seq = processor.token2json(seq)
42
+ return seq
43
+
44
+
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--task", type=str, default="cord-v2") # Add argument for task
47
+ parser.add_argument("--pretrained_path", type=str, default="suthawadee/donut-demo_new")
48
+ args, left_argv = parser.parse_known_args()
49
+
50
+ processor = DonutProcessor.from_pretrained(args.pretrained_path)
51
+ model = VisionEncoderDecoderModel.from_pretrained(args.pretrained_path)
52
+ device = "cpu" if not torch.cuda.is_available() else "cuda"
53
+ model.to(device)
54
+ model.eval()
55
+
56
+ inputs = ["image", "text"] if args.task == "docvqa" else "image"
57
+ outputs = "json"
58
+
59
+ demo = gr.Interface(
60
+ fn=demo_process,
61
+ inputs=inputs,
62
+ outputs=outputs,
63
+ title="🇹🇭🧾ThaiReceipt",
64
+ description="Upload an image."
65
+ )
66
+ demo.launch(debug=True)