athaneduc commited on
Commit
ceb2483
·
verified ·
1 Parent(s): 43e74c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -28
app.py CHANGED
@@ -1,46 +1,24 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from transformers import DonutProcessor, VisionEncoderDecoderModel
7
- from PIL import Image
8
  import torch
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_repo_id = "selvakumarcts/sk_invoice_receipts"
12
 
 
13
  processor = DonutProcessor.from_pretrained(model_repo_id)
14
  model = VisionEncoderDecoderModel.from_pretrained(model_repo_id).to(device)
15
 
16
-
17
- if torch.cuda.is_available():
18
- torch_dtype = torch.float16
19
- else:
20
- torch_dtype = torch.float32
21
-
22
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
23
- pipe = pipe.to(device)
24
-
25
- MAX_SEED = np.iinfo(np.int32).max
26
- MAX_IMAGE_SIZE = 1024
27
-
28
-
29
- # @spaces.GPU #[uncomment to use ZeroGPU]
30
- def infer(image, progress=gr.Progress(track_tqdm=True)):
31
- # Preprocess the image
32
  image = image.convert("RGB")
33
  pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
34
-
35
- # Run the model
36
  output = model.generate(pixel_values, max_length=512)
37
-
38
- # Decode the output (JSON)
39
  result = processor.batch_decode(output, skip_special_tokens=True)[0]
40
-
41
  return result
42
 
43
-
44
  with gr.Blocks() as demo:
45
  gr.Markdown(" # Invoice/Receipt Reader (Donut Model)")
46
  with gr.Column():
 
1
  import gradio as gr
 
 
 
 
 
 
2
  import torch
3
+ from PIL import Image
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
  model_repo_id = "selvakumarcts/sk_invoice_receipts"
8
 
9
+ # Load model and processor
10
  processor = DonutProcessor.from_pretrained(model_repo_id)
11
  model = VisionEncoderDecoderModel.from_pretrained(model_repo_id).to(device)
12
 
13
+ # Inference function
14
+ def infer(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  image = image.convert("RGB")
16
  pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
 
 
17
  output = model.generate(pixel_values, max_length=512)
 
 
18
  result = processor.batch_decode(output, skip_special_tokens=True)[0]
 
19
  return result
20
 
21
+ # UI
22
  with gr.Blocks() as demo:
23
  gr.Markdown(" # Invoice/Receipt Reader (Donut Model)")
24
  with gr.Column():