ngupta2026 commited on
Commit
d60e25f
Β·
verified Β·
1 Parent(s): 8be9c6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -28
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import pytesseract
3
- from PIL import Image
4
  import torch
5
  import re
6
  import requests
@@ -12,9 +12,7 @@ from transformers import LayoutLMTokenizerFast, LayoutLMForTokenClassification
12
  # CONFIG
13
  # =====================================================
14
  RESEND_API_KEY = os.getenv("RESEND_API_KEY")
15
-
16
- # Use verified sender from Resend
17
- FROM_EMAIL = "AI Claims <claims@yudham.com>"
18
 
19
  MODEL_NAME = "ngupta2026/sroie-layoutlm"
20
 
@@ -39,7 +37,7 @@ model.to(device)
39
  model.eval()
40
 
41
  # =====================================================
42
- # NORMALIZE BOX
43
  # =====================================================
44
  def normalize(box, width, height):
45
  return [
@@ -50,26 +48,48 @@ def normalize(box, width, height):
50
  ]
51
 
52
  # =====================================================
53
- # AVG CONFIDENCE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # =====================================================
55
- def avg_conf(values):
56
- if len(values) == 0:
57
  return 0
58
- return sum(values) / len(values)
59
 
60
  # =====================================================
61
- # OCR + EXTRACTION (IMPROVED ACCURACY)
62
  # =====================================================
63
  def extract_receipt(image):
64
 
65
  try:
66
- # Keep quality high for OCR
67
- image = image.convert("RGB")
68
 
 
69
  data = pytesseract.image_to_data(
70
  image,
71
  output_type=pytesseract.Output.DICT,
72
- config="--oem 3 --psm 6"
73
  )
74
 
75
  words = []
@@ -77,16 +97,16 @@ def extract_receipt(image):
77
 
78
  for i in range(len(data["text"])):
79
 
80
- text = data["text"][i].strip()
81
 
82
- if text != "" and text != "|":
83
 
84
  x = data["left"][i]
85
  y = data["top"][i]
86
  w = data["width"][i]
87
  h = data["height"][i]
88
 
89
- words.append(text)
90
  boxes.append([x, y, x + w, y + h])
91
 
92
  if len(words) == 0:
@@ -95,7 +115,9 @@ def extract_receipt(image):
95
  width, height = image.size
96
  boxes = [normalize(b, width, height) for b in boxes]
97
 
98
- # IMPORTANT: use 512 for better predictions
 
 
99
  encoding = tokenizer(
100
  words,
101
  boxes=boxes,
@@ -108,6 +130,9 @@ def extract_receipt(image):
108
 
109
  encoding = {k: v.to(device) for k, v in encoding.items()}
110
 
 
 
 
111
  with torch.no_grad():
112
  outputs = model(**encoding)
113
 
@@ -129,32 +154,39 @@ def extract_receipt(image):
129
  }
130
 
131
  # =================================================
132
- # TOKEN LEVEL EXTRACTION
133
  # =================================================
134
  for word, pred, conf in zip(words, preds, confs):
135
 
136
  label = id2label[pred.item()]
137
  c = conf.item()
138
 
139
- # COMPANY from model
 
 
140
  if label == "COMPANY":
141
  result["company"].append(word)
142
  conf_store["company"].append(c)
143
 
144
- # DATE regex
 
 
145
  if re.search(r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", word):
146
  result["date"].append(word)
147
  conf_store["date"].append(c)
148
 
149
- # TOTAL numeric values
150
- cleaned = word.replace(",", "").replace("β‚Ή", "")
 
 
151
 
152
  if re.fullmatch(r"\d+(\.\d{1,2})?", cleaned):
 
153
  try:
154
  value = float(cleaned)
155
 
156
- # Better range for totals
157
- if value >= 10:
158
  result["total"].append(value)
159
  conf_store["total"].append(c)
160
 
@@ -167,14 +199,19 @@ def extract_receipt(image):
167
 
168
  # COMPANY
169
  company = " ".join(result["company"][:6]).strip()
 
170
  if company == "":
171
- company = "Not Found"
 
172
 
173
  # DATE
174
  date = result["date"][0] if result["date"] else "Not Found"
175
 
176
- # TOTAL = highest amount (better than last token)
177
- total = str(max(result["total"])) if result["total"] else "Not Found"
 
 
 
178
 
179
  # CONFIDENCE
180
  company_conf = avg_conf(conf_store["company"])
@@ -298,7 +335,7 @@ demo = gr.Interface(
298
  ],
299
 
300
  title="πŸ“„ AI Insurance Claim Generator",
301
- description="Upload receipt β†’ Better extraction β†’ Confidence check β†’ Auto Email"
302
  )
303
 
304
  demo.launch()
 
1
  import gradio as gr
2
  import pytesseract
3
+ from PIL import Image, ImageFilter, ImageOps
4
  import torch
5
  import re
6
  import requests
 
12
  # CONFIG
13
  # =====================================================
14
  RESEND_API_KEY = os.getenv("RESEND_API_KEY")
15
+ FROM_EMAIL = "AI Claims <claims@yudham.com>" # verified sender
 
 
16
 
17
  MODEL_NAME = "ngupta2026/sroie-layoutlm"
18
 
 
37
  model.eval()
38
 
39
  # =====================================================
40
+ # NORMALIZE BOUNDING BOXES
41
  # =====================================================
42
  def normalize(box, width, height):
43
  return [
 
48
  ]
49
 
50
  # =====================================================
51
+ # IMAGE PREPROCESSING (VERY IMPORTANT)
52
+ # =====================================================
53
+ def preprocess_image(image):
54
+
55
+ image = image.convert("RGB")
56
+
57
+ # upscale for OCR
58
+ w, h = image.size
59
+ image = image.resize((w * 2, h * 2))
60
+
61
+ # grayscale
62
+ image = image.convert("L")
63
+
64
+ # sharpen
65
+ image = image.filter(ImageFilter.SHARPEN)
66
+
67
+ # auto contrast
68
+ image = ImageOps.autocontrast(image)
69
+
70
+ return image
71
+
72
+ # =====================================================
73
+ # CONFIDENCE AVG
74
  # =====================================================
75
+ def avg_conf(lst):
76
+ if len(lst) == 0:
77
  return 0
78
+ return sum(lst) / len(lst)
79
 
80
  # =====================================================
81
+ # OCR + EXTRACTION
82
  # =====================================================
83
  def extract_receipt(image):
84
 
85
  try:
86
+ image = preprocess_image(image)
 
87
 
88
+ # Better OCR mode for receipts
89
  data = pytesseract.image_to_data(
90
  image,
91
  output_type=pytesseract.Output.DICT,
92
+ config="--oem 3 --psm 4"
93
  )
94
 
95
  words = []
 
97
 
98
  for i in range(len(data["text"])):
99
 
100
+ txt = data["text"][i].strip()
101
 
102
+ if txt != "" and txt != "|":
103
 
104
  x = data["left"][i]
105
  y = data["top"][i]
106
  w = data["width"][i]
107
  h = data["height"][i]
108
 
109
+ words.append(txt)
110
  boxes.append([x, y, x + w, y + h])
111
 
112
  if len(words) == 0:
 
115
  width, height = image.size
116
  boxes = [normalize(b, width, height) for b in boxes]
117
 
118
+ # =================================================
119
+ # TOKENIZER
120
+ # =================================================
121
  encoding = tokenizer(
122
  words,
123
  boxes=boxes,
 
130
 
131
  encoding = {k: v.to(device) for k, v in encoding.items()}
132
 
133
+ # =================================================
134
+ # MODEL PREDICTION
135
+ # =================================================
136
  with torch.no_grad():
137
  outputs = model(**encoding)
138
 
 
154
  }
155
 
156
  # =================================================
157
+ # EXTRACT ENTITIES
158
  # =================================================
159
  for word, pred, conf in zip(words, preds, confs):
160
 
161
  label = id2label[pred.item()]
162
  c = conf.item()
163
 
164
+ # -------------------------
165
+ # COMPANY
166
+ # -------------------------
167
  if label == "COMPANY":
168
  result["company"].append(word)
169
  conf_store["company"].append(c)
170
 
171
+ # -------------------------
172
+ # DATE
173
+ # -------------------------
174
  if re.search(r"\d{1,2}[/-]\d{1,2}[/-]\d{2,4}", word):
175
  result["date"].append(word)
176
  conf_store["date"].append(c)
177
 
178
+ # -------------------------
179
+ # TOTAL
180
+ # -------------------------
181
+ cleaned = word.replace(",", "").replace("β‚Ή", "").replace("$", "")
182
 
183
  if re.fullmatch(r"\d+(\.\d{1,2})?", cleaned):
184
+
185
  try:
186
  value = float(cleaned)
187
 
188
+ # realistic receipt range
189
+ if 1 <= value <= 10000:
190
  result["total"].append(value)
191
  conf_store["total"].append(c)
192
 
 
199
 
200
  # COMPANY
201
  company = " ".join(result["company"][:6]).strip()
202
+
203
  if company == "":
204
+ # fallback top words
205
+ company = " ".join(words[:3])
206
 
207
  # DATE
208
  date = result["date"][0] if result["date"] else "Not Found"
209
 
210
+ # TOTAL = best realistic amount
211
+ if result["total"]:
212
+ total = f"{max(result['total']):.2f}"
213
+ else:
214
+ total = "Not Found"
215
 
216
  # CONFIDENCE
217
  company_conf = avg_conf(conf_store["company"])
 
335
  ],
336
 
337
  title="πŸ“„ AI Insurance Claim Generator",
338
+ description="Upload receipt β†’ Extract fields accurately β†’ Confidence Check β†’ Auto Email"
339
  )
340
 
341
  demo.launch()