suthawadee commited on
Commit
462d7a2
·
verified ·
1 Parent(s): 5704f67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -4,8 +4,6 @@ import torch
4
  from PIL import Image
5
  import re
6
  from transformers import DonutProcessor, VisionEncoderDecoderModel
7
- import json
8
-
9
 
10
  def demo_process(input_img, question=None):
11
  global processor, model
@@ -22,17 +20,17 @@ def demo_process(input_img, question=None):
22
 
23
  with torch.no_grad():
24
  outputs = model.generate(
25
- pixel_values,
26
- decoder_input_ids=decoder_input_ids,
27
- max_length=1024, # เปลี่ยนตามความต้องการ
28
- early_stopping=True,
29
- pad_token_id=processor.tokenizer.pad_token_id,
30
- eos_token_id=processor.tokenizer.eos_token_id,
31
- use_cache=True,
32
- num_beams=1,
33
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
34
- return_dict_in_generate=True,
35
- )
36
 
37
  seq = processor.batch_decode(outputs.sequences)[0]
38
  seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
@@ -40,9 +38,8 @@ def demo_process(input_img, question=None):
40
  seq = processor.token2json(seq)
41
  return seq
42
 
43
-
44
  parser = argparse.ArgumentParser()
45
- parser.add_argument("--task", type=str, default="cord-v2") # Add argument for task
46
  parser.add_argument("--pretrained_path", type=str, default="suthawadee/donut-demo_new")
47
  args, left_argv = parser.parse_known_args()
48
 
@@ -52,14 +49,24 @@ device = "cpu" if not torch.cuda.is_available() else "cuda"
52
  model.to(device)
53
  model.eval()
54
 
55
- inputs = ["image", "text"] if args.task == "docvqa" else "image"
56
- outputs = "json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- demo = gr.Interface(
59
- fn=demo_process,
60
- inputs=inputs,
61
- outputs=outputs,
62
- title="🇹🇭🧾ThaiReceipt",
63
- description="Upload an image."
64
- )
65
- demo.launch(debug=True)
 
4
  from PIL import Image
5
  import re
6
  from transformers import DonutProcessor, VisionEncoderDecoderModel
 
 
7
 
8
  def demo_process(input_img, question=None):
9
  global processor, model
 
20
 
21
  with torch.no_grad():
22
  outputs = model.generate(
23
+ pixel_values,
24
+ decoder_input_ids=decoder_input_ids,
25
+ max_length=1024,
26
+ early_stopping=True,
27
+ pad_token_id=processor.tokenizer.pad_token_id,
28
+ eos_token_id=processor.tokenizer.eos_token_id,
29
+ use_cache=True,
30
+ num_beams=1,
31
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
32
+ return_dict_in_generate=True,
33
+ )
34
 
35
  seq = processor.batch_decode(outputs.sequences)[0]
36
  seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
 
38
  seq = processor.token2json(seq)
39
  return seq
40
 
 
41
  parser = argparse.ArgumentParser()
42
+ parser.add_argument("--task", type=str, default="cord-v2")
43
  parser.add_argument("--pretrained_path", type=str, default="suthawadee/donut-demo_new")
44
  args, left_argv = parser.parse_known_args()
45
 
 
49
  model.to(device)
50
  model.eval()
51
 
52
+ # เพิ่มตัวอย่างรูปภาพที่มีอยู่เพื่อทดสอบ
53
+ image1 = "8.jpg"
54
+ image2 = "15.jpg"
55
+
56
+ examples = [
57
+ [Image.open(image1)],
58
+ [Image.open(image2)]
59
+ ]
60
+
61
+ def main(pretrained_path, examples):
62
+ demo = gr.Interface(
63
+ fn=demo_process,
64
+ inputs=["image", "text"] if args.task == "docvqa" else "image",
65
+ outputs="json",
66
+ title="🇹🇭🧾ThaiReceipt",
67
+ description="Upload image.",
68
+ examples=examples
69
+ )
70
+ demo.launch(debug=True)
71
 
72
+ main(args.pretrained_path, examples)