Zakariya007 commited on
Commit
3c6c02b
·
verified ·
1 Parent(s): aea0d00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -119
app.py CHANGED
@@ -1,77 +1,77 @@
1
- #!/usr/bin/env python3
2
- """
3
- DocFusion — Gradio Web UI
4
- Rihal CodeStacker 2026
5
- """
6
 
7
  import os
8
- os.system("apt-get install -y tesseract-ocr")
9
- os.system("pip install pytesseract -q")
 
 
 
 
10
 
11
  import re
12
  import json
13
  import torch
14
  import numpy as np
 
 
15
  from PIL import Image, ImageDraw
16
  from torchvision import transforms, models
17
  from transformers import LayoutLMForTokenClassification, BertTokenizerFast
18
  from huggingface_hub import hf_hub_download
19
- import gradio as gr
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  LABELS = ["O", "B-VENDOR", "I-VENDOR", "B-DATE", "I-DATE", "B-TOTAL", "I-TOTAL"]
23
  ID2LABEL = {i: x for i, x in enumerate(LABELS)}
24
  LABEL2ID = {x: i for i, x in enumerate(LABELS)}
25
 
26
- print("Loading models...")
27
- tokenizer = BertTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased")
 
 
 
 
 
28
  extraction_model = LayoutLMForTokenClassification.from_pretrained(
29
  "Zakariya007/docfusion-v1",
30
  num_labels=len(LABELS), id2label=ID2LABEL, label2id=LABEL2ID,
31
- )
32
- extraction_model = extraction_model.to(DEVICE)
33
  extraction_model.eval()
34
 
 
35
  forgery_model = models.efficientnet_b0(weights=None)
36
  forgery_model.classifier[1] = torch.nn.Linear(1280, 2)
37
- weights_path = hf_hub_download(
38
- repo_id="Zakariya007/docfusion-v2",
39
- filename="efficientnet_best.pth"
40
- )
41
  forgery_model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
42
  forgery_model = forgery_model.to(DEVICE)
43
  forgery_model.eval()
44
- print("✅ Models loaded!")
 
45
 
46
  def extract_fields(image):
47
  try:
48
- import pytesseract
49
- ocr_text = pytesseract.image_to_string(image)
50
-
51
- # Regex-based extraction
52
- date_match = re.search(
53
- r'\d{1,2}[\/\-\.]\d{1,2}[\/\-\.]\d{2,4}',
54
- ocr_text
55
- )
56
  date = date_match.group(0) if date_match else None
57
 
58
- total_match = re.search(
59
- r'(?:TOTAL|AMOUNT|JUMLAH)[^\d]*(\d+[\.,]\d{2})',
60
- ocr_text, re.IGNORECASE
61
- )
62
  total = total_match.group(1) if total_match else None
63
 
64
- lines = [l.strip() for l in ocr_text.split('\n') if len(l.strip()) > 3]
65
  vendor = lines[0] if lines else None
66
 
67
  return vendor, date, total
68
-
69
  except Exception as e:
70
  print(f"Extraction error: {e}")
71
  return None, None, None
72
 
73
-
74
- def detect_forgery(image, vendor, date, total):
75
  transform = transforms.Compose([
76
  transforms.Resize((224, 224)),
77
  transforms.ToTensor(),
@@ -79,106 +79,70 @@ def detect_forgery(image, vendor, date, total):
79
  ])
80
  tensor = transform(image).unsqueeze(0).to(DEVICE)
81
  with torch.no_grad():
82
- output = forgery_model(tensor)
83
- probs = torch.softmax(output, dim=1)
84
  forged_prob = probs[0][1].item()
85
- visual_flag = forged_prob > 0.5
86
-
87
- rule_flags = []
88
- if not vendor: rule_flags.append("Missing vendor")
89
- if not date: rule_flags.append("Missing date")
90
- if not total: rule_flags.append("Missing total")
91
-
92
- if total:
93
- try:
94
- total_val = float(re.sub(r"[^\d.]", "", total))
95
- if total_val > 10000: rule_flags.append("Abnormally high total")
96
- if total_val <= 0: rule_flags.append("Invalid total amount")
97
- except Exception:
98
- rule_flags.append("Unparseable total")
99
-
100
- if date:
101
- date_pattern = re.compile(
102
- r"\d{1,2}[\/\-\.]\d{1,2}[\/\-\.]\d{2,4}|\d{4}[\/\-\.]\d{2}[\/\-\.]\d{2}"
103
- )
104
- if not date_pattern.search(date):
105
- rule_flags.append("Invalid date format")
106
-
107
- rule_flag = len(rule_flags) >= 2
108
- is_forged = 1 if (visual_flag or rule_flag) else 0
109
- return is_forged, forged_prob, rule_flags
110
-
111
-
112
- def annotate_image(image, is_forged):
113
- annotated = image.copy()
114
- draw = ImageDraw.Draw(annotated)
115
- w, h = image.size
116
- if is_forged:
117
- draw.rectangle([0, 0, w-1, h-1], outline="#FF0000", width=6)
118
- draw.text((10, 10), "FORGED", fill="#FF0000")
119
- else:
120
- draw.rectangle([0, 0, w-1, h-1], outline="#00AA00", width=6)
121
- draw.text((10, 10), "GENUINE", fill="#00AA00")
122
- return annotated
123
 
124
 
 
 
 
125
  def process_receipt(image):
126
- if image is None:
127
- return None, "Please upload an image.", "", "", "", ""
128
-
129
- pil_image = Image.fromarray(image).convert("RGB")
130
- vendor, date, total = extract_fields(pil_image)
131
- is_forged, forged_prob, rule_flags = detect_forgery(pil_image, vendor, date, total)
132
- annotated = annotate_image(pil_image, is_forged)
133
-
134
- if is_forged:
135
- status = f"FORGED (confidence: {forged_prob:.1%})"
136
- if rule_flags:
137
- status += "\nFlags: " + ", ".join(rule_flags)
138
- else:
139
- status = f"GENUINE (forged probability: {forged_prob:.1%})"
140
-
141
- return (
142
- np.array(annotated),
143
- status,
144
- vendor or "Not detected",
145
- date or "Not detected",
146
- total or "Not detected",
147
- json.dumps({
148
- "vendor": vendor,
149
- "date": date,
150
- "total": total,
151
- "is_forged": is_forged
152
- }, indent=2),
153
- )
154
 
 
155
 
156
- with gr.Blocks(title="DocFusion") as demo:
157
- gr.Markdown("# DocFusion — Intelligent Document Processing")
158
- gr.Markdown("Upload a scanned receipt to extract fields and detect forgery.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  with gr.Row():
161
  with gr.Column():
162
- input_image = gr.Image(label="Upload Receipt", type="numpy")
163
- submit_btn = gr.Button("Analyze Receipt", variant="primary")
164
  with gr.Column():
165
- output_image = gr.Image(label="Annotated Receipt")
166
- forgery_status = gr.Textbox(label="Forgery Status", lines=4)
167
 
168
  with gr.Row():
169
- vendor_out = gr.Textbox(label="Vendor")
170
- date_out = gr.Textbox(label="Date")
171
- total_out = gr.Textbox(label="Total")
172
 
173
- json_out = gr.Code(label="JSON Output", language="json")
174
 
175
- submit_btn.click(
176
- fn = process_receipt,
177
- inputs = [input_image],
178
- outputs = [output_image, forgery_status, vendor_out, date_out, total_out, json_out],
179
  )
180
 
181
- gr.Markdown("**Models:** LayoutLM v1 (extraction) + EfficientNet-B0 (forgery)")
182
-
183
  if __name__ == "__main__":
184
- demo.launch(share=True)
 
1
+
 
 
 
 
2
 
3
  import os
4
+ import subprocess
5
+
6
+
7
+ libs = ["easyocr", "transformers", "torchvision", "gradio", "huggingface_hub"]
8
+ for lib in libs:
9
+ subprocess.run(["pip", "install", lib, "-q"])
10
 
11
  import re
12
  import json
13
  import torch
14
  import numpy as np
15
+ import easyocr
16
+ import gradio as gr
17
  from PIL import Image, ImageDraw
18
  from torchvision import transforms, models
19
  from transformers import LayoutLMForTokenClassification, BertTokenizerFast
20
  from huggingface_hub import hf_hub_download
21
+
22
 
23
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  LABELS = ["O", "B-VENDOR", "I-VENDOR", "B-DATE", "I-DATE", "B-TOTAL", "I-TOTAL"]
25
  ID2LABEL = {i: x for i, x in enumerate(LABELS)}
26
  LABEL2ID = {x: i for i, x in enumerate(LABELS)}
27
 
28
+ print(f"Status: Loading models on {DEVICE}...")
29
+
30
+
31
+ reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
32
+
33
+
34
+ tokenizer = BertTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased")
35
  extraction_model = LayoutLMForTokenClassification.from_pretrained(
36
  "Zakariya007/docfusion-v1",
37
  num_labels=len(LABELS), id2label=ID2LABEL, label2id=LABEL2ID,
38
+ ).to(DEVICE)
 
39
  extraction_model.eval()
40
 
41
+
42
  forgery_model = models.efficientnet_b0(weights=None)
43
  forgery_model.classifier[1] = torch.nn.Linear(1280, 2)
44
+ weights_path = hf_hub_download(repo_id="Zakariya007/docfusion-v2", filename="efficientnet_best.pth")
 
 
 
45
  forgery_model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
46
  forgery_model = forgery_model.to(DEVICE)
47
  forgery_model.eval()
48
+
49
+ print("All systems ready!")
50
 
51
  def extract_fields(image):
52
  try:
53
+ img_np = np.array(image)
54
+ results = reader.readtext(img_np)
55
+
56
+ full_text = " ".join([res[1] for res in results])
57
+ lines = [res[1].strip() for res in results if len(res[1].strip()) > 2]
58
+
59
+
60
+ date_match = re.search(r'\d{1,2}[\/\-\.]\d{1,2}[\/\-\.]\d{2,4}', full_text)
61
  date = date_match.group(0) if date_match else None
62
 
63
+ total_match = re.search(r'(?:TOTAL|AMOUNT|NET|DUE|CASH|SUBTOTAL)[^\d]*([\d,]+\.\d{2})', full_text, re.IGNORECASE)
 
 
 
64
  total = total_match.group(1) if total_match else None
65
 
 
66
  vendor = lines[0] if lines else None
67
 
68
  return vendor, date, total
 
69
  except Exception as e:
70
  print(f"Extraction error: {e}")
71
  return None, None, None
72
 
73
+ def detect_forgery_pure_model(image):
74
+ """Detects forgery using only the neural network output."""
75
  transform = transforms.Compose([
76
  transforms.Resize((224, 224)),
77
  transforms.ToTensor(),
 
79
  ])
80
  tensor = transform(image).unsqueeze(0).to(DEVICE)
81
  with torch.no_grad():
82
+ output = forgery_model(tensor)
83
+ probs = torch.softmax(output, dim=1)
84
  forged_prob = probs[0][1].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
+ is_forged = 1 if forged_prob > 0.5 else 0
88
+ return is_forged, forged_prob
89
+
90
  def process_receipt(image):
91
+ if image is None: return None, "No image uploaded", "", "", "", ""
92
+
93
+ pil_img = Image.fromarray(image).convert("RGB")
94
+
95
+
96
+ vendor, date, total = extract_fields(pil_img)
97
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ is_forged, prob = detect_forgery_pure_model(pil_img)
100
 
101
+
102
+ annotated = pil_img.copy()
103
+ draw = ImageDraw.Draw(annotated)
104
+ color = "#FF4B4B" if is_forged else "#24A148"
105
+ draw.rectangle([0, 0, pil_img.size[0]-1, pil_img.size[1]-1], outline=color, width=12)
106
+
107
+ status = f"Result: {'SUSPECTED FORGERY' if is_forged else 'LIKELY GENUINE'}"
108
+ confidence_str = f"Model Confidence: {prob if is_forged else (1-prob):.1%}"
109
+ full_status = f"{status}\n{confidence_str}"
110
+
111
+ res_json = json.dumps({
112
+ "vendor": vendor,
113
+ "date": date,
114
+ "total": total,
115
+ "forgery_score": round(prob, 4),
116
+ "is_forged": bool(is_forged)
117
+ }, indent=2)
118
+
119
+ return np.array(annotated), full_status, vendor or "N/A", date or "N/A", total or "N/A", res_json
120
+
121
+
122
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
123
+ gr.Markdown("# 📑 DocFusion: Receipt Intelligence (V2)")
124
+ gr.Markdown("Visual Forgery Detection + Deep Learning OCR")
125
 
126
  with gr.Row():
127
  with gr.Column():
128
+ in_img = gr.Image(label="Upload Receipt Scan", type="numpy")
129
+ btn = gr.Button("Analyze Receipt", variant="primary")
130
  with gr.Column():
131
+ out_img = gr.Image(label="Visual Analysis")
132
+ out_stat = gr.Textbox(label="Forgery Detection Status", lines=2)
133
 
134
  with gr.Row():
135
+ v_out = gr.Textbox(label="Vendor")
136
+ d_out = gr.Textbox(label="Date")
137
+ t_out = gr.Textbox(label="Total Amount")
138
 
139
+ js_out = gr.Code(label="Metadata Output (JSON)", language="json")
140
 
141
+ btn.click(
142
+ process_receipt,
143
+ inputs=[in_img],
144
+ outputs=[out_img, out_stat, v_out, d_out, t_out, js_out]
145
  )
146
 
 
 
147
  if __name__ == "__main__":
148
+ demo.launch(debug=True)