ashvin-savani commited on
Commit
294474c
·
1 Parent(s): 3889a50
Files changed (2) hide show
  1. app.py +132 -4
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,135 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import base64
4
+ import json
5
+ import gc
6
+ import torch
7
+ import io
8
+ from transformers import AutoProcessor, AutoModelForImageTextToText
9
+ from qwen_vl_utils import process_vision_info
10
  import gradio as gr
11
+ import spaces
12
 
13
+ # Model setup
14
+ MODEL_NAME = "numind/NuExtract-2.0-4B"
15
+ device = "cuda" # ZeroGPU provides GPU
16
 
17
+ model = AutoModelForImageTextToText.from_pretrained(
18
+ MODEL_NAME,
19
+ trust_remote_code=True,
20
+ dtype=torch.bfloat16,
21
+ device_map=None, # Load on CPU, move to GPU in function
22
+ )
23
+
24
+ processor = AutoProcessor.from_pretrained(
25
+ MODEL_NAME,
26
+ trust_remote_code=True,
27
+ padding_side='left',
28
+ use_fast=True,
29
+ )
30
+
31
+ # Invoice schema
32
+ invoice_schema = {
33
+ "invoice_number": "",
34
+ "invoice_date": "",
35
+ "supplier_name": "",
36
+ "supplier_address": "",
37
+ "total_amount": "",
38
+ "currency": "",
39
+ "items": [
40
+ {
41
+ "description": "",
42
+ "quantity": "",
43
+ "unit_price": "",
44
+ "total_price": ""
45
+ }
46
+ ]
47
+ }
48
+
49
+ def encode_image_to_base64(image_path):
50
+ with open(image_path, "rb") as img_file:
51
+ return base64.b64encode(img_file.read()).decode("utf-8")
52
+
53
+ def encode_image_from_pil(image):
54
+ buffer = io.BytesIO()
55
+ image.save(buffer, format="PNG")
56
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
57
+
58
+ def prepare_prompt(image_path):
59
+ base64_image = encode_image_to_base64(image_path)
60
+ messages = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {"type": "image", "image": f"data:image;base64,{base64_image}"}
65
+ ]
66
+ }
67
+ ]
68
+ text = processor.tokenizer.apply_chat_template(
69
+ messages,
70
+ template=json.dumps(invoice_schema, indent=4),
71
+ tokenize=False,
72
+ add_generation_prompt=True
73
+ )
74
+ return messages, text
75
+
76
+ @spaces.GPU
77
+ def process_image(image):
78
+ if image is None:
79
+ return "No image provided."
80
+
81
+ base64_str = encode_image_from_pil(image)
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {"type": "image", "image": f"data:image;base64,{base64_str}"}
87
+ ]
88
+ }
89
+ ]
90
+ text = processor.tokenizer.apply_chat_template(
91
+ messages,
92
+ template=json.dumps(invoice_schema, indent=4),
93
+ tokenize=False,
94
+ add_generation_prompt=True
95
+ )
96
+
97
+ image_inputs = process_vision_info(messages)[0] or []
98
+
99
+ inputs = processor(
100
+ text=[text],
101
+ images=image_inputs,
102
+ padding=True,
103
+ return_tensors="pt",
104
+ ).to(device)
105
+
106
+ generation_config = {
107
+ "do_sample": False,
108
+ "num_beams": 1,
109
+ "max_new_tokens": 2048,
110
+ }
111
+
112
+ generated_ids = model.generate(**inputs, **generation_config)
113
+
114
+ generated_ids_trimmed = [
115
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
116
+ ]
117
+
118
+ output_text = processor.batch_decode(
119
+ generated_ids_trimmed,
120
+ skip_special_tokens=True,
121
+ clean_up_tokenization_spaces=False,
122
+ )[0]
123
+
124
+ return output_text
125
+
126
+ # Gradio interface
127
+ iface = gr.Interface(
128
+ fn=process_image,
129
+ inputs=gr.Image(type="pil", label="Upload Invoice Image"),
130
+ outputs=gr.Textbox(label="Extracted Invoice Data (JSON)"),
131
+ title="Invoice Parser with NuExtract",
132
+ description="Upload an invoice image to extract structured data using AI."
133
+ )
134
+
135
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ qwen-vl-utils
4
+ gradio
5
+ huggingface_hub[spaces]
6
+ accelerate
7
+ flashinfer-python