GSoumyajit2005 commited on
Commit
ec0b507
·
1 Parent(s): 343b0c3

Refactor: Replace Tesseract with DocTR and integrate LayoutLMv3-DocTR model

Browse files

Major overhaul of OCR/Inference pipeline. Swapped Tesseract for DocTR, retrained LayoutLMv3 (~83% F1), and fixed address extraction using Fuzzy Matching.

.gitignore CHANGED
@@ -23,6 +23,7 @@ credentials.json
23
  *.log
24
  logs/
25
  .cache/
 
26
 
27
  # OS
28
  .DS_Store
 
23
  *.log
24
  logs/
25
  .cache/
26
+ *.pkl
27
 
28
  # OS
29
  .DS_Store
Dockerfile CHANGED
@@ -1,10 +1,13 @@
1
  # Use an official Python runtime
2
  FROM python:3.10-slim
3
 
4
- # 1. Install system dependencies (Tesseract + OpenCV + POPPLER)
5
- # Added poppler-utils because src/pdf_utils.py uses pdf2image
6
  RUN apt-get update && apt-get install -y \
7
- tesseract-ocr \
 
 
 
8
  poppler-utils \
9
  ffmpeg libsm6 libxext6 \
10
  && rm -rf /var/lib/apt/lists/*
 
1
  # Use an official Python runtime
2
  FROM python:3.10-slim
3
 
4
+ # 1. Install system dependencies (DocTR + OpenCV + POPPLER)
5
+ # DocTR requires OpenGL and GStreamer libraries for image processing
6
  RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-dev \
8
+ libglib2.0-0 \
9
+ libgstreamer1.0-0 \
10
+ libgstreamer-plugins-base1.0-0 \
11
  poppler-utils \
12
  ffmpeg libsm6 libxext6 \
13
  && rm -rf /var/lib/apt/lists/*
README.md CHANGED
@@ -374,7 +374,7 @@ invoice-processor-ml/
374
 
375
  ## ⚠️ Known Limitations
376
 
377
- 1. **Layout Sensitivity**: The ML model was fine‑tuned only on SROIE (retail receipts). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
378
  2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
379
  3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
380
  4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
 
374
 
375
  ## ⚠️ Known Limitations
376
 
377
+ 1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
378
  2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
379
  3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
380
  4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  streamlit>=1.28.0
3
 
4
  # ----- OCR -----
5
- pytesseract>=0.3.10
6
  opencv-python>=4.8.0
7
  Pillow>=10.0.0
8
 
 
2
  streamlit>=1.28.0
3
 
4
  # ----- OCR -----
5
+ python-doctr[torch]>=0.8.0
6
  opencv-python>=4.8.0
7
  Pillow>=10.0.0
8
 
scripts/prepare_doctr_data.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/prepare_doctr_data.py
2
+
3
+ """
4
+ Prepare training data using DocTR OCR output.
5
+
6
+ This script:
7
+ 1. Iterates through SROIE training/test images
8
+ 2. Runs DocTR OCR to get words and boxes
9
+ 3. Aligns DocTR output with ground truth labels using fuzzy matching
10
+ 4. Saves the aligned dataset to a pickle file for training
11
+
12
+ This ensures the model learns from DocTR's actual output (with its specific errors)
13
+ rather than from perfect ground truth which it will never see in production.
14
+ """
15
+
16
+ import torch
17
+ import sys
18
+ import os
19
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
20
+
21
+ import json
22
+ import pickle
23
+ from pathlib import Path
24
+ from PIL import Image
25
+ from tqdm import tqdm
26
+ from difflib import SequenceMatcher
27
+ from typing import List, Dict, Any, Tuple, Optional
28
+
29
+ from doctr.io import DocumentFile
30
+ from doctr.models import ocr_predictor
31
+
32
+ # --- CONFIGURATION ---
33
+ SROIE_DATA_PATH = "data/sroie"
34
+ OUTPUT_CACHE_PATH = "data/doctr_trained_cache.pkl"
35
+
36
+ # Ground truth field names and their corresponding BIO labels
37
+ GT_FIELD_MAPPING = {
38
+ "company": "COMPANY",
39
+ "date": "DATE",
40
+ "address": "ADDRESS",
41
+ "total": "TOTAL",
42
+ }
43
+
44
+
45
+ def load_doctr_predictor():
46
+ """Initialize DocTR predictor with lightweight backbone and move to GPU."""
47
+ print("Loading DocTR OCR predictor...")
48
+
49
+ # 1. Initialize the model
50
+ predictor = ocr_predictor(
51
+ det_arch='db_resnet50',
52
+ reco_arch='crnn_vgg16_bn',
53
+ pretrained=True
54
+ )
55
+
56
+ # 2. Force it to GPU if available
57
+ if torch.cuda.is_available():
58
+ print("🚀 Moving DocTR to GPU (CUDA)...")
59
+ predictor.cuda()
60
+ else:
61
+ print("⚠️ GPU not found. Running on CPU (this will be slow).")
62
+
63
+ print("DocTR OCR predictor ready.")
64
+ return predictor
65
+
66
+
67
+ def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]]]:
68
+ """
69
+ Parse DocTR output into words and normalized boxes (0-1000 scale).
70
+
71
+ Returns:
72
+ words: List of word strings
73
+ normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
74
+ """
75
+ words = []
76
+ normalized_boxes = []
77
+
78
+ for page in doctr_result.pages:
79
+ for block in page.blocks:
80
+ for line in block.lines:
81
+ for word in line.words:
82
+ if not word.value.strip():
83
+ continue
84
+
85
+ words.append(word.value)
86
+
87
+ # DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
88
+ (x_min, y_min), (x_max, y_max) = word.geometry
89
+
90
+ # Normalize to 0-1000 scale with clamping
91
+ normalized_boxes.append([
92
+ max(0, min(1000, int(x_min * 1000))),
93
+ max(0, min(1000, int(y_min * 1000))),
94
+ max(0, min(1000, int(x_max * 1000))),
95
+ max(0, min(1000, int(y_max * 1000))),
96
+ ])
97
+
98
+ return words, normalized_boxes
99
+
100
+
101
+ def fuzzy_match_score(s1: str, s2: str) -> float:
102
+ """Calculate fuzzy match score between two strings."""
103
+ return SequenceMatcher(None, s1.lower(), s2.lower()).ratio()
104
+
105
+
106
+ def find_entity_in_words(
107
+ entity_text: str,
108
+ words: List[str],
109
+ start_idx: int = 0,
110
+ threshold: float = 0.7
111
+ ) -> Optional[Tuple[int, int]]:
112
+ """
113
+ Find a ground truth entity in the DocTR words using fuzzy matching.
114
+ Includes expansion search to handle OCR word splitting.
115
+ """
116
+ entity_words = entity_text.split()
117
+ n_target = len(entity_words)
118
+
119
+ # 1. Single word match
120
+ if n_target == 1:
121
+ best_score = 0
122
+ best_idx = -1
123
+ for i in range(start_idx, len(words)):
124
+ score = fuzzy_match_score(entity_text, words[i])
125
+ if score > best_score and score >= threshold:
126
+ best_score = score
127
+ best_idx = i
128
+ if best_idx >= 0:
129
+ return (best_idx, best_idx)
130
+
131
+ # 2. Multi-word entity: Flexible Window Search
132
+ # We search windows of size N, N+1, N+2... up to N+5 (to catch OCR splits)
133
+ # AND N-1, N-2... (to catch OCR merges)
134
+
135
+ best_match_score = 0.0
136
+ best_match_indices = None
137
+
138
+ # Define search range: from (Length - 3) to (Length + 5)
139
+ min_len = max(1, n_target - 3)
140
+ max_len = min(len(words) - start_idx, n_target + 5)
141
+
142
+ combined_entity_text = " ".join(entity_words)
143
+
144
+ # Iterate through window sizes
145
+ for window_size in range(min_len, max_len + 1):
146
+ for i in range(start_idx, len(words) - window_size + 1):
147
+
148
+ # Construct window text
149
+ window_tokens = words[i : i + window_size]
150
+ window_text = " ".join(window_tokens)
151
+
152
+ score = fuzzy_match_score(combined_entity_text, window_text)
153
+
154
+ # Optimization: If perfect match, return immediately
155
+ if score > 0.95:
156
+ return (i, i + window_size - 1)
157
+
158
+ if score > best_match_score and score >= threshold:
159
+ best_match_score = score
160
+ best_match_indices = (i, i + window_size - 1)
161
+
162
+ return best_match_indices
163
+
164
+
165
+ def load_ground_truth(json_path: Path) -> Dict[str, str]:
166
+ """
167
+ Load ground truth entities from the tagged JSON.
168
+
169
+ The SROIE tagged JSON has: {"words": [...], "bbox": [...], "labels": [...]}
170
+ We need to reconstruct the entity values from words + labels.
171
+ """
172
+ with open(json_path, encoding="utf-8") as f:
173
+ data = json.load(f)
174
+
175
+ words = data.get("words", [])
176
+ labels = data.get("labels", [])
177
+
178
+ # Reconstruct entities from BIO tags
179
+ entities = {}
180
+ current_entity = None
181
+ current_text = []
182
+
183
+ for word, label in zip(words, labels):
184
+ if label.startswith("B-"):
185
+ # Save previous entity if exists
186
+ if current_entity and current_text:
187
+ entities[current_entity.lower()] = " ".join(current_text)
188
+
189
+ # Start new entity
190
+ current_entity = label[2:] # Remove "B-" prefix
191
+ current_text = [word]
192
+
193
+ elif label.startswith("I-") and current_entity:
194
+ entity_type = label[2:]
195
+ if entity_type == current_entity:
196
+ current_text.append(word)
197
+ else:
198
+ # Entity type changed, save current
199
+ if current_text:
200
+ entities[current_entity.lower()] = " ".join(current_text)
201
+ current_entity = None
202
+ current_text = []
203
+ else:
204
+ # "O" label - save current entity if exists
205
+ if current_entity and current_text:
206
+ entities[current_entity.lower()] = " ".join(current_text)
207
+ current_entity = None
208
+ current_text = []
209
+
210
+ # Don't forget the last entity
211
+ if current_entity and current_text:
212
+ entities[current_entity.lower()] = " ".join(current_text)
213
+
214
+ return entities
215
+
216
+
217
+ def align_labels(
218
+ doctr_words: List[str],
219
+ ground_truth: Dict[str, str]
220
+ ) -> List[str]:
221
+ labels = ["O"] * len(doctr_words)
222
+ used_indices = set()
223
+
224
+ for gt_field, bio_label in GT_FIELD_MAPPING.items():
225
+ if gt_field not in ground_truth:
226
+ continue
227
+
228
+ entity_text = ground_truth[gt_field]
229
+ if not entity_text or not entity_text.strip():
230
+ continue
231
+
232
+ # DYNAMIC THRESHOLD: Be lenient with Addresses, strict with Dates/Totals
233
+ current_threshold = 0.6
234
+ if bio_label == "ADDRESS":
235
+ current_threshold = 0.45 # Lower threshold for messy addresses
236
+ elif bio_label in ["DATE", "TOTAL"]:
237
+ current_threshold = 0.7 # Keep strict for precision fields
238
+
239
+ match = find_entity_in_words(entity_text, doctr_words, start_idx=0, threshold=current_threshold)
240
+
241
+ if match:
242
+ start_idx, end_idx = match
243
+
244
+ # Overlap check
245
+ if any(i in used_indices for i in range(start_idx, end_idx + 1)):
246
+ continue
247
+
248
+ labels[start_idx] = f"B-{bio_label}"
249
+ for i in range(start_idx + 1, end_idx + 1):
250
+ labels[i] = f"I-{bio_label}"
251
+
252
+ used_indices.update(range(start_idx, end_idx + 1))
253
+
254
+ return labels
255
+
256
+
257
+ def process_split(
258
+ split_path: Path,
259
+ predictor,
260
+ split_name: str
261
+ ) -> List[Dict[str, Any]]:
262
+ """Process all images in a split directory."""
263
+
264
+ # Find image and annotation directories
265
+ if (split_path / "images").exists():
266
+ img_dir = split_path / "images"
267
+ elif (split_path / "img").exists():
268
+ img_dir = split_path / "img"
269
+ else:
270
+ print(f" ⚠️ No image directory found in {split_path}")
271
+ return []
272
+
273
+ if (split_path / "tagged").exists():
274
+ ann_dir = split_path / "tagged"
275
+ elif (split_path / "box").exists():
276
+ ann_dir = split_path / "box"
277
+ else:
278
+ print(f" ⚠️ No annotation directory found in {split_path}")
279
+ return []
280
+
281
+ examples = []
282
+ image_files = sorted([f for f in img_dir.iterdir() if f.suffix.lower() in [".jpg", ".png"]])
283
+
284
+ print(f" Processing {len(image_files)} images in {split_name}...")
285
+
286
+ for img_file in tqdm(image_files, desc=f" {split_name}"):
287
+ try:
288
+ # Check for corresponding annotation
289
+ json_path = ann_dir / f"{img_file.stem}.json"
290
+ if not json_path.exists():
291
+ continue
292
+
293
+ # Load image dimensions
294
+ with Image.open(img_file) as img:
295
+ width, height = img.size
296
+
297
+ # Run DocTR OCR
298
+ doc = DocumentFile.from_images(str(img_file))
299
+ doctr_result = predictor(doc)
300
+
301
+ # Parse DocTR output
302
+ words, boxes = parse_doctr_output(doctr_result, width, height)
303
+
304
+ if not words:
305
+ continue
306
+
307
+ # Load ground truth and align labels
308
+ ground_truth = load_ground_truth(json_path)
309
+ aligned_labels = align_labels(words, ground_truth)
310
+
311
+ # Create example
312
+ examples.append({
313
+ "image_path": str(img_file),
314
+ "words": words,
315
+ "bboxes": boxes,
316
+ "ner_tags": aligned_labels,
317
+ "ground_truth": ground_truth # Keep for debugging
318
+ })
319
+
320
+ except Exception as e:
321
+ print(f"\n ❌ Error processing {img_file.name}: {e}")
322
+ continue
323
+
324
+ return examples
325
+
326
+
327
+ def main():
328
+ print("=" * 60)
329
+ print("📦 DocTR Training Data Preparation")
330
+ print("=" * 60)
331
+
332
+ sroie_path = Path(SROIE_DATA_PATH)
333
+
334
+ if not sroie_path.exists():
335
+ print(f"❌ SROIE path not found: {sroie_path}")
336
+ return
337
+
338
+ # Load DocTR predictor
339
+ predictor = load_doctr_predictor()
340
+
341
+ dataset = {"train": [], "test": []}
342
+
343
+ # Process each split
344
+ for split in ["train", "test"]:
345
+ split_path = sroie_path / split
346
+ if not split_path.exists():
347
+ print(f" ⚠️ Split not found: {split}")
348
+ continue
349
+
350
+ print(f"\n📂 Processing {split} split...")
351
+ examples = process_split(split_path, predictor, split)
352
+ dataset[split] = examples
353
+
354
+ # Stats
355
+ total_entities = sum(
356
+ sum(1 for label in ex["ner_tags"] if label.startswith("B-"))
357
+ for ex in examples
358
+ )
359
+ print(f" ✅ {len(examples)} images processed")
360
+ print(f" 📊 {total_entities} entities aligned")
361
+
362
+ # Save cache
363
+ print(f"\n💾 Saving cache to {OUTPUT_CACHE_PATH}...")
364
+ output_path = Path(OUTPUT_CACHE_PATH)
365
+ output_path.parent.mkdir(parents=True, exist_ok=True)
366
+
367
+ with open(output_path, "wb") as f:
368
+ pickle.dump(dataset, f)
369
+
370
+ print(f"✅ Cache saved!")
371
+ print(f" - Train examples: {len(dataset['train'])}")
372
+ print(f" - Test examples: {len(dataset['test'])}")
373
+ print("=" * 60)
374
+
375
+
376
+ if __name__ == "__main__":
377
+ main()
scripts/train_combined.py CHANGED
@@ -13,6 +13,7 @@ from pathlib import Path
13
  import numpy as np
14
  import random
15
  import os
 
16
 
17
  # --- IMPORTS ---
18
  from src.sroie_loader import load_sroie
@@ -21,8 +22,9 @@ from src.data_loader import load_unified_dataset
21
  # --- CONFIGURATION ---
22
  # Points to your local SROIE copy
23
  SROIE_DATA_PATH = "data/sroie"
 
24
  MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
25
- OUTPUT_DIR = "models/layoutlmv3-generalized"
26
 
27
  # Standard Label Set
28
  LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
@@ -86,18 +88,34 @@ class UnifiedDataset(Dataset):
86
 
87
  return {k: v.squeeze(0) for k, v in encoding.items()}
88
 
 
 
 
 
 
 
 
 
 
 
89
  def train():
90
  print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
91
 
92
- # Check SROIE path
93
- if not os.path.exists(SROIE_DATA_PATH):
94
- print(f" Error: SROIE path not found at {SROIE_DATA_PATH}")
95
- print("Please make sure you copied the 'sroie' folder into 'data/'.")
96
- return
97
-
98
- # 1. Load SROIE
99
- print("📦 Loading SROIE dataset...")
100
- sroie_data = load_sroie(SROIE_DATA_PATH)
 
 
 
 
 
 
101
  print(f" - SROIE Train: {len(sroie_data['train'])}")
102
  print(f" - SROIE Test: {len(sroie_data['test'])}")
103
 
@@ -141,7 +159,7 @@ def train():
141
  # 6. Optimize & Train
142
  optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
143
  best_f1 = 0.0
144
- NUM_EPOCHS = 5
145
 
146
  print("\n🔥 Beginning Fine-Tuning...")
147
  for epoch in range(NUM_EPOCHS):
 
13
  import numpy as np
14
  import random
15
  import os
16
+ import pickle
17
 
18
  # --- IMPORTS ---
19
  from src.sroie_loader import load_sroie
 
22
  # --- CONFIGURATION ---
23
  # Points to your local SROIE copy
24
  SROIE_DATA_PATH = "data/sroie"
25
+ DOCTR_CACHE_PATH = "data/doctr_trained_cache.pkl" # DocTR pre-processed cache
26
  MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
27
+ OUTPUT_DIR = "models/layoutlmv3-doctr-trained"
28
 
29
  # Standard Label Set
30
  LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
 
88
 
89
  return {k: v.squeeze(0) for k, v in encoding.items()}
90
 
91
+
92
+ def load_doctr_cache(cache_path: str) -> dict:
93
+ """Load pre-processed DocTR training data from cache."""
94
+ print(f"📦 Loading DocTR cache from {cache_path}...")
95
+ with open(cache_path, "rb") as f:
96
+ data = pickle.load(f)
97
+ print(f" ✅ Loaded {len(data.get('train', []))} train, {len(data.get('test', []))} test examples")
98
+ return data
99
+
100
+
101
  def train():
102
  print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
103
 
104
+ # 1. Load SROIE data (prefer DocTR cache if available)
105
+ if os.path.exists(DOCTR_CACHE_PATH):
106
+ print("🔄 Using DocTR-aligned training data (recommended)")
107
+ sroie_data = load_doctr_cache(DOCTR_CACHE_PATH)
108
+ else:
109
+ print("⚠️ DocTR cache not found. Using original SROIE loader.")
110
+ print(" Run 'python scripts/prepare_doctr_data.py' to generate the cache.")
111
+
112
+ if not os.path.exists(SROIE_DATA_PATH):
113
+ print(f"❌ Error: SROIE path not found at {SROIE_DATA_PATH}")
114
+ print("Please make sure you copied the 'sroie' folder into 'data/'.")
115
+ return
116
+
117
+ sroie_data = load_sroie(SROIE_DATA_PATH)
118
+
119
  print(f" - SROIE Train: {len(sroie_data['train'])}")
120
  print(f" - SROIE Test: {len(sroie_data['test'])}")
121
 
 
159
  # 6. Optimize & Train
160
  optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
161
  best_f1 = 0.0
162
+ NUM_EPOCHS = 10
163
 
164
  print("\n🔥 Beginning Fine-Tuning...")
165
  for epoch in range(NUM_EPOCHS):
src/extraction.py CHANGED
@@ -102,29 +102,57 @@ def extract_vendor(text: str) -> Optional[str]:
102
  return None
103
 
104
  def extract_invoice_number(text: str) -> Optional[str]:
105
- """
106
- Improved regex that handles alphanumeric AND numeric IDs, plus variations like "Tax Inv".
107
- """
108
  if not text: return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Strategy 1: Look for "Invoice No: XXXXX" pattern
111
- # UPDATED: Handles "Tax Invoice", "Inv No", and standard variations
112
- keyword_pattern = r'(?:TAX\s*)?(?:INVOICE|INV|BILL|RECEIPT)\s*(?:NO|NUMBER|#|NUM)?[\s\.:-]*([A-Z0-9\-/]{3,})'
113
- match = re.search(keyword_pattern, text, re.IGNORECASE)
114
- if match:
115
- return match.group(1)
116
 
117
- # Strategy 2: Look for standalone labeled patterns (Existing Logic)
118
- # Only if Strategy 1 fails
119
  lines = text.split('\n')
120
- for line in lines[:20]:
121
- if any(k in line.lower() for k in ['invoice', 'no', '#']):
122
- # Allow pure digits now if they are long enough (e.g. 40378170)
123
- # Match 4+ digits OR alphanumeric
124
- token_match = re.search(r'\b([A-Z0-9-]{4,})\b', line)
125
- if token_match:
126
- return token_match.group(1)
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
 
 
128
  return None
129
 
130
  def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
 
102
  return None
103
 
104
  def extract_invoice_number(text: str) -> Optional[str]:
 
 
 
105
  if not text: return None
106
+
107
+ # 1. BLOCK LIST: Words that might be captured as the ID itself by mistake
108
+ FORBIDDEN_WORDS = {
109
+ 'INVOICE', 'TAX', 'RECEIPT', 'BILL', 'NUMBER', 'NO', 'DATE',
110
+ 'ORIGINAL', 'COPY', 'GST', 'REG', 'MEMBER', 'SLIP', 'TEL', 'FAX'
111
+ }
112
+
113
+ # 2. TOXIC CONTEXTS: If a line contains these, it's likely a Tax ID or Phone #, not an Invoice #
114
+ # We skip the line entirely if these are found (unless "INVOICE" is also strictly present)
115
+ TOXIC_LINE_INDICATORS = ['GST', 'REG', 'SSM', 'TIN', 'PHONE', 'TEL', 'FAX', 'UBL', 'UEN']
116
+
117
+ # Strategy 1: Explicit Label Search (High Confidence)
118
+ # matches "Invoice No:", "Slip No:", "Bill #:", etc.
119
+ # ADDED: 'SLIP' to the valid prefixes
120
+ keyword_pattern = r'(?i)(?:TAX\s*)?(?:INVOICE|INV|BILL|RECEIPT|SLIP)\s*(?:NO|NUMBER|#|NUM)\s*[:\.]?\s*([A-Z0-9\-/]+)'
121
+ matches = re.findall(keyword_pattern, text)
122
 
123
+ for match in matches:
124
+ clean_match = match.strip()
125
+ # Verify length and ensure the match itself isn't a forbidden word
126
+ if len(clean_match) >= 3 and clean_match.upper() not in FORBIDDEN_WORDS:
127
+ return clean_match
 
128
 
129
+ # Strategy 2: Contextual Line Search (Medium Confidence)
130
+ # We scan line-by-line for loose patterns like "No: 12345" or "Slip: 555"
131
  lines = text.split('\n')
132
+ for line in lines[:25]: # Scan top 25 lines
133
+ line_upper = line.upper()
134
+
135
+ # ⚠️ CRITICAL FIX: Skip lines that look like Tax IDs (GST/REG)
136
+ # But allow if the line explicitly says "INVOICE" (e.g. "Tax Invoice / GST Reg No")
137
+ if any(bad in line_upper for bad in TOXIC_LINE_INDICATORS) and "INVOICE" not in line_upper:
138
+ continue
139
+
140
+ # Look for Invoice-like keywords (Added SLIP)
141
+ # matches " NO", " #", "SLIP"
142
+ if any(k in line_upper for k in ['INVOICE', ' NO', ' #', 'INV', 'SLIP', 'BILL']):
143
+
144
+ # Find candidate tokens: 3+ alphanumeric chars
145
+ tokens = re.findall(r'\b[A-Z0-9\-/]{3,}\b', line_upper)
146
+
147
+ for token in tokens:
148
+ if token in FORBIDDEN_WORDS:
149
+ continue
150
 
151
+ # Heuristic: Invoice numbers almost always have digits.
152
+ # This filters out purely alpha strings like "CREDIT" or "CASH"
153
+ if any(c.isdigit() for c in token):
154
+ return token
155
+
156
  return None
157
 
158
  def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
src/ml_extraction.py CHANGED
@@ -5,17 +5,18 @@ import torch
5
  from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from huggingface_hub import snapshot_download
7
  from PIL import Image
8
- import pytesseract
9
- from typing import List, Dict, Any
10
  import re
11
  import numpy as np
12
  from extraction import extract_invoice_number, extract_total
 
 
13
 
14
  # --- CONFIGURATION ---
15
- LOCAL_MODEL_PATH = "./models/layoutlmv3-generalized"
16
- HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-sroie-invoice-extraction"
17
 
18
- # --- Load Model ---
19
  def load_model_and_processor(model_path, hub_id):
20
  print("Loading processor from microsoft/layoutlmv3-base...")
21
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
@@ -32,7 +33,26 @@ def load_model_and_processor(model_path, hub_id):
32
 
33
  return model, processor
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
 
36
 
37
  if MODEL and PROCESSOR:
38
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -43,6 +63,71 @@ else:
43
  DEVICE = None
44
  print("❌ Could not load ML model.")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
47
  word_ids = encoding.word_ids(batch_index=0)
48
  word_level_preds = {}
@@ -70,6 +155,7 @@ def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2la
70
 
71
  return entities
72
 
 
73
  def extract_ml_based(image_path: str) -> Dict[str, Any]:
74
  if not MODEL or not PROCESSOR:
75
  raise RuntimeError("ML model is not loaded.")
@@ -77,35 +163,59 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
77
  # 1. Load Image
78
  image = Image.open(image_path).convert("RGB")
79
  width, height = image.size
80
- ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
81
 
82
- words = []
83
- unnormalized_boxes = []
84
- for i in range(len(ocr_data['level'])):
85
- if int(ocr_data['conf'][i]) > 30 and ocr_data['text'][i].strip() != '':
86
- words.append(ocr_data['text'][i])
87
- unnormalized_boxes.append([
88
- ocr_data['left'][i], ocr_data['top'][i],
89
- ocr_data['width'][i], ocr_data['height'][i]
90
- ])
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- raw_text = " ".join(words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # 2. Normalize Boxes (WITH SAFETY CLAMP)
95
- normalized_boxes = []
96
- for box in unnormalized_boxes:
97
- x, y, w, h = box
98
- x0, y0, x1, y1 = x, y, x + w, y + h
99
-
100
- # ⚠️ The Fix: Ensure values never exceed 1000 or drop below 0
101
- normalized_boxes.append([
102
- max(0, min(1000, int(1000 * (x0 / width)))),
103
- max(0, min(1000, int(1000 * (y0 / height)))),
104
- max(0, min(1000, int(1000 * (x1 / width)))),
105
- max(0, min(1000, int(1000 * (y1 / height)))),
106
- ])
107
-
108
- # 3. Inference
109
  encoding = PROCESSOR(
110
  image, text=words, boxes=normalized_boxes,
111
  truncation=True, max_length=512, return_tensors="pt"
@@ -117,7 +227,7 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
117
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
118
  extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
119
 
120
- # 4. Construct Output
121
  final_output = {
122
  "vendor": extracted_entities.get("COMPANY", {}).get("text"),
123
  "date": extracted_entities.get("DATE", {}).get("text"),
@@ -130,6 +240,20 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
130
  "raw_predictions": extracted_entities # Contains text and bbox data for each entity
131
  }
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # Fallbacks
134
  ml_total = extracted_entities.get("TOTAL", {}).get("text")
135
  if ml_total:
@@ -144,5 +268,29 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
144
 
145
  if not final_output["receipt_number"]:
146
  final_output["receipt_number"] = extract_invoice_number(raw_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  return final_output
 
5
  from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from huggingface_hub import snapshot_download
7
  from PIL import Image
8
+ from typing import List, Dict, Any, Tuple
 
9
  import re
10
  import numpy as np
11
  from extraction import extract_invoice_number, extract_total
12
+ from doctr.io import DocumentFile
13
+ from doctr.models import ocr_predictor
14
 
15
  # --- CONFIGURATION ---
16
+ LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
17
+ HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor"
18
 
19
+ # --- Load LayoutLMv3 Model ---
20
  def load_model_and_processor(model_path, hub_id):
21
  print("Loading processor from microsoft/layoutlmv3-base...")
22
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
 
33
 
34
  return model, processor
35
 
36
+ # --- Load DocTR OCR Predictor ---
37
+ def load_doctr_predictor():
38
+ """Initialize DocTR predictor and move to GPU for speed."""
39
+ print("Loading DocTR OCR predictor...")
40
+ predictor = ocr_predictor(
41
+ det_arch='db_resnet50',
42
+ reco_arch='crnn_vgg16_bn',
43
+ pretrained=True
44
+ )
45
+ if torch.cuda.is_available():
46
+ print("🚀 Moving DocTR to GPU (CUDA)...")
47
+ predictor.cuda()
48
+ else:
49
+ print("⚠️ GPU not found. Running on CPU (slow).")
50
+
51
+ print("DocTR OCR predictor is ready.")
52
+ return predictor
53
+
54
  MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
55
+ DOCTR_PREDICTOR = load_doctr_predictor()
56
 
57
  if MODEL and PROCESSOR:
58
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
63
  DEVICE = None
64
  print("❌ Could not load ML model.")
65
 
66
+
67
+ def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
68
+ """
69
+ Parse DocTR's hierarchical output (Page -> Block -> Line -> Word)
70
+ into flat lists of words and bounding boxes for LayoutLMv3.
71
+
72
+ DocTR returns coordinates in 0-1.0 scale (relative to image).
73
+ We convert to:
74
+ - unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization
75
+ - normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3
76
+
77
+ Args:
78
+ doctr_result: Output from DocTR predictor
79
+ img_width: Original image width in pixels
80
+ img_height: Original image height in pixels
81
+
82
+ Returns:
83
+ words: List of word strings
84
+ unnormalized_boxes: List of [x, y, width, height] in pixel coordinates
85
+ normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
86
+ """
87
+ words = []
88
+ unnormalized_boxes = []
89
+ normalized_boxes = []
90
+
91
+ # DocTR hierarchy: Document -> Page -> Block -> Line -> Word
92
+ for page in doctr_result.pages:
93
+ for block in page.blocks:
94
+ for line in block.lines:
95
+ for word in line.words:
96
+ # Skip empty words
97
+ if not word.value.strip():
98
+ continue
99
+
100
+ words.append(word.value)
101
+
102
+ # DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
103
+ (x_min, y_min), (x_max, y_max) = word.geometry
104
+
105
+ # Convert to pixel coordinates for visualization
106
+ px_x0 = int(x_min * img_width)
107
+ px_y0 = int(y_min * img_height)
108
+ px_x1 = int(x_max * img_width)
109
+ px_y1 = int(y_max * img_height)
110
+
111
+ # Unnormalized box: [x, y, width, height] for visualization overlay
112
+ unnormalized_boxes.append([
113
+ px_x0,
114
+ px_y0,
115
+ px_x1 - px_x0, # width
116
+ px_y1 - px_y0 # height
117
+ ])
118
+
119
+ # Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3
120
+ # Clamp values to ensure they stay within [0, 1000]
121
+ normalized_boxes.append([
122
+ max(0, min(1000, int(x_min * 1000))),
123
+ max(0, min(1000, int(y_min * 1000))),
124
+ max(0, min(1000, int(x_max * 1000))),
125
+ max(0, min(1000, int(y_max * 1000))),
126
+ ])
127
+
128
+ return words, unnormalized_boxes, normalized_boxes
129
+
130
+
131
  def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
132
  word_ids = encoding.word_ids(batch_index=0)
133
  word_level_preds = {}
 
155
 
156
  return entities
157
 
158
+
159
  def extract_ml_based(image_path: str) -> Dict[str, Any]:
160
  if not MODEL or not PROCESSOR:
161
  raise RuntimeError("ML model is not loaded.")
 
163
  # 1. Load Image
164
  image = Image.open(image_path).convert("RGB")
165
  width, height = image.size
 
166
 
167
+ # 2. Run DocTR OCR
168
+ doc = DocumentFile.from_images(image_path)
169
+ doctr_result = DOCTR_PREDICTOR(doc)
170
+
171
+ # 3. Parse DocTR output to get words and boxes
172
+ words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
173
+ doctr_result, width, height
174
+ )
175
+
176
+ # Reconstructs lines so regex can work line-by-line
177
+ lines = []
178
+ current_line = []
179
+
180
+ if len(unnormalized_boxes) > 0:
181
+ # Initialize with first word's Y and Height
182
+ current_y = unnormalized_boxes[0][1]
183
+ current_h = unnormalized_boxes[0][3]
184
+
185
+ for i, word in enumerate(words):
186
+ y = unnormalized_boxes[i][1]
187
+ h = unnormalized_boxes[i][3]
188
 
189
+ # If vertical gap > 50% of line height, it's a new line
190
+ if abs(y - current_y) > max(current_h, h) / 2:
191
+ lines.append(" ".join(current_line))
192
+ current_line = []
193
+ current_y = y
194
+ current_h = h
195
+
196
+ current_line.append(word)
197
+
198
+ # Append the last line
199
+ if current_line:
200
+ lines.append(" ".join(current_line))
201
+
202
+ raw_text = "\n".join(lines)
203
+
204
+ # Handle empty OCR result
205
+ if not words:
206
+ return {
207
+ "vendor": None,
208
+ "date": None,
209
+ "address": None,
210
+ "receipt_number": None,
211
+ "bill_to": None,
212
+ "total_amount": None,
213
+ "items": [],
214
+ "raw_text": "",
215
+ "raw_predictions": {}
216
+ }
217
 
218
+ # 4. Inference with LayoutLMv3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  encoding = PROCESSOR(
220
  image, text=words, boxes=normalized_boxes,
221
  truncation=True, max_length=512, return_tensors="pt"
 
227
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
228
  extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
229
 
230
+ # 5. Construct Output
231
  final_output = {
232
  "vendor": extracted_entities.get("COMPANY", {}).get("text"),
233
  "date": extracted_entities.get("DATE", {}).get("text"),
 
240
  "raw_predictions": extracted_entities # Contains text and bbox data for each entity
241
  }
242
 
243
+ # 6. Vendor Fallback (Spatial Heuristic)
244
+ # If ML failed to find a vendor, assume the largest text at the top is the vendor
245
+ if not final_output["vendor"] and unnormalized_boxes:
246
+ # Filter for words in the top 20% of the image
247
+ top_words_indices = [
248
+ i for i, box in enumerate(unnormalized_boxes)
249
+ if box[1] < height * 0.2
250
+ ]
251
+
252
+ if top_words_indices:
253
+ # Find the word with the largest height (font size)
254
+ largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
255
+ final_output["vendor"] = words[largest_idx]
256
+
257
  # Fallbacks
258
  ml_total = extracted_entities.get("TOTAL", {}).get("text")
259
  if ml_total:
 
268
 
269
  if not final_output["receipt_number"]:
270
  final_output["receipt_number"] = extract_invoice_number(raw_text)
271
+
272
+ # Backfill Bounding Boxes for Regex Results
273
+ # If Regex found the number but ML didn't, we must find its box
274
+ # in the OCR data so the UI can draw it.
275
+
276
+ if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
277
+ target_val = final_output["receipt_number"].strip()
278
+ found_box = None
279
+
280
+ # 1. Try finding the exact word in the OCR list
281
+ # 'words' and 'unnormalized_boxes' are available from step 3
282
+ for i, word in enumerate(words):
283
+ # Check for exact match or if the word contains the target (e.g. "Inv#123")
284
+ if target_val == word or (len(target_val) > 3 and target_val in word):
285
+ found_box = unnormalized_boxes[i]
286
+ break
287
+
288
+ # 2. If found, inject it into raw_predictions
289
+ if found_box:
290
+ # The UI expects a list of boxes
291
+ final_output["raw_predictions"]["INVOICE_NO"] = {
292
+ "text": target_val,
293
+ "bbox": [found_box]
294
+ }
295
 
296
  return final_output
src/ocr.py DELETED
@@ -1,42 +0,0 @@
1
- # src/ocr.py
2
-
3
- import pytesseract
4
- import numpy as np
5
- import os
6
- import shutil
7
- import sys
8
-
9
- # --- Dynamic Tesseract Configuration ---
10
- # This block ensures the code runs on both Windows (Local) and Linux (Production)
11
- if os.name == 'nt': # Windows
12
- # Common default installation paths for Windows
13
- possible_paths = [
14
- r'C:\Program Files\Tesseract-OCR\tesseract.exe',
15
- r'C:\Program Files (x86)\Tesseract-OCR\tesseract.exe',
16
- r'C:\Users\{}\AppData\Local\Tesseract-OCR\tesseract.exe'.format(os.getlogin())
17
- ]
18
-
19
- # Search for the executable
20
- found = False
21
- for path in possible_paths:
22
- if os.path.exists(path):
23
- pytesseract.pytesseract.tesseract_cmd = path
24
- found = True
25
- print(f"✅ Found Tesseract at: {path}")
26
- break
27
-
28
- if not found:
29
- print("⚠️ Warning: Tesseract exe not found in standard paths. Assuming it's in system PATH.")
30
- else:
31
- # Linux/Mac (Docker/Production)
32
- if not shutil.which('tesseract'):
33
- print("⚠️ Warning: 'tesseract' binary not found in PATH. Please install tesseract-ocr.")
34
-
35
- def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
36
- if image is None:
37
- raise ValueError("Input image is None")
38
- # Pytesseract will now use the path found above (or default to PATH)
39
- return pytesseract.image_to_string(image, lang=lang, config=config).strip()
40
-
41
- def extract_text_with_boxes(image):
42
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pipeline.py CHANGED
@@ -13,7 +13,6 @@ import cv2
13
 
14
  # --- IMPORTS ---
15
  from preprocessing import load_image, convert_to_grayscale, remove_noise
16
- from ocr import extract_text
17
  from extraction import structure_output
18
  from ml_extraction import extract_ml_based
19
  from schema import InvoiceData
@@ -90,13 +89,10 @@ def process_invoice(image_path: str,
90
 
91
  elif method == 'rules':
92
  try:
93
- image = load_image(image_path)
94
- gray_image = convert_to_grayscale(image)
95
- preprocessed_image = remove_noise(gray_image, kernel_size=3)
96
- text = extract_text(preprocessed_image, config='--psm 6')
97
- raw_result = structure_output(text)
98
  except Exception as e:
99
- raise ValueError(f"Error during rule-based extraction: {e}")
100
 
101
  # Clean up temp file if we created one
102
  if image_path.endswith('.jpg') and 'sample_pdf' in image_path: # Safety check
 
13
 
14
  # --- IMPORTS ---
15
  from preprocessing import load_image, convert_to_grayscale, remove_noise
 
16
  from extraction import structure_output
17
  from ml_extraction import extract_ml_based
18
  from schema import InvoiceData
 
89
 
90
  elif method == 'rules':
91
  try:
92
+ print("⚠️ Rule-based mode is deprecated. Redirecting to ML-based extraction.")
93
+ raw_result = extract_ml_based(image_path)
 
 
 
94
  except Exception as e:
95
+ raise ValueError(f"Error during ML-based extraction: {e}")
96
 
97
  # Clean up temp file if we created one
98
  if image_path.endswith('.jpg') and 'sample_pdf' in image_path: # Safety check