GSoumyajit2005 commited on
Commit
097a95c
·
1 Parent(s): e81f779

refactor: remove obsolete OCR test file and enhance address extraction logic

Browse files
Files changed (5) hide show
  1. README.md +0 -2
  2. src/extraction.py +103 -1
  3. src/ml_extraction.py +36 -1
  4. src/schema.py +8 -3
  5. tests/test_ocr.py +0 -101
README.md CHANGED
@@ -156,7 +156,6 @@ _UI shows simple format hints and confidence._
156
  ### Prerequisites
157
 
158
  - Python 3.10+
159
- - Tesseract OCR
160
  - (Optional) CUDA-capable GPU for training/inference speed
161
 
162
  ### Installation
@@ -342,7 +341,6 @@ invoice-processor-ml/
342
  ├── tests/
343
  │ ├── test_extraction.py # Tests for regex extraction module
344
  │ ├── test_full_pipeline.py # Full end-to-end integration tests
345
- │ ├── test_ocr.py # Tests for the OCR module
346
  │ ├── test_pipeline.py # Pipeline process tests
347
  │ └── test_preprocessing.py # Tests for the preprocessing module
348
 
 
156
  ### Prerequisites
157
 
158
  - Python 3.10+
 
159
  - (Optional) CUDA-capable GPU for training/inference speed
160
 
161
  ### Installation
 
341
  ├── tests/
342
  │ ├── test_extraction.py # Tests for regex extraction module
343
  │ ├── test_full_pipeline.py # Full end-to-end integration tests
 
344
  │ ├── test_pipeline.py # Pipeline process tests
345
  │ └── test_preprocessing.py # Tests for the preprocessing module
346
 
src/extraction.py CHANGED
@@ -3,6 +3,7 @@
3
  import re
4
  from typing import List, Dict, Optional, Any
5
  from datetime import datetime
 
6
 
7
  def extract_dates(text: str) -> List[str]:
8
  """
@@ -132,7 +133,7 @@ def extract_invoice_number(text: str) -> Optional[str]:
132
  for line in lines[:25]: # Scan top 25 lines
133
  line_upper = line.upper()
134
 
135
- # ⚠️ CRITICAL FIX: Skip lines that look like Tax IDs (GST/REG)
136
  # But allow if the line explicitly says "INVOICE" (e.g. "Tax Invoice / GST Reg No")
137
  if any(bad in line_upper for bad in TOXIC_LINE_INDICATORS) and "INVOICE" not in line_upper:
138
  continue
@@ -165,6 +166,107 @@ def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
165
  return {"name": name, "email": None}
166
  return None
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def extract_line_items(text: str) -> List[Dict[str, Any]]:
169
  return []
170
 
 
3
  import re
4
  from typing import List, Dict, Optional, Any
5
  from datetime import datetime
6
+ from difflib import SequenceMatcher
7
 
8
  def extract_dates(text: str) -> List[str]:
9
  """
 
133
  for line in lines[:25]: # Scan top 25 lines
134
  line_upper = line.upper()
135
 
136
+ # CRITICAL FIX: Skip lines that look like Tax IDs (GST/REG)
137
  # But allow if the line explicitly says "INVOICE" (e.g. "Tax Invoice / GST Reg No")
138
  if any(bad in line_upper for bad in TOXIC_LINE_INDICATORS) and "INVOICE" not in line_upper:
139
  continue
 
166
  return {"name": name, "email": None}
167
  return None
168
 
169
+ def extract_address(text: str, vendor_name: Optional[str] = None) -> Optional[str]:
170
+ """
171
+ Generalized Address Extraction using Spatial Heuristics.
172
+ Strategy:
173
+ 1. If Vendor is known, look at the lines immediately FOLLOWING it (Spatial).
174
+ 2. If Vendor is unknown, look for lines in the top header with 'Address-like' traits
175
+ (mix of text + numbers, 3+ words, contains Zip-code-like patterns).
176
+ """
177
+ if not text: return None
178
+
179
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
180
+
181
+ # --- FILTERS (Generalized) ---
182
+ # Skip lines that are clearly NOT addresses
183
+ def is_invalid_line(line):
184
+ line_upper = line.upper()
185
+ # 1. It's a Phone/Fax/Email/URL
186
+ if any(x in line_upper for x in ['TEL', 'FAX', 'PHONE', 'EMAIL', '@', 'WWW.', '.COM', 'HTTP']):
187
+ return True
188
+ # 2. It's a Date
189
+ if len(line) < 15 and any(c.isdigit() for c in line) and ('/' in line or '-' in line):
190
+ return True
191
+ # 3. It's the Vendor name itself (if provided)
192
+ if vendor_name and vendor_name.lower() in line.lower():
193
+ return True
194
+ return False
195
+
196
+ # --- STRATEGY 1: Contextual Search (Below Vendor) ---
197
+ # This is the most accurate method for receipts worldwide.
198
+ candidate_lines = []
199
+
200
+ if vendor_name:
201
+ vendor_found = False
202
+ # Find where the vendor appears
203
+ for i, line in enumerate(lines[:15]): # Check top 15 lines only
204
+ if vendor_name.lower() in line.lower() or (len(vendor_name) > 5 and SequenceMatcher(None, vendor_name, line).ratio() > 0.8):
205
+ vendor_found = True
206
+ # Grab the next 1-3 lines as the potential address block
207
+ # We stop if we hit a phone number or blank line
208
+ for j in range(1, 4):
209
+ if i + j < len(lines):
210
+ next_line = lines[i + j]
211
+ if not is_invalid_line(next_line):
212
+ candidate_lines.append(next_line)
213
+ else:
214
+ # If we hit a phone number, the address block usually ended
215
+ break
216
+ break
217
+
218
+ # If Strategy 1 found something, join it and return
219
+ if candidate_lines:
220
+ return ", ".join(candidate_lines)
221
+
222
+ # --- STRATEGY 2: Header Scan (Density Heuristic) ---
223
+ # If we couldn't anchor to the vendor, we scan the top 10 lines for "Address-looking" text.
224
+ # An address usually has:
225
+ # - At least one digit (Building number, Zip code)
226
+ # - At least 3 words
227
+ # - Is NOT a phone number
228
+ #
229
+ # CONTIGUITY RULE: Once we start collecting candidates, we STOP at the first
230
+ # invalid line (phone/fax/etc). This prevents capturing non-adjacent lines
231
+ # like GST numbers that appear after phone numbers.
232
+
233
+ fallback_candidates = []
234
+ started_collecting = False
235
+
236
+ for line in lines[:10]:
237
+ if is_invalid_line(line):
238
+ # If we've already started collecting, an invalid line means
239
+ # the address block has ended - don't continue past it
240
+ if started_collecting:
241
+ break
242
+ continue
243
+
244
+ # Check for Address Density:
245
+ # 1. Has digits (e.g. "123 Main St" or "Singapore 55123")
246
+ has_digits = any(c.isdigit() for c in line)
247
+ # 2. Length is substantial (avoid short noise)
248
+ is_long_enough = len(line) > 10
249
+ # 3. Has spaces (at least 2 spaces => 3 words)
250
+ is_multi_word = line.count(' ') >= 2
251
+
252
+ # FIRST line must have digits (to anchor on building/street number)
253
+ # CONTINUATION lines only need length + multi-word (city/state names often lack digits)
254
+ is_valid_first_line = has_digits and is_long_enough and is_multi_word
255
+ is_valid_continuation = started_collecting and is_long_enough and is_multi_word
256
+
257
+ if is_valid_first_line or is_valid_continuation:
258
+ # We found a strong candidate line
259
+ fallback_candidates.append(line)
260
+ started_collecting = True
261
+ # If we have 3 candidates, that's probably the full address block
262
+ if len(fallback_candidates) >= 3:
263
+ break
264
+
265
+ if fallback_candidates:
266
+ return ", ".join(fallback_candidates)
267
+
268
+ return None
269
+
270
  def extract_line_items(text: str) -> List[Dict[str, Any]]:
271
  return []
272
 
src/ml_extraction.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
8
  from typing import List, Dict, Any, Tuple
9
  import re
10
  import numpy as np
11
- from extraction import extract_invoice_number, extract_total
12
  from doctr.io import DocumentFile
13
  from doctr.models import ocr_predictor
14
 
@@ -254,6 +254,41 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
254
  largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
255
  final_output["vendor"] = words[largest_idx]
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # Fallbacks
258
  ml_total = extracted_entities.get("TOTAL", {}).get("text")
259
  if ml_total:
 
8
  from typing import List, Dict, Any, Tuple
9
  import re
10
  import numpy as np
11
+ from extraction import extract_invoice_number, extract_total, extract_address
12
  from doctr.io import DocumentFile
13
  from doctr.models import ocr_predictor
14
 
 
254
  largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
255
  final_output["vendor"] = words[largest_idx]
256
 
257
+ # --- ADDRESS FALLBACK ---
258
+ if not final_output["address"]:
259
+ # We pass the extracted (or fallback) Vendor Name to help anchor the search
260
+ # Use the raw text and the known vendor to find the address spatially
261
+ fallback_address = extract_address(raw_text, vendor_name=final_output["vendor"])
262
+
263
+ if fallback_address:
264
+ final_output["address"] = fallback_address
265
+
266
+ # Backfill Bounding Boxes for Address Fallback
267
+ # If Regex found the address but ML didn't, find its boxes in the OCR data
268
+ if final_output["address"] and "ADDRESS" not in final_output["raw_predictions"]:
269
+ address_text = final_output["address"]
270
+ address_boxes = []
271
+
272
+ # The address may span multiple words, so we search for each word
273
+ # Split by comma first (since extract_address joins lines with ", ")
274
+ address_parts = [part.strip() for part in address_text.split(",")]
275
+
276
+ for part in address_parts:
277
+ part_words = part.split()
278
+ for target_word in part_words:
279
+ for i, word in enumerate(words):
280
+ # Case-insensitive match
281
+ if target_word.lower() == word.lower() or target_word.lower() in word.lower():
282
+ address_boxes.append(unnormalized_boxes[i])
283
+ break # Only match once per target word
284
+
285
+ # If we found any boxes, inject into raw_predictions
286
+ if address_boxes:
287
+ final_output["raw_predictions"]["ADDRESS"] = {
288
+ "text": address_text,
289
+ "bbox": address_boxes
290
+ }
291
+
292
  # Fallbacks
293
  ml_total = extracted_entities.get("TOTAL", {}).get("text")
294
  if ml_total:
src/schema.py CHANGED
@@ -64,9 +64,14 @@ class InvoiceData(BaseModel):
64
  if isinstance(v, str):
65
  try:
66
  # Try common formats
67
- for fmt in ("%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y"):
 
 
 
68
  try:
69
  parsed_date = datetime.strptime(v, fmt).date()
 
 
70
  break
71
  except ValueError:
72
  continue
@@ -78,8 +83,8 @@ class InvoiceData(BaseModel):
78
  if parsed_date > today:
79
  return None
80
 
81
- # ⚠️ FIX: Use 'DateType' constructor
82
- min_date = DateType(today.year - 10, 1, 1)
83
  if parsed_date < min_date:
84
  return None
85
 
 
64
  if isinstance(v, str):
65
  try:
66
  # Try common formats
67
+ for fmt in (
68
+ "%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y",
69
+ "%m/%d/%Y", "%m-%d-%Y"
70
+ ):
71
  try:
72
  parsed_date = datetime.strptime(v, fmt).date()
73
+ # Sanity check: If we parsed 05/01/2020, was it May 1st or Jan 5th?
74
+ # Usually, if we are here, strict parsing succeeded.
75
  break
76
  except ValueError:
77
  continue
 
83
  if parsed_date > today:
84
  return None
85
 
86
+ # FIX: Use 'DateType' constructor
87
+ min_date = DateType(today.year - 30, 1, 1)
88
  if parsed_date < min_date:
89
  return None
90
 
tests/test_ocr.py DELETED
@@ -1,101 +0,0 @@
1
- import sys
2
- sys.path.append('src')
3
-
4
- from preprocessing import load_image, convert_to_grayscale, remove_noise
5
- from ocr import extract_text
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
-
9
- print("=" * 60)
10
- print("🎯 OPTIMIZING GRAYSCALE OCR")
11
- print("=" * 60)
12
-
13
- # Load and convert to grayscale
14
- image = load_image('data/raw/receipt3.jpg')
15
- gray = convert_to_grayscale(image)
16
-
17
- # Test 1: Different PSM modes
18
- print("\n📊 Testing different Tesseract PSM modes...\n")
19
-
20
- psm_configs = [
21
- ('', 'Default'),
22
- ('--psm 3', 'Automatic page segmentation'),
23
- ('--psm 4', 'Single column of text'),
24
- ('--psm 6', 'Uniform block of text'),
25
- ('--psm 11', 'Sparse text, find as much as possible'),
26
- ('--psm 12', 'Sparse text with OSD (Orientation and Script Detection)'),
27
- ]
28
-
29
- results = {}
30
- for config, desc in psm_configs:
31
- text = extract_text(gray, config=config)
32
- results[desc] = text
33
- print(f"{desc:50s} → {len(text):4d} chars")
34
-
35
- # Find best result
36
- best_desc = max(results, key=lambda k: len(results[k]))
37
- best_text = results[best_desc]
38
-
39
- print(f"\n✅ WINNER: {best_desc} ({len(best_text)} chars)")
40
-
41
- # Test 2: With slight denoising
42
- print("\n📊 Testing with light denoising...\n")
43
-
44
- denoised = remove_noise(gray, kernel_size=3)
45
- text_denoised = extract_text(denoised, config='--psm 6')
46
- print(f"Grayscale + Denoise (psm 6): {len(text_denoised)} chars")
47
-
48
-
49
- # Display best result
50
- print("\n" + "=" * 60)
51
- print("📄 BEST EXTRACTED TEXT:")
52
- print("=" * 60)
53
- print(best_text)
54
- print("=" * 60)
55
-
56
- # Visualize
57
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
58
-
59
- axes[0].imshow(image)
60
- axes[0].set_title("Original")
61
- axes[0].axis('off')
62
-
63
- axes[1].imshow(gray, cmap='gray')
64
- axes[1].set_title(f"Grayscale\n({len(best_text)} chars - {best_desc})")
65
- axes[1].axis('off')
66
-
67
- axes[2].imshow(denoised, cmap='gray')
68
- axes[2].set_title(f"Denoised\n({len(text_denoised)} chars)")
69
- axes[2].axis('off')
70
-
71
- plt.tight_layout()
72
- plt.show()
73
-
74
- print(f"\n💡 Recommended pipeline: Grayscale + {best_desc}")
75
-
76
- # Test the combination we missed!
77
- print("\n📊 Testing BEST combination...\n")
78
-
79
- denoised = remove_noise(gray, kernel_size=3)
80
-
81
- # Test PSM 11 on denoised
82
- text_denoised_psm11 = extract_text(denoised, config='--psm 11')
83
- text_denoised_psm6 = extract_text(denoised, config='--psm 6')
84
-
85
- print(f"Denoised + PSM 6: {len(text_denoised_psm6)} chars")
86
- print(f"Denoised + PSM 11: {len(text_denoised_psm11)} chars")
87
-
88
- if len(text_denoised_psm11) > len(text_denoised_psm6):
89
- print(f"\n✅ PSM 11 wins! ({len(text_denoised_psm11)} chars)")
90
- best_config = '--psm 11'
91
- best_text_final = text_denoised_psm11
92
- else:
93
- print(f"\n✅ PSM 6 wins! ({len(text_denoised_psm6)} chars)")
94
- best_config = '--psm 6'
95
- best_text_final = text_denoised_psm6
96
-
97
- print(f"\n🏆 FINAL WINNER: Denoised + {best_config}")
98
- print("\nFull text:")
99
- print("=" * 60)
100
- print(best_text_final)
101
- print("=" * 60)