isslao commited on
Commit
1256e60
·
verified ·
1 Parent(s): ece6a1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -42
app.py CHANGED
@@ -2,60 +2,64 @@ import re
2
  import gradio as gr
3
  import torch
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
- from transformers import AutoProcessor, AutoModelForVision2Seq
6
-
7
-
8
- processor = AutoProcessor.from_pretrained("debu-das/donut_receipt_v1.20")
9
- model = AutoModelForVision2Seq.from_pretrained("debu-das/donut_receipt_v1.20")
10
 
 
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model.to(device)
14
 
15
- def process_documents(images):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  results = []
17
  for image in images:
18
- # Prepare encoder inputs
19
- pixel_values = processor(image, return_tensors="pt").pixel_values
20
-
21
- # Prepare decoder inputs
22
- task_prompt = "<s_cord-v2>"
23
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
24
-
25
- # Generate answer
26
- outputs = model.generate(
27
- pixel_values.to(device),
28
- decoder_input_ids=decoder_input_ids.to(device),
29
- max_length=model.decoder.config.max_position_embeddings,
30
- early_stopping=True,
31
- pad_token_id=processor.tokenizer.pad_token_id,
32
- eos_token_id=processor.tokenizer.eos_token_id,
33
- use_cache=True,
34
- num_beams=1,
35
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
36
- return_dict_in_generate=True,
37
- )
38
-
39
- # Postprocess
40
- sequence = processor.batch_decode(outputs.sequences)[0]
41
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
42
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # Remove first task start token
43
-
44
- results.append(processor.token2json(sequence))
45
-
46
  return results
47
 
48
- description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on CORD (document parsing). To use it, simply upload multiple images and click 'submit'."
49
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
50
 
51
  demo = gr.Interface(
52
- fn=process_documents,
53
- inputs=gr.Files(label="Upload Images"), # Use Files to handle multiple files
54
  outputs="json",
55
- title="Batch Demo: Donut 🍩 for Document Parsing",
56
  description=description,
57
  article=article,
58
- examples=[["example.png"], ["example_1.png"],["example_2.png"], ["example_3.png"],["example_4.png"]],
59
- cache_examples=False)
 
 
 
 
60
 
61
- demo.launch()
 
 
2
  import gradio as gr
3
  import torch
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ import os
6
+ from PIL import Image
 
 
 
7
 
8
+ processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v1.20")
9
+ model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v1.20")
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
+ def process_document(image):
15
+ if isinstance(image, str): # Si l'image est un chemin de fichier
16
+ image = Image.open(image).convert("RGB")
17
+ pixel_values = processor(image, return_tensors="pt").pixel_values
18
+ task_prompt = "<s_cord-v2>"
19
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
20
+
21
+ outputs = model.generate(
22
+ pixel_values.to(device),
23
+ decoder_input_ids=decoder_input_ids.to(device),
24
+ max_length=model.decoder.config.max_position_embeddings,
25
+ early_stopping=True,
26
+ pad_token_id=processor.tokenizer.pad_token_id,
27
+ eos_token_id=processor.tokenizer.eos_token_id,
28
+ use_cache=True,
29
+ num_beams=1,
30
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
31
+ return_dict_in_generate=True,
32
+ )
33
+
34
+ sequence = processor.batch_decode(outputs.sequences)[0]
35
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
36
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
37
+
38
+ return processor.token2json(sequence)
39
+
40
+ def process_batch(images):
41
  results = []
42
  for image in images:
43
+ result = process_document(image)
44
+ results.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return results
46
 
47
+ description = "Démo Gradio pour Donut, une instance du modèle `VisionEncoderDecoderModel` affiné sur CORD (analyse de documents). Pour l'utiliser, téléchargez une ou plusieurs images et cliquez sur `Submit`, ou cliquez sur l'un des exemples pour les charger."
48
+ article = "Cette application permet maintenant de traiter plusieurs images de tickets de caisse à la fois."
49
 
50
  demo = gr.Interface(
51
+ fn=process_batch,
52
+ inputs=gr.File(file_count="multiple", type="filepath"),
53
  outputs="json",
54
+ title="Reconnaissance des tickets de caisse en lot 🧾",
55
  description=description,
56
  article=article,
57
+ examples=[
58
+ [["example.jpg", "example_1.jpg"]],
59
+ [["example_2.jpg", "example_3.jpg", "example_4.jpg"]]
60
+ ],
61
+ cache_examples=False
62
+ )
63
 
64
+ port = int(os.environ.get("PORT", 7860))
65
+ demo.launch(server_name="0.0.0.0", server_port=port)