Spaces:
Sleeping
Sleeping
Commit
·
ec0b507
1
Parent(s):
343b0c3
Refactor: Replace Tesseract with DocTR and integrate LayoutLMv3-DocTR model
Browse filesMajor overhaul of OCR/Inference pipeline. Swapped Tesseract for DocTR, retrained LayoutLMv3 (~83% F1), and fixed address extraction using Fuzzy Matching.
- .gitignore +1 -0
- Dockerfile +6 -3
- README.md +1 -1
- requirements.txt +1 -1
- scripts/prepare_doctr_data.py +377 -0
- scripts/train_combined.py +29 -11
- src/extraction.py +46 -18
- src/ml_extraction.py +180 -32
- src/ocr.py +0 -42
- src/pipeline.py +3 -7
.gitignore
CHANGED
|
@@ -23,6 +23,7 @@ credentials.json
|
|
| 23 |
*.log
|
| 24 |
logs/
|
| 25 |
.cache/
|
|
|
|
| 26 |
|
| 27 |
# OS
|
| 28 |
.DS_Store
|
|
|
|
| 23 |
*.log
|
| 24 |
logs/
|
| 25 |
.cache/
|
| 26 |
+
*.pkl
|
| 27 |
|
| 28 |
# OS
|
| 29 |
.DS_Store
|
Dockerfile
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
# Use an official Python runtime
|
| 2 |
FROM python:3.10-slim
|
| 3 |
|
| 4 |
-
# 1. Install system dependencies (
|
| 5 |
-
#
|
| 6 |
RUN apt-get update && apt-get install -y \
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
poppler-utils \
|
| 9 |
ffmpeg libsm6 libxext6 \
|
| 10 |
&& rm -rf /var/lib/apt/lists/*
|
|
|
|
| 1 |
# Use an official Python runtime
|
| 2 |
FROM python:3.10-slim
|
| 3 |
|
| 4 |
+
# 1. Install system dependencies (DocTR + OpenCV + POPPLER)
|
| 5 |
+
# DocTR requires OpenGL and GStreamer libraries for image processing
|
| 6 |
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1-mesa-dev \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
libgstreamer1.0-0 \
|
| 10 |
+
libgstreamer-plugins-base1.0-0 \
|
| 11 |
poppler-utils \
|
| 12 |
ffmpeg libsm6 libxext6 \
|
| 13 |
&& rm -rf /var/lib/apt/lists/*
|
README.md
CHANGED
|
@@ -374,7 +374,7 @@ invoice-processor-ml/
|
|
| 374 |
|
| 375 |
## ⚠️ Known Limitations
|
| 376 |
|
| 377 |
-
1. **Layout Sensitivity**: The ML model was fine‑tuned
|
| 378 |
2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
|
| 379 |
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 380 |
4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
|
|
|
|
| 374 |
|
| 375 |
## ⚠️ Known Limitations
|
| 376 |
|
| 377 |
+
1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
|
| 378 |
2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
|
| 379 |
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 380 |
4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
streamlit>=1.28.0
|
| 3 |
|
| 4 |
# ----- OCR -----
|
| 5 |
-
|
| 6 |
opencv-python>=4.8.0
|
| 7 |
Pillow>=10.0.0
|
| 8 |
|
|
|
|
| 2 |
streamlit>=1.28.0
|
| 3 |
|
| 4 |
# ----- OCR -----
|
| 5 |
+
python-doctr[torch]>=0.8.0
|
| 6 |
opencv-python>=4.8.0
|
| 7 |
Pillow>=10.0.0
|
| 8 |
|
scripts/prepare_doctr_data.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/prepare_doctr_data.py
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Prepare training data using DocTR OCR output.
|
| 5 |
+
|
| 6 |
+
This script:
|
| 7 |
+
1. Iterates through SROIE training/test images
|
| 8 |
+
2. Runs DocTR OCR to get words and boxes
|
| 9 |
+
3. Aligns DocTR output with ground truth labels using fuzzy matching
|
| 10 |
+
4. Saves the aligned dataset to a pickle file for training
|
| 11 |
+
|
| 12 |
+
This ensures the model learns from DocTR's actual output (with its specific errors)
|
| 13 |
+
rather than from perfect ground truth which it will never see in production.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import pickle
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from difflib import SequenceMatcher
|
| 27 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 28 |
+
|
| 29 |
+
from doctr.io import DocumentFile
|
| 30 |
+
from doctr.models import ocr_predictor
|
| 31 |
+
|
| 32 |
+
# --- CONFIGURATION ---
|
| 33 |
+
SROIE_DATA_PATH = "data/sroie"
|
| 34 |
+
OUTPUT_CACHE_PATH = "data/doctr_trained_cache.pkl"
|
| 35 |
+
|
| 36 |
+
# Ground truth field names and their corresponding BIO labels
|
| 37 |
+
GT_FIELD_MAPPING = {
|
| 38 |
+
"company": "COMPANY",
|
| 39 |
+
"date": "DATE",
|
| 40 |
+
"address": "ADDRESS",
|
| 41 |
+
"total": "TOTAL",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_doctr_predictor():
|
| 46 |
+
"""Initialize DocTR predictor with lightweight backbone and move to GPU."""
|
| 47 |
+
print("Loading DocTR OCR predictor...")
|
| 48 |
+
|
| 49 |
+
# 1. Initialize the model
|
| 50 |
+
predictor = ocr_predictor(
|
| 51 |
+
det_arch='db_resnet50',
|
| 52 |
+
reco_arch='crnn_vgg16_bn',
|
| 53 |
+
pretrained=True
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 2. Force it to GPU if available
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
print("🚀 Moving DocTR to GPU (CUDA)...")
|
| 59 |
+
predictor.cuda()
|
| 60 |
+
else:
|
| 61 |
+
print("⚠️ GPU not found. Running on CPU (this will be slow).")
|
| 62 |
+
|
| 63 |
+
print("DocTR OCR predictor ready.")
|
| 64 |
+
return predictor
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]]]:
|
| 68 |
+
"""
|
| 69 |
+
Parse DocTR output into words and normalized boxes (0-1000 scale).
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
words: List of word strings
|
| 73 |
+
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
|
| 74 |
+
"""
|
| 75 |
+
words = []
|
| 76 |
+
normalized_boxes = []
|
| 77 |
+
|
| 78 |
+
for page in doctr_result.pages:
|
| 79 |
+
for block in page.blocks:
|
| 80 |
+
for line in block.lines:
|
| 81 |
+
for word in line.words:
|
| 82 |
+
if not word.value.strip():
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
words.append(word.value)
|
| 86 |
+
|
| 87 |
+
# DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
|
| 88 |
+
(x_min, y_min), (x_max, y_max) = word.geometry
|
| 89 |
+
|
| 90 |
+
# Normalize to 0-1000 scale with clamping
|
| 91 |
+
normalized_boxes.append([
|
| 92 |
+
max(0, min(1000, int(x_min * 1000))),
|
| 93 |
+
max(0, min(1000, int(y_min * 1000))),
|
| 94 |
+
max(0, min(1000, int(x_max * 1000))),
|
| 95 |
+
max(0, min(1000, int(y_max * 1000))),
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
return words, normalized_boxes
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def fuzzy_match_score(s1: str, s2: str) -> float:
|
| 102 |
+
"""Calculate fuzzy match score between two strings."""
|
| 103 |
+
return SequenceMatcher(None, s1.lower(), s2.lower()).ratio()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def find_entity_in_words(
|
| 107 |
+
entity_text: str,
|
| 108 |
+
words: List[str],
|
| 109 |
+
start_idx: int = 0,
|
| 110 |
+
threshold: float = 0.7
|
| 111 |
+
) -> Optional[Tuple[int, int]]:
|
| 112 |
+
"""
|
| 113 |
+
Find a ground truth entity in the DocTR words using fuzzy matching.
|
| 114 |
+
Includes expansion search to handle OCR word splitting.
|
| 115 |
+
"""
|
| 116 |
+
entity_words = entity_text.split()
|
| 117 |
+
n_target = len(entity_words)
|
| 118 |
+
|
| 119 |
+
# 1. Single word match
|
| 120 |
+
if n_target == 1:
|
| 121 |
+
best_score = 0
|
| 122 |
+
best_idx = -1
|
| 123 |
+
for i in range(start_idx, len(words)):
|
| 124 |
+
score = fuzzy_match_score(entity_text, words[i])
|
| 125 |
+
if score > best_score and score >= threshold:
|
| 126 |
+
best_score = score
|
| 127 |
+
best_idx = i
|
| 128 |
+
if best_idx >= 0:
|
| 129 |
+
return (best_idx, best_idx)
|
| 130 |
+
|
| 131 |
+
# 2. Multi-word entity: Flexible Window Search
|
| 132 |
+
# We search windows of size N, N+1, N+2... up to N+5 (to catch OCR splits)
|
| 133 |
+
# AND N-1, N-2... (to catch OCR merges)
|
| 134 |
+
|
| 135 |
+
best_match_score = 0.0
|
| 136 |
+
best_match_indices = None
|
| 137 |
+
|
| 138 |
+
# Define search range: from (Length - 3) to (Length + 5)
|
| 139 |
+
min_len = max(1, n_target - 3)
|
| 140 |
+
max_len = min(len(words) - start_idx, n_target + 5)
|
| 141 |
+
|
| 142 |
+
combined_entity_text = " ".join(entity_words)
|
| 143 |
+
|
| 144 |
+
# Iterate through window sizes
|
| 145 |
+
for window_size in range(min_len, max_len + 1):
|
| 146 |
+
for i in range(start_idx, len(words) - window_size + 1):
|
| 147 |
+
|
| 148 |
+
# Construct window text
|
| 149 |
+
window_tokens = words[i : i + window_size]
|
| 150 |
+
window_text = " ".join(window_tokens)
|
| 151 |
+
|
| 152 |
+
score = fuzzy_match_score(combined_entity_text, window_text)
|
| 153 |
+
|
| 154 |
+
# Optimization: If perfect match, return immediately
|
| 155 |
+
if score > 0.95:
|
| 156 |
+
return (i, i + window_size - 1)
|
| 157 |
+
|
| 158 |
+
if score > best_match_score and score >= threshold:
|
| 159 |
+
best_match_score = score
|
| 160 |
+
best_match_indices = (i, i + window_size - 1)
|
| 161 |
+
|
| 162 |
+
return best_match_indices
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_ground_truth(json_path: Path) -> Dict[str, str]:
|
| 166 |
+
"""
|
| 167 |
+
Load ground truth entities from the tagged JSON.
|
| 168 |
+
|
| 169 |
+
The SROIE tagged JSON has: {"words": [...], "bbox": [...], "labels": [...]}
|
| 170 |
+
We need to reconstruct the entity values from words + labels.
|
| 171 |
+
"""
|
| 172 |
+
with open(json_path, encoding="utf-8") as f:
|
| 173 |
+
data = json.load(f)
|
| 174 |
+
|
| 175 |
+
words = data.get("words", [])
|
| 176 |
+
labels = data.get("labels", [])
|
| 177 |
+
|
| 178 |
+
# Reconstruct entities from BIO tags
|
| 179 |
+
entities = {}
|
| 180 |
+
current_entity = None
|
| 181 |
+
current_text = []
|
| 182 |
+
|
| 183 |
+
for word, label in zip(words, labels):
|
| 184 |
+
if label.startswith("B-"):
|
| 185 |
+
# Save previous entity if exists
|
| 186 |
+
if current_entity and current_text:
|
| 187 |
+
entities[current_entity.lower()] = " ".join(current_text)
|
| 188 |
+
|
| 189 |
+
# Start new entity
|
| 190 |
+
current_entity = label[2:] # Remove "B-" prefix
|
| 191 |
+
current_text = [word]
|
| 192 |
+
|
| 193 |
+
elif label.startswith("I-") and current_entity:
|
| 194 |
+
entity_type = label[2:]
|
| 195 |
+
if entity_type == current_entity:
|
| 196 |
+
current_text.append(word)
|
| 197 |
+
else:
|
| 198 |
+
# Entity type changed, save current
|
| 199 |
+
if current_text:
|
| 200 |
+
entities[current_entity.lower()] = " ".join(current_text)
|
| 201 |
+
current_entity = None
|
| 202 |
+
current_text = []
|
| 203 |
+
else:
|
| 204 |
+
# "O" label - save current entity if exists
|
| 205 |
+
if current_entity and current_text:
|
| 206 |
+
entities[current_entity.lower()] = " ".join(current_text)
|
| 207 |
+
current_entity = None
|
| 208 |
+
current_text = []
|
| 209 |
+
|
| 210 |
+
# Don't forget the last entity
|
| 211 |
+
if current_entity and current_text:
|
| 212 |
+
entities[current_entity.lower()] = " ".join(current_text)
|
| 213 |
+
|
| 214 |
+
return entities
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def align_labels(
|
| 218 |
+
doctr_words: List[str],
|
| 219 |
+
ground_truth: Dict[str, str]
|
| 220 |
+
) -> List[str]:
|
| 221 |
+
labels = ["O"] * len(doctr_words)
|
| 222 |
+
used_indices = set()
|
| 223 |
+
|
| 224 |
+
for gt_field, bio_label in GT_FIELD_MAPPING.items():
|
| 225 |
+
if gt_field not in ground_truth:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
entity_text = ground_truth[gt_field]
|
| 229 |
+
if not entity_text or not entity_text.strip():
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
# DYNAMIC THRESHOLD: Be lenient with Addresses, strict with Dates/Totals
|
| 233 |
+
current_threshold = 0.6
|
| 234 |
+
if bio_label == "ADDRESS":
|
| 235 |
+
current_threshold = 0.45 # Lower threshold for messy addresses
|
| 236 |
+
elif bio_label in ["DATE", "TOTAL"]:
|
| 237 |
+
current_threshold = 0.7 # Keep strict for precision fields
|
| 238 |
+
|
| 239 |
+
match = find_entity_in_words(entity_text, doctr_words, start_idx=0, threshold=current_threshold)
|
| 240 |
+
|
| 241 |
+
if match:
|
| 242 |
+
start_idx, end_idx = match
|
| 243 |
+
|
| 244 |
+
# Overlap check
|
| 245 |
+
if any(i in used_indices for i in range(start_idx, end_idx + 1)):
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
labels[start_idx] = f"B-{bio_label}"
|
| 249 |
+
for i in range(start_idx + 1, end_idx + 1):
|
| 250 |
+
labels[i] = f"I-{bio_label}"
|
| 251 |
+
|
| 252 |
+
used_indices.update(range(start_idx, end_idx + 1))
|
| 253 |
+
|
| 254 |
+
return labels
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def process_split(
|
| 258 |
+
split_path: Path,
|
| 259 |
+
predictor,
|
| 260 |
+
split_name: str
|
| 261 |
+
) -> List[Dict[str, Any]]:
|
| 262 |
+
"""Process all images in a split directory."""
|
| 263 |
+
|
| 264 |
+
# Find image and annotation directories
|
| 265 |
+
if (split_path / "images").exists():
|
| 266 |
+
img_dir = split_path / "images"
|
| 267 |
+
elif (split_path / "img").exists():
|
| 268 |
+
img_dir = split_path / "img"
|
| 269 |
+
else:
|
| 270 |
+
print(f" ⚠️ No image directory found in {split_path}")
|
| 271 |
+
return []
|
| 272 |
+
|
| 273 |
+
if (split_path / "tagged").exists():
|
| 274 |
+
ann_dir = split_path / "tagged"
|
| 275 |
+
elif (split_path / "box").exists():
|
| 276 |
+
ann_dir = split_path / "box"
|
| 277 |
+
else:
|
| 278 |
+
print(f" ⚠️ No annotation directory found in {split_path}")
|
| 279 |
+
return []
|
| 280 |
+
|
| 281 |
+
examples = []
|
| 282 |
+
image_files = sorted([f for f in img_dir.iterdir() if f.suffix.lower() in [".jpg", ".png"]])
|
| 283 |
+
|
| 284 |
+
print(f" Processing {len(image_files)} images in {split_name}...")
|
| 285 |
+
|
| 286 |
+
for img_file in tqdm(image_files, desc=f" {split_name}"):
|
| 287 |
+
try:
|
| 288 |
+
# Check for corresponding annotation
|
| 289 |
+
json_path = ann_dir / f"{img_file.stem}.json"
|
| 290 |
+
if not json_path.exists():
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# Load image dimensions
|
| 294 |
+
with Image.open(img_file) as img:
|
| 295 |
+
width, height = img.size
|
| 296 |
+
|
| 297 |
+
# Run DocTR OCR
|
| 298 |
+
doc = DocumentFile.from_images(str(img_file))
|
| 299 |
+
doctr_result = predictor(doc)
|
| 300 |
+
|
| 301 |
+
# Parse DocTR output
|
| 302 |
+
words, boxes = parse_doctr_output(doctr_result, width, height)
|
| 303 |
+
|
| 304 |
+
if not words:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
# Load ground truth and align labels
|
| 308 |
+
ground_truth = load_ground_truth(json_path)
|
| 309 |
+
aligned_labels = align_labels(words, ground_truth)
|
| 310 |
+
|
| 311 |
+
# Create example
|
| 312 |
+
examples.append({
|
| 313 |
+
"image_path": str(img_file),
|
| 314 |
+
"words": words,
|
| 315 |
+
"bboxes": boxes,
|
| 316 |
+
"ner_tags": aligned_labels,
|
| 317 |
+
"ground_truth": ground_truth # Keep for debugging
|
| 318 |
+
})
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
print(f"\n ❌ Error processing {img_file.name}: {e}")
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
return examples
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
print("=" * 60)
|
| 329 |
+
print("📦 DocTR Training Data Preparation")
|
| 330 |
+
print("=" * 60)
|
| 331 |
+
|
| 332 |
+
sroie_path = Path(SROIE_DATA_PATH)
|
| 333 |
+
|
| 334 |
+
if not sroie_path.exists():
|
| 335 |
+
print(f"❌ SROIE path not found: {sroie_path}")
|
| 336 |
+
return
|
| 337 |
+
|
| 338 |
+
# Load DocTR predictor
|
| 339 |
+
predictor = load_doctr_predictor()
|
| 340 |
+
|
| 341 |
+
dataset = {"train": [], "test": []}
|
| 342 |
+
|
| 343 |
+
# Process each split
|
| 344 |
+
for split in ["train", "test"]:
|
| 345 |
+
split_path = sroie_path / split
|
| 346 |
+
if not split_path.exists():
|
| 347 |
+
print(f" ⚠️ Split not found: {split}")
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
print(f"\n📂 Processing {split} split...")
|
| 351 |
+
examples = process_split(split_path, predictor, split)
|
| 352 |
+
dataset[split] = examples
|
| 353 |
+
|
| 354 |
+
# Stats
|
| 355 |
+
total_entities = sum(
|
| 356 |
+
sum(1 for label in ex["ner_tags"] if label.startswith("B-"))
|
| 357 |
+
for ex in examples
|
| 358 |
+
)
|
| 359 |
+
print(f" ✅ {len(examples)} images processed")
|
| 360 |
+
print(f" 📊 {total_entities} entities aligned")
|
| 361 |
+
|
| 362 |
+
# Save cache
|
| 363 |
+
print(f"\n💾 Saving cache to {OUTPUT_CACHE_PATH}...")
|
| 364 |
+
output_path = Path(OUTPUT_CACHE_PATH)
|
| 365 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
|
| 367 |
+
with open(output_path, "wb") as f:
|
| 368 |
+
pickle.dump(dataset, f)
|
| 369 |
+
|
| 370 |
+
print(f"✅ Cache saved!")
|
| 371 |
+
print(f" - Train examples: {len(dataset['train'])}")
|
| 372 |
+
print(f" - Test examples: {len(dataset['test'])}")
|
| 373 |
+
print("=" * 60)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
main()
|
scripts/train_combined.py
CHANGED
|
@@ -13,6 +13,7 @@ from pathlib import Path
|
|
| 13 |
import numpy as np
|
| 14 |
import random
|
| 15 |
import os
|
|
|
|
| 16 |
|
| 17 |
# --- IMPORTS ---
|
| 18 |
from src.sroie_loader import load_sroie
|
|
@@ -21,8 +22,9 @@ from src.data_loader import load_unified_dataset
|
|
| 21 |
# --- CONFIGURATION ---
|
| 22 |
# Points to your local SROIE copy
|
| 23 |
SROIE_DATA_PATH = "data/sroie"
|
|
|
|
| 24 |
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 25 |
-
OUTPUT_DIR = "models/layoutlmv3-
|
| 26 |
|
| 27 |
# Standard Label Set
|
| 28 |
LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
|
|
@@ -86,18 +88,34 @@ class UnifiedDataset(Dataset):
|
|
| 86 |
|
| 87 |
return {k: v.squeeze(0) for k, v in encoding.items()}
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
def train():
|
| 90 |
print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
if
|
| 94 |
-
print(
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
print(f" - SROIE Train: {len(sroie_data['train'])}")
|
| 102 |
print(f" - SROIE Test: {len(sroie_data['test'])}")
|
| 103 |
|
|
@@ -141,7 +159,7 @@ def train():
|
|
| 141 |
# 6. Optimize & Train
|
| 142 |
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
|
| 143 |
best_f1 = 0.0
|
| 144 |
-
NUM_EPOCHS =
|
| 145 |
|
| 146 |
print("\n🔥 Beginning Fine-Tuning...")
|
| 147 |
for epoch in range(NUM_EPOCHS):
|
|
|
|
| 13 |
import numpy as np
|
| 14 |
import random
|
| 15 |
import os
|
| 16 |
+
import pickle
|
| 17 |
|
| 18 |
# --- IMPORTS ---
|
| 19 |
from src.sroie_loader import load_sroie
|
|
|
|
| 22 |
# --- CONFIGURATION ---
|
| 23 |
# Points to your local SROIE copy
|
| 24 |
SROIE_DATA_PATH = "data/sroie"
|
| 25 |
+
DOCTR_CACHE_PATH = "data/doctr_trained_cache.pkl" # DocTR pre-processed cache
|
| 26 |
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 27 |
+
OUTPUT_DIR = "models/layoutlmv3-doctr-trained"
|
| 28 |
|
| 29 |
# Standard Label Set
|
| 30 |
LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
|
|
|
|
| 88 |
|
| 89 |
return {k: v.squeeze(0) for k, v in encoding.items()}
|
| 90 |
|
| 91 |
+
|
| 92 |
+
def load_doctr_cache(cache_path: str) -> dict:
|
| 93 |
+
"""Load pre-processed DocTR training data from cache."""
|
| 94 |
+
print(f"📦 Loading DocTR cache from {cache_path}...")
|
| 95 |
+
with open(cache_path, "rb") as f:
|
| 96 |
+
data = pickle.load(f)
|
| 97 |
+
print(f" ✅ Loaded {len(data.get('train', []))} train, {len(data.get('test', []))} test examples")
|
| 98 |
+
return data
|
| 99 |
+
|
| 100 |
+
|
| 101 |
def train():
|
| 102 |
print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
|
| 103 |
|
| 104 |
+
# 1. Load SROIE data (prefer DocTR cache if available)
|
| 105 |
+
if os.path.exists(DOCTR_CACHE_PATH):
|
| 106 |
+
print("🔄 Using DocTR-aligned training data (recommended)")
|
| 107 |
+
sroie_data = load_doctr_cache(DOCTR_CACHE_PATH)
|
| 108 |
+
else:
|
| 109 |
+
print("⚠️ DocTR cache not found. Using original SROIE loader.")
|
| 110 |
+
print(" Run 'python scripts/prepare_doctr_data.py' to generate the cache.")
|
| 111 |
+
|
| 112 |
+
if not os.path.exists(SROIE_DATA_PATH):
|
| 113 |
+
print(f"❌ Error: SROIE path not found at {SROIE_DATA_PATH}")
|
| 114 |
+
print("Please make sure you copied the 'sroie' folder into 'data/'.")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
sroie_data = load_sroie(SROIE_DATA_PATH)
|
| 118 |
+
|
| 119 |
print(f" - SROIE Train: {len(sroie_data['train'])}")
|
| 120 |
print(f" - SROIE Test: {len(sroie_data['test'])}")
|
| 121 |
|
|
|
|
| 159 |
# 6. Optimize & Train
|
| 160 |
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
|
| 161 |
best_f1 = 0.0
|
| 162 |
+
NUM_EPOCHS = 10
|
| 163 |
|
| 164 |
print("\n🔥 Beginning Fine-Tuning...")
|
| 165 |
for epoch in range(NUM_EPOCHS):
|
src/extraction.py
CHANGED
|
@@ -102,29 +102,57 @@ def extract_vendor(text: str) -> Optional[str]:
|
|
| 102 |
return None
|
| 103 |
|
| 104 |
def extract_invoice_number(text: str) -> Optional[str]:
|
| 105 |
-
"""
|
| 106 |
-
Improved regex that handles alphanumeric AND numeric IDs, plus variations like "Tax Inv".
|
| 107 |
-
"""
|
| 108 |
if not text: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
return match.group(1)
|
| 116 |
|
| 117 |
-
# Strategy 2:
|
| 118 |
-
#
|
| 119 |
lines = text.split('\n')
|
| 120 |
-
for line in lines[:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
return None
|
| 129 |
|
| 130 |
def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
|
|
|
|
| 102 |
return None
|
| 103 |
|
| 104 |
def extract_invoice_number(text: str) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
| 105 |
if not text: return None
|
| 106 |
+
|
| 107 |
+
# 1. BLOCK LIST: Words that might be captured as the ID itself by mistake
|
| 108 |
+
FORBIDDEN_WORDS = {
|
| 109 |
+
'INVOICE', 'TAX', 'RECEIPT', 'BILL', 'NUMBER', 'NO', 'DATE',
|
| 110 |
+
'ORIGINAL', 'COPY', 'GST', 'REG', 'MEMBER', 'SLIP', 'TEL', 'FAX'
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# 2. TOXIC CONTEXTS: If a line contains these, it's likely a Tax ID or Phone #, not an Invoice #
|
| 114 |
+
# We skip the line entirely if these are found (unless "INVOICE" is also strictly present)
|
| 115 |
+
TOXIC_LINE_INDICATORS = ['GST', 'REG', 'SSM', 'TIN', 'PHONE', 'TEL', 'FAX', 'UBL', 'UEN']
|
| 116 |
+
|
| 117 |
+
# Strategy 1: Explicit Label Search (High Confidence)
|
| 118 |
+
# matches "Invoice No:", "Slip No:", "Bill #:", etc.
|
| 119 |
+
# ADDED: 'SLIP' to the valid prefixes
|
| 120 |
+
keyword_pattern = r'(?i)(?:TAX\s*)?(?:INVOICE|INV|BILL|RECEIPT|SLIP)\s*(?:NO|NUMBER|#|NUM)\s*[:\.]?\s*([A-Z0-9\-/]+)'
|
| 121 |
+
matches = re.findall(keyword_pattern, text)
|
| 122 |
|
| 123 |
+
for match in matches:
|
| 124 |
+
clean_match = match.strip()
|
| 125 |
+
# Verify length and ensure the match itself isn't a forbidden word
|
| 126 |
+
if len(clean_match) >= 3 and clean_match.upper() not in FORBIDDEN_WORDS:
|
| 127 |
+
return clean_match
|
|
|
|
| 128 |
|
| 129 |
+
# Strategy 2: Contextual Line Search (Medium Confidence)
|
| 130 |
+
# We scan line-by-line for loose patterns like "No: 12345" or "Slip: 555"
|
| 131 |
lines = text.split('\n')
|
| 132 |
+
for line in lines[:25]: # Scan top 25 lines
|
| 133 |
+
line_upper = line.upper()
|
| 134 |
+
|
| 135 |
+
# ⚠️ CRITICAL FIX: Skip lines that look like Tax IDs (GST/REG)
|
| 136 |
+
# But allow if the line explicitly says "INVOICE" (e.g. "Tax Invoice / GST Reg No")
|
| 137 |
+
if any(bad in line_upper for bad in TOXIC_LINE_INDICATORS) and "INVOICE" not in line_upper:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
# Look for Invoice-like keywords (Added SLIP)
|
| 141 |
+
# matches " NO", " #", "SLIP"
|
| 142 |
+
if any(k in line_upper for k in ['INVOICE', ' NO', ' #', 'INV', 'SLIP', 'BILL']):
|
| 143 |
+
|
| 144 |
+
# Find candidate tokens: 3+ alphanumeric chars
|
| 145 |
+
tokens = re.findall(r'\b[A-Z0-9\-/]{3,}\b', line_upper)
|
| 146 |
+
|
| 147 |
+
for token in tokens:
|
| 148 |
+
if token in FORBIDDEN_WORDS:
|
| 149 |
+
continue
|
| 150 |
|
| 151 |
+
# Heuristic: Invoice numbers almost always have digits.
|
| 152 |
+
# This filters out purely alpha strings like "CREDIT" or "CASH"
|
| 153 |
+
if any(c.isdigit() for c in token):
|
| 154 |
+
return token
|
| 155 |
+
|
| 156 |
return None
|
| 157 |
|
| 158 |
def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
|
src/ml_extraction.py
CHANGED
|
@@ -5,17 +5,18 @@ import torch
|
|
| 5 |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
from PIL import Image
|
| 8 |
-
import
|
| 9 |
-
from typing import List, Dict, Any
|
| 10 |
import re
|
| 11 |
import numpy as np
|
| 12 |
from extraction import extract_invoice_number, extract_total
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# --- CONFIGURATION ---
|
| 15 |
-
LOCAL_MODEL_PATH = "./models/layoutlmv3-
|
| 16 |
-
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-
|
| 17 |
|
| 18 |
-
# --- Load Model ---
|
| 19 |
def load_model_and_processor(model_path, hub_id):
|
| 20 |
print("Loading processor from microsoft/layoutlmv3-base...")
|
| 21 |
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
|
@@ -32,7 +33,26 @@ def load_model_and_processor(model_path, hub_id):
|
|
| 32 |
|
| 33 |
return model, processor
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
|
|
|
|
| 36 |
|
| 37 |
if MODEL and PROCESSOR:
|
| 38 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -43,6 +63,71 @@ else:
|
|
| 43 |
DEVICE = None
|
| 44 |
print("❌ Could not load ML model.")
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
|
| 47 |
word_ids = encoding.word_ids(batch_index=0)
|
| 48 |
word_level_preds = {}
|
|
@@ -70,6 +155,7 @@ def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2la
|
|
| 70 |
|
| 71 |
return entities
|
| 72 |
|
|
|
|
| 73 |
def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
| 74 |
if not MODEL or not PROCESSOR:
|
| 75 |
raise RuntimeError("ML model is not loaded.")
|
|
@@ -77,35 +163,59 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
| 77 |
# 1. Load Image
|
| 78 |
image = Image.open(image_path).convert("RGB")
|
| 79 |
width, height = image.size
|
| 80 |
-
ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
#
|
| 95 |
-
normalized_boxes = []
|
| 96 |
-
for box in unnormalized_boxes:
|
| 97 |
-
x, y, w, h = box
|
| 98 |
-
x0, y0, x1, y1 = x, y, x + w, y + h
|
| 99 |
-
|
| 100 |
-
# ⚠️ The Fix: Ensure values never exceed 1000 or drop below 0
|
| 101 |
-
normalized_boxes.append([
|
| 102 |
-
max(0, min(1000, int(1000 * (x0 / width)))),
|
| 103 |
-
max(0, min(1000, int(1000 * (y0 / height)))),
|
| 104 |
-
max(0, min(1000, int(1000 * (x1 / width)))),
|
| 105 |
-
max(0, min(1000, int(1000 * (y1 / height)))),
|
| 106 |
-
])
|
| 107 |
-
|
| 108 |
-
# 3. Inference
|
| 109 |
encoding = PROCESSOR(
|
| 110 |
image, text=words, boxes=normalized_boxes,
|
| 111 |
truncation=True, max_length=512, return_tensors="pt"
|
|
@@ -117,7 +227,7 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
| 117 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 118 |
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
| 119 |
|
| 120 |
-
#
|
| 121 |
final_output = {
|
| 122 |
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
|
| 123 |
"date": extracted_entities.get("DATE", {}).get("text"),
|
|
@@ -130,6 +240,20 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
| 130 |
"raw_predictions": extracted_entities # Contains text and bbox data for each entity
|
| 131 |
}
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# Fallbacks
|
| 134 |
ml_total = extracted_entities.get("TOTAL", {}).get("text")
|
| 135 |
if ml_total:
|
|
@@ -144,5 +268,29 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
| 144 |
|
| 145 |
if not final_output["receipt_number"]:
|
| 146 |
final_output["receipt_number"] = extract_invoice_number(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
return final_output
|
|
|
|
| 5 |
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
from PIL import Image
|
| 8 |
+
from typing import List, Dict, Any, Tuple
|
|
|
|
| 9 |
import re
|
| 10 |
import numpy as np
|
| 11 |
from extraction import extract_invoice_number, extract_total
|
| 12 |
+
from doctr.io import DocumentFile
|
| 13 |
+
from doctr.models import ocr_predictor
|
| 14 |
|
| 15 |
# --- CONFIGURATION ---
|
| 16 |
+
LOCAL_MODEL_PATH = "./models/layoutlmv3-doctr-trained"
|
| 17 |
+
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-doctr-invoice-processor"
|
| 18 |
|
| 19 |
+
# --- Load LayoutLMv3 Model ---
|
| 20 |
def load_model_and_processor(model_path, hub_id):
|
| 21 |
print("Loading processor from microsoft/layoutlmv3-base...")
|
| 22 |
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
|
|
|
| 33 |
|
| 34 |
return model, processor
|
| 35 |
|
| 36 |
+
# --- Load DocTR OCR Predictor ---
|
| 37 |
+
def load_doctr_predictor():
|
| 38 |
+
"""Initialize DocTR predictor and move to GPU for speed."""
|
| 39 |
+
print("Loading DocTR OCR predictor...")
|
| 40 |
+
predictor = ocr_predictor(
|
| 41 |
+
det_arch='db_resnet50',
|
| 42 |
+
reco_arch='crnn_vgg16_bn',
|
| 43 |
+
pretrained=True
|
| 44 |
+
)
|
| 45 |
+
if torch.cuda.is_available():
|
| 46 |
+
print("🚀 Moving DocTR to GPU (CUDA)...")
|
| 47 |
+
predictor.cuda()
|
| 48 |
+
else:
|
| 49 |
+
print("⚠️ GPU not found. Running on CPU (slow).")
|
| 50 |
+
|
| 51 |
+
print("DocTR OCR predictor is ready.")
|
| 52 |
+
return predictor
|
| 53 |
+
|
| 54 |
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
|
| 55 |
+
DOCTR_PREDICTOR = load_doctr_predictor()
|
| 56 |
|
| 57 |
if MODEL and PROCESSOR:
|
| 58 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 63 |
DEVICE = None
|
| 64 |
print("❌ Could not load ML model.")
|
| 65 |
|
| 66 |
+
|
| 67 |
+
def parse_doctr_output(doctr_result, img_width: int, img_height: int) -> Tuple[List[str], List[List[int]], List[List[int]]]:
|
| 68 |
+
"""
|
| 69 |
+
Parse DocTR's hierarchical output (Page -> Block -> Line -> Word)
|
| 70 |
+
into flat lists of words and bounding boxes for LayoutLMv3.
|
| 71 |
+
|
| 72 |
+
DocTR returns coordinates in 0-1.0 scale (relative to image).
|
| 73 |
+
We convert to:
|
| 74 |
+
- unnormalized_boxes: pixel coordinates [x, y, width, height] for visualization
|
| 75 |
+
- normalized_boxes: 0-1000 scale [x0, y0, x1, y1] for LayoutLMv3
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
doctr_result: Output from DocTR predictor
|
| 79 |
+
img_width: Original image width in pixels
|
| 80 |
+
img_height: Original image height in pixels
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
words: List of word strings
|
| 84 |
+
unnormalized_boxes: List of [x, y, width, height] in pixel coordinates
|
| 85 |
+
normalized_boxes: List of [x0, y0, x1, y1] in 0-1000 scale
|
| 86 |
+
"""
|
| 87 |
+
words = []
|
| 88 |
+
unnormalized_boxes = []
|
| 89 |
+
normalized_boxes = []
|
| 90 |
+
|
| 91 |
+
# DocTR hierarchy: Document -> Page -> Block -> Line -> Word
|
| 92 |
+
for page in doctr_result.pages:
|
| 93 |
+
for block in page.blocks:
|
| 94 |
+
for line in block.lines:
|
| 95 |
+
for word in line.words:
|
| 96 |
+
# Skip empty words
|
| 97 |
+
if not word.value.strip():
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
words.append(word.value)
|
| 101 |
+
|
| 102 |
+
# DocTR bbox format: ((x_min, y_min), (x_max, y_max)) in 0-1 scale
|
| 103 |
+
(x_min, y_min), (x_max, y_max) = word.geometry
|
| 104 |
+
|
| 105 |
+
# Convert to pixel coordinates for visualization
|
| 106 |
+
px_x0 = int(x_min * img_width)
|
| 107 |
+
px_y0 = int(y_min * img_height)
|
| 108 |
+
px_x1 = int(x_max * img_width)
|
| 109 |
+
px_y1 = int(y_max * img_height)
|
| 110 |
+
|
| 111 |
+
# Unnormalized box: [x, y, width, height] for visualization overlay
|
| 112 |
+
unnormalized_boxes.append([
|
| 113 |
+
px_x0,
|
| 114 |
+
px_y0,
|
| 115 |
+
px_x1 - px_x0, # width
|
| 116 |
+
px_y1 - px_y0 # height
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
# Normalized box: [x0, y0, x1, y1] in 0-1000 scale for LayoutLMv3
|
| 120 |
+
# Clamp values to ensure they stay within [0, 1000]
|
| 121 |
+
normalized_boxes.append([
|
| 122 |
+
max(0, min(1000, int(x_min * 1000))),
|
| 123 |
+
max(0, min(1000, int(y_min * 1000))),
|
| 124 |
+
max(0, min(1000, int(x_max * 1000))),
|
| 125 |
+
max(0, min(1000, int(y_max * 1000))),
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
return words, unnormalized_boxes, normalized_boxes
|
| 129 |
+
|
| 130 |
+
|
| 131 |
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
|
| 132 |
word_ids = encoding.word_ids(batch_index=0)
|
| 133 |
word_level_preds = {}
|
|
|
|
| 155 |
|
| 156 |
return entities
|
| 157 |
|
| 158 |
+
|
| 159 |
def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
| 160 |
if not MODEL or not PROCESSOR:
|
| 161 |
raise RuntimeError("ML model is not loaded.")
|
|
|
|
| 163 |
# 1. Load Image
|
| 164 |
image = Image.open(image_path).convert("RGB")
|
| 165 |
width, height = image.size
|
|
|
|
| 166 |
|
| 167 |
+
# 2. Run DocTR OCR
|
| 168 |
+
doc = DocumentFile.from_images(image_path)
|
| 169 |
+
doctr_result = DOCTR_PREDICTOR(doc)
|
| 170 |
+
|
| 171 |
+
# 3. Parse DocTR output to get words and boxes
|
| 172 |
+
words, unnormalized_boxes, normalized_boxes = parse_doctr_output(
|
| 173 |
+
doctr_result, width, height
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Reconstructs lines so regex can work line-by-line
|
| 177 |
+
lines = []
|
| 178 |
+
current_line = []
|
| 179 |
+
|
| 180 |
+
if len(unnormalized_boxes) > 0:
|
| 181 |
+
# Initialize with first word's Y and Height
|
| 182 |
+
current_y = unnormalized_boxes[0][1]
|
| 183 |
+
current_h = unnormalized_boxes[0][3]
|
| 184 |
+
|
| 185 |
+
for i, word in enumerate(words):
|
| 186 |
+
y = unnormalized_boxes[i][1]
|
| 187 |
+
h = unnormalized_boxes[i][3]
|
| 188 |
|
| 189 |
+
# If vertical gap > 50% of line height, it's a new line
|
| 190 |
+
if abs(y - current_y) > max(current_h, h) / 2:
|
| 191 |
+
lines.append(" ".join(current_line))
|
| 192 |
+
current_line = []
|
| 193 |
+
current_y = y
|
| 194 |
+
current_h = h
|
| 195 |
+
|
| 196 |
+
current_line.append(word)
|
| 197 |
+
|
| 198 |
+
# Append the last line
|
| 199 |
+
if current_line:
|
| 200 |
+
lines.append(" ".join(current_line))
|
| 201 |
+
|
| 202 |
+
raw_text = "\n".join(lines)
|
| 203 |
+
|
| 204 |
+
# Handle empty OCR result
|
| 205 |
+
if not words:
|
| 206 |
+
return {
|
| 207 |
+
"vendor": None,
|
| 208 |
+
"date": None,
|
| 209 |
+
"address": None,
|
| 210 |
+
"receipt_number": None,
|
| 211 |
+
"bill_to": None,
|
| 212 |
+
"total_amount": None,
|
| 213 |
+
"items": [],
|
| 214 |
+
"raw_text": "",
|
| 215 |
+
"raw_predictions": {}
|
| 216 |
+
}
|
| 217 |
|
| 218 |
+
# 4. Inference with LayoutLMv3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
encoding = PROCESSOR(
|
| 220 |
image, text=words, boxes=normalized_boxes,
|
| 221 |
truncation=True, max_length=512, return_tensors="pt"
|
|
|
|
| 227 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 228 |
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
| 229 |
|
| 230 |
+
# 5. Construct Output
|
| 231 |
final_output = {
|
| 232 |
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
|
| 233 |
"date": extracted_entities.get("DATE", {}).get("text"),
|
|
|
|
| 240 |
"raw_predictions": extracted_entities # Contains text and bbox data for each entity
|
| 241 |
}
|
| 242 |
|
| 243 |
+
# 6. Vendor Fallback (Spatial Heuristic)
|
| 244 |
+
# If ML failed to find a vendor, assume the largest text at the top is the vendor
|
| 245 |
+
if not final_output["vendor"] and unnormalized_boxes:
|
| 246 |
+
# Filter for words in the top 20% of the image
|
| 247 |
+
top_words_indices = [
|
| 248 |
+
i for i, box in enumerate(unnormalized_boxes)
|
| 249 |
+
if box[1] < height * 0.2
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
if top_words_indices:
|
| 253 |
+
# Find the word with the largest height (font size)
|
| 254 |
+
largest_idx = max(top_words_indices, key=lambda i: unnormalized_boxes[i][3])
|
| 255 |
+
final_output["vendor"] = words[largest_idx]
|
| 256 |
+
|
| 257 |
# Fallbacks
|
| 258 |
ml_total = extracted_entities.get("TOTAL", {}).get("text")
|
| 259 |
if ml_total:
|
|
|
|
| 268 |
|
| 269 |
if not final_output["receipt_number"]:
|
| 270 |
final_output["receipt_number"] = extract_invoice_number(raw_text)
|
| 271 |
+
|
| 272 |
+
# Backfill Bounding Boxes for Regex Results
|
| 273 |
+
# If Regex found the number but ML didn't, we must find its box
|
| 274 |
+
# in the OCR data so the UI can draw it.
|
| 275 |
+
|
| 276 |
+
if final_output["receipt_number"] and "INVOICE_NO" not in final_output["raw_predictions"]:
|
| 277 |
+
target_val = final_output["receipt_number"].strip()
|
| 278 |
+
found_box = None
|
| 279 |
+
|
| 280 |
+
# 1. Try finding the exact word in the OCR list
|
| 281 |
+
# 'words' and 'unnormalized_boxes' are available from step 3
|
| 282 |
+
for i, word in enumerate(words):
|
| 283 |
+
# Check for exact match or if the word contains the target (e.g. "Inv#123")
|
| 284 |
+
if target_val == word or (len(target_val) > 3 and target_val in word):
|
| 285 |
+
found_box = unnormalized_boxes[i]
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
# 2. If found, inject it into raw_predictions
|
| 289 |
+
if found_box:
|
| 290 |
+
# The UI expects a list of boxes
|
| 291 |
+
final_output["raw_predictions"]["INVOICE_NO"] = {
|
| 292 |
+
"text": target_val,
|
| 293 |
+
"bbox": [found_box]
|
| 294 |
+
}
|
| 295 |
|
| 296 |
return final_output
|
src/ocr.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
# src/ocr.py
|
| 2 |
-
|
| 3 |
-
import pytesseract
|
| 4 |
-
import numpy as np
|
| 5 |
-
import os
|
| 6 |
-
import shutil
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
# --- Dynamic Tesseract Configuration ---
|
| 10 |
-
# This block ensures the code runs on both Windows (Local) and Linux (Production)
|
| 11 |
-
if os.name == 'nt': # Windows
|
| 12 |
-
# Common default installation paths for Windows
|
| 13 |
-
possible_paths = [
|
| 14 |
-
r'C:\Program Files\Tesseract-OCR\tesseract.exe',
|
| 15 |
-
r'C:\Program Files (x86)\Tesseract-OCR\tesseract.exe',
|
| 16 |
-
r'C:\Users\{}\AppData\Local\Tesseract-OCR\tesseract.exe'.format(os.getlogin())
|
| 17 |
-
]
|
| 18 |
-
|
| 19 |
-
# Search for the executable
|
| 20 |
-
found = False
|
| 21 |
-
for path in possible_paths:
|
| 22 |
-
if os.path.exists(path):
|
| 23 |
-
pytesseract.pytesseract.tesseract_cmd = path
|
| 24 |
-
found = True
|
| 25 |
-
print(f"✅ Found Tesseract at: {path}")
|
| 26 |
-
break
|
| 27 |
-
|
| 28 |
-
if not found:
|
| 29 |
-
print("⚠️ Warning: Tesseract exe not found in standard paths. Assuming it's in system PATH.")
|
| 30 |
-
else:
|
| 31 |
-
# Linux/Mac (Docker/Production)
|
| 32 |
-
if not shutil.which('tesseract'):
|
| 33 |
-
print("⚠️ Warning: 'tesseract' binary not found in PATH. Please install tesseract-ocr.")
|
| 34 |
-
|
| 35 |
-
def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
|
| 36 |
-
if image is None:
|
| 37 |
-
raise ValueError("Input image is None")
|
| 38 |
-
# Pytesseract will now use the path found above (or default to PATH)
|
| 39 |
-
return pytesseract.image_to_string(image, lang=lang, config=config).strip()
|
| 40 |
-
|
| 41 |
-
def extract_text_with_boxes(image):
|
| 42 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/pipeline.py
CHANGED
|
@@ -13,7 +13,6 @@ import cv2
|
|
| 13 |
|
| 14 |
# --- IMPORTS ---
|
| 15 |
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 16 |
-
from ocr import extract_text
|
| 17 |
from extraction import structure_output
|
| 18 |
from ml_extraction import extract_ml_based
|
| 19 |
from schema import InvoiceData
|
|
@@ -90,13 +89,10 @@ def process_invoice(image_path: str,
|
|
| 90 |
|
| 91 |
elif method == 'rules':
|
| 92 |
try:
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
preprocessed_image = remove_noise(gray_image, kernel_size=3)
|
| 96 |
-
text = extract_text(preprocessed_image, config='--psm 6')
|
| 97 |
-
raw_result = structure_output(text)
|
| 98 |
except Exception as e:
|
| 99 |
-
raise ValueError(f"Error during
|
| 100 |
|
| 101 |
# Clean up temp file if we created one
|
| 102 |
if image_path.endswith('.jpg') and 'sample_pdf' in image_path: # Safety check
|
|
|
|
| 13 |
|
| 14 |
# --- IMPORTS ---
|
| 15 |
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
|
|
|
| 16 |
from extraction import structure_output
|
| 17 |
from ml_extraction import extract_ml_based
|
| 18 |
from schema import InvoiceData
|
|
|
|
| 89 |
|
| 90 |
elif method == 'rules':
|
| 91 |
try:
|
| 92 |
+
print("⚠️ Rule-based mode is deprecated. Redirecting to ML-based extraction.")
|
| 93 |
+
raw_result = extract_ml_based(image_path)
|
|
|
|
|
|
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
+
raise ValueError(f"Error during ML-based extraction: {e}")
|
| 96 |
|
| 97 |
# Clean up temp file if we created one
|
| 98 |
if image_path.endswith('.jpg') and 'sample_pdf' in image_path: # Safety check
|