ngupta2026 commited on
Commit
2bd90f0
·
verified ·
1 Parent(s): f8fd50e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -42
app.py CHANGED
@@ -3,18 +3,24 @@ import pytesseract
3
  from PIL import Image
4
  import torch
5
  import re
 
 
 
 
 
 
6
  from transformers import LayoutLMTokenizerFast, LayoutLMForTokenClassification
7
 
8
- # =========================
9
  # LABELS
10
- # =========================
11
- label2id = {"O":0, "COMPANY":1, "DATE":2, "TOTAL":3}
12
- id2label = {v:k for k,v in label2id.items()}
13
 
14
- # =========================
15
- # LOAD YOUR TRAINED MODEL
16
- # =========================
17
- MODEL_NAME = "ngupta2026/sroie-layoutlm" # 🔥 your new model
18
 
19
  model = LayoutLMForTokenClassification.from_pretrained(MODEL_NAME)
20
  tokenizer = LayoutLMTokenizerFast.from_pretrained(MODEL_NAME)
@@ -23,9 +29,18 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model.to(device)
24
  model.eval()
25
 
26
- # =========================
 
 
 
 
 
 
 
 
 
27
  # NORMALIZE BOXES
28
- # =========================
29
  def normalize(box, width, height):
30
  return [
31
  int(1000 * box[0] / width),
@@ -34,18 +49,21 @@ def normalize(box, width, height):
34
  int(1000 * box[3] / height),
35
  ]
36
 
37
- # =========================
38
- # MAIN FUNCTION
39
- # =========================
40
- def process(image):
41
 
42
- # OCR
43
- data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
 
 
44
 
45
  words = []
46
  boxes = []
47
 
48
  for i in range(len(data["text"])):
 
49
  text = data["text"][i].strip()
50
 
51
  if text != "":
@@ -55,16 +73,14 @@ def process(image):
55
  h = data["height"][i]
56
 
57
  words.append(text)
58
- boxes.append([x, y, x+w, y+h])
59
 
60
  if len(words) == 0:
61
  return {"error": "No text detected"}
62
 
63
- # normalize boxes
64
  width, height = image.size
65
  boxes = [normalize(box, width, height) for box in boxes]
66
 
67
- # tokenize
68
  encoding = tokenizer(
69
  words,
70
  boxes=boxes,
@@ -75,58 +91,134 @@ def process(image):
75
  max_length=512
76
  )
77
 
78
- encoding = {k:v.to(device) for k,v in encoding.items()}
79
 
80
- # model prediction
81
  with torch.no_grad():
82
  outputs = model(**encoding)
83
 
84
  predictions = torch.argmax(outputs.logits, dim=2)[0][:len(words)]
85
 
86
- # =========================
87
- # HYBRID EXTRACTION
88
- # =========================
89
- result = {"company": [], "date": [], "total": []}
 
90
 
91
  for word, pred in zip(words, predictions):
92
 
93
  label = id2label[pred.item()]
94
 
95
- # 🧠 MODEL (company)
96
  if label == "COMPANY":
97
  result["company"].append(word)
98
 
99
- # 📅 DATE (strong regex)
100
  if re.search(r"\d{2}[/-]\d{2}[/-]\d{2,4}", word):
101
  result["date"].append(word)
102
 
103
- # 💰 TOTAL (better filtering)
104
  if re.search(r"\d+(\.\d{2})?", word):
105
  try:
106
  value = float(word.replace(",", ""))
107
- if value > 50: # ignore small numbers
108
  result["total"].append(word)
109
  except:
110
  pass
111
 
112
- # =========================
113
- # CLEAN OUTPUT
114
- # =========================
115
- result["company"] = " ".join(result["company"]) if result["company"] else "Not Found"
116
- result["date"] = result["date"][0] if result["date"] else "Not Found"
117
- result["total"] = result["total"][-1] if result["total"] else "Not Found"
 
 
 
 
 
 
 
 
118
 
119
  return result
120
 
121
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # UI
123
- # =========================
124
  demo = gr.Interface(
125
- fn=process,
126
- inputs=gr.Image(type="pil"),
127
- outputs="json",
128
- title="📄 Document AI Extractor",
129
- description="Upload invoice image → Extract Company, Date, Total"
 
 
 
 
 
 
130
  )
131
 
132
  demo.launch()
 
3
  from PIL import Image
4
  import torch
5
  import re
6
+ import os
7
+ import smtplib
8
+
9
+ from email.mime.text import MIMEText
10
+ from email.mime.multipart import MIMEMultipart
11
+
12
  from transformers import LayoutLMTokenizerFast, LayoutLMForTokenClassification
13
 
14
+ # =====================================================
15
  # LABELS
16
+ # =====================================================
17
+ label2id = {"O": 0, "COMPANY": 1, "DATE": 2, "TOTAL": 3}
18
+ id2label = {v: k for k, v in label2id.items()}
19
 
20
+ # =====================================================
21
+ # LOAD MODEL
22
+ # =====================================================
23
+ MODEL_NAME = "ngupta2026/sroie-layoutlm"
24
 
25
  model = LayoutLMForTokenClassification.from_pretrained(MODEL_NAME)
26
  tokenizer = LayoutLMTokenizerFast.from_pretrained(MODEL_NAME)
 
29
  model.to(device)
30
  model.eval()
31
 
32
+ # =====================================================
33
+ # EMAIL CONFIG
34
+ # Add these in Hugging Face Space Secrets:
35
+ # EMAIL_USER = yourgmail@gmail.com
36
+ # EMAIL_PASS = your_app_password
37
+ # =====================================================
38
+ EMAIL_USER = os.getenv("EMAIL_USER")
39
+ EMAIL_PASS = os.getenv("EMAIL_PASS")
40
+
41
+ # =====================================================
42
  # NORMALIZE BOXES
43
+ # =====================================================
44
  def normalize(box, width, height):
45
  return [
46
  int(1000 * box[0] / width),
 
49
  int(1000 * box[3] / height),
50
  ]
51
 
52
+ # =====================================================
53
+ # EXTRACT DATA
54
+ # =====================================================
55
+ def extract_receipt(image):
56
 
57
+ data = pytesseract.image_to_data(
58
+ image,
59
+ output_type=pytesseract.Output.DICT
60
+ )
61
 
62
  words = []
63
  boxes = []
64
 
65
  for i in range(len(data["text"])):
66
+
67
  text = data["text"][i].strip()
68
 
69
  if text != "":
 
73
  h = data["height"][i]
74
 
75
  words.append(text)
76
+ boxes.append([x, y, x + w, y + h])
77
 
78
  if len(words) == 0:
79
  return {"error": "No text detected"}
80
 
 
81
  width, height = image.size
82
  boxes = [normalize(box, width, height) for box in boxes]
83
 
 
84
  encoding = tokenizer(
85
  words,
86
  boxes=boxes,
 
91
  max_length=512
92
  )
93
 
94
+ encoding = {k: v.to(device) for k, v in encoding.items()}
95
 
 
96
  with torch.no_grad():
97
  outputs = model(**encoding)
98
 
99
  predictions = torch.argmax(outputs.logits, dim=2)[0][:len(words)]
100
 
101
+ result = {
102
+ "company": [],
103
+ "date": [],
104
+ "total": []
105
+ }
106
 
107
  for word, pred in zip(words, predictions):
108
 
109
  label = id2label[pred.item()]
110
 
111
+ # company from model
112
  if label == "COMPANY":
113
  result["company"].append(word)
114
 
115
+ # date from regex
116
  if re.search(r"\d{2}[/-]\d{2}[/-]\d{2,4}", word):
117
  result["date"].append(word)
118
 
119
+ # total from regex
120
  if re.search(r"\d+(\.\d{2})?", word):
121
  try:
122
  value = float(word.replace(",", ""))
123
+ if value > 50:
124
  result["total"].append(word)
125
  except:
126
  pass
127
 
128
+ result["company"] = (
129
+ " ".join(result["company"])
130
+ if result["company"] else "Not Found"
131
+ )
132
+
133
+ result["date"] = (
134
+ result["date"][0]
135
+ if result["date"] else "Not Found"
136
+ )
137
+
138
+ result["total"] = (
139
+ result["total"][-1]
140
+ if result["total"] else "Not Found"
141
+ )
142
 
143
  return result
144
 
145
+ # =====================================================
146
+ # SEND EMAIL
147
+ # =====================================================
148
+ def send_claim_email(to_email, extracted):
149
+
150
+ if not EMAIL_USER or not EMAIL_PASS:
151
+ return "Email secrets not configured."
152
+
153
+ subject = "Insurance Claim Request"
154
+
155
+ body = f"""
156
+ Dear Claims Team,
157
+
158
+ I would like to request reimbursement for an eligible expense.
159
+
160
+ Provider Name: {extracted['company']}
161
+ Bill Date: {extracted['date']}
162
+ Claim Amount: ₹{extracted['total']}
163
+
164
+ Please process the claim.
165
+
166
+ Regards
167
+ Customer
168
+ """
169
+
170
+ msg = MIMEMultipart()
171
+ msg["From"] = EMAIL_USER
172
+ msg["To"] = to_email
173
+ msg["Subject"] = subject
174
+
175
+ msg.attach(MIMEText(body, "plain"))
176
+
177
+ try:
178
+ server = smtplib.SMTP("smtp.gmail.com", 587)
179
+ server.starttls()
180
+ server.login(EMAIL_USER, EMAIL_PASS)
181
+ server.sendmail(
182
+ EMAIL_USER,
183
+ to_email,
184
+ msg.as_string()
185
+ )
186
+ server.quit()
187
+
188
+ return f"✅ Email sent successfully to {to_email}"
189
+
190
+ except Exception as e:
191
+ return f"❌ Email failed: {str(e)}"
192
+
193
+ # =====================================================
194
+ # MAIN UI FUNCTION
195
+ # =====================================================
196
+ def process_and_send(image, email_id):
197
+
198
+ extracted = extract_receipt(image)
199
+
200
+ if "error" in extracted:
201
+ return extracted, extracted["error"]
202
+
203
+ email_status = send_claim_email(email_id, extracted)
204
+
205
+ return extracted, email_status
206
+
207
+ # =====================================================
208
  # UI
209
+ # =====================================================
210
  demo = gr.Interface(
211
+ fn=process_and_send,
212
+ inputs=[
213
+ gr.Image(type="pil", label="Upload Receipt"),
214
+ gr.Textbox(label="Insurance Email ID")
215
+ ],
216
+ outputs=[
217
+ gr.JSON(label="Extracted Data"),
218
+ gr.Textbox(label="Email Status")
219
+ ],
220
+ title="📄 AI Insurance Claim Generator",
221
+ description="Upload receipt → Extract details → Auto send claim email"
222
  )
223
 
224
  demo.launch()