Spaces:
Sleeping
Sleeping
Commit ·
d79b7f7
1
Parent(s): 22fe020
feat: Add Phase 3 generalization scripts and clean up legacy files
Browse files- .gitignore +73 -70
- LICENSE +21 -21
- README.md +322 -322
- app.py +312 -312
- eval_new_dataset.py +42 -0
- explore_new_dataset.py +113 -0
- load_sroie_dataset.py +65 -0
- notebooks/test_setup.py +0 -11
- notebooks/test_visual.ipynb +0 -0
- requirements.txt +0 -0
- src/data_loader.py +197 -0
- src/extraction.py +123 -273
- src/ml_extraction.py +143 -175
- src/ocr.py +15 -15
- src/pipeline.py +150 -150
- src/preprocessing.py +78 -78
- tests/test_extraction.py +40 -40
- tests/test_full_pipeline.py +41 -41
- tests/test_ocr.py +100 -100
- tests/test_pipeline.py +95 -95
- tests/test_preprocessing.py +177 -177
- tests/utils.py +6 -6
- train_combined.py +187 -0
- train_layoutlm.py +185 -0
.gitignore
CHANGED
|
@@ -1,70 +1,73 @@
|
|
| 1 |
-
# Python
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.pyc
|
| 4 |
-
*.pyo
|
| 5 |
-
*.pyd
|
| 6 |
-
|
| 7 |
-
# Environment
|
| 8 |
-
env/
|
| 9 |
-
venv/
|
| 10 |
-
.env
|
| 11 |
-
config.yaml
|
| 12 |
-
credentials.json
|
| 13 |
-
|
| 14 |
-
# IDE / Editor
|
| 15 |
-
.vscode/
|
| 16 |
-
.idea/
|
| 17 |
-
*.swp
|
| 18 |
-
*.swo
|
| 19 |
-
|
| 20 |
-
# Notebooks / caches / logs
|
| 21 |
-
.ipynb_checkpoints/
|
| 22 |
-
.pytest_cache/
|
| 23 |
-
*.log
|
| 24 |
-
logs/
|
| 25 |
-
.cache/
|
| 26 |
-
|
| 27 |
-
# OS
|
| 28 |
-
.DS_Store
|
| 29 |
-
Thumbs.db
|
| 30 |
-
ehthumbs.db
|
| 31 |
-
*.code-workspace
|
| 32 |
-
Desktop.ini
|
| 33 |
-
|
| 34 |
-
# Streamlit temp folder
|
| 35 |
-
temp/
|
| 36 |
-
.streamlit/
|
| 37 |
-
|
| 38 |
-
# Jupyter Notebook
|
| 39 |
-
.ipynb_checkpoints
|
| 40 |
-
|
| 41 |
-
# JSON outputs
|
| 42 |
-
outputs/
|
| 43 |
-
|
| 44 |
-
# Logs
|
| 45 |
-
logs/
|
| 46 |
-
*.log
|
| 47 |
-
|
| 48 |
-
# --- Data Folders ---
|
| 49 |
-
# Ignore all files inside the raw and processed data folders
|
| 50 |
-
data/raw/*
|
| 51 |
-
data/processed/*
|
| 52 |
-
|
| 53 |
-
# But DO NOT ignore the .gitkeep files inside them
|
| 54 |
-
!data/raw/.gitkeep
|
| 55 |
-
!data/processed/.gitkeep
|
| 56 |
-
|
| 57 |
-
!requirements.txt
|
| 58 |
-
!README.md
|
| 59 |
-
|
| 60 |
-
datasets/
|
| 61 |
-
checkpoints/
|
| 62 |
-
lightning_logs/
|
| 63 |
-
wandb/
|
| 64 |
-
mlruns/
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Ignore all files in the models directory
|
| 68 |
-
models/*
|
| 69 |
-
!models/.gitkeep
|
| 70 |
-
!models/README.md
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
|
| 7 |
+
# Environment
|
| 8 |
+
env/
|
| 9 |
+
venv/
|
| 10 |
+
.env
|
| 11 |
+
config.yaml
|
| 12 |
+
credentials.json
|
| 13 |
+
|
| 14 |
+
# IDE / Editor
|
| 15 |
+
.vscode/
|
| 16 |
+
.idea/
|
| 17 |
+
*.swp
|
| 18 |
+
*.swo
|
| 19 |
+
|
| 20 |
+
# Notebooks / caches / logs
|
| 21 |
+
.ipynb_checkpoints/
|
| 22 |
+
.pytest_cache/
|
| 23 |
+
*.log
|
| 24 |
+
logs/
|
| 25 |
+
.cache/
|
| 26 |
+
|
| 27 |
+
# OS
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
ehthumbs.db
|
| 31 |
+
*.code-workspace
|
| 32 |
+
Desktop.ini
|
| 33 |
+
|
| 34 |
+
# Streamlit temp folder
|
| 35 |
+
temp/
|
| 36 |
+
.streamlit/
|
| 37 |
+
|
| 38 |
+
# Jupyter Notebook
|
| 39 |
+
.ipynb_checkpoints
|
| 40 |
+
|
| 41 |
+
# JSON outputs
|
| 42 |
+
outputs/
|
| 43 |
+
|
| 44 |
+
# Logs
|
| 45 |
+
logs/
|
| 46 |
+
*.log
|
| 47 |
+
|
| 48 |
+
# --- Data Folders ---
|
| 49 |
+
# Ignore all files inside the raw and processed data folders
|
| 50 |
+
data/raw/*
|
| 51 |
+
data/processed/*
|
| 52 |
+
|
| 53 |
+
# But DO NOT ignore the .gitkeep files inside them
|
| 54 |
+
!data/raw/.gitkeep
|
| 55 |
+
!data/processed/.gitkeep
|
| 56 |
+
|
| 57 |
+
!requirements.txt
|
| 58 |
+
!README.md
|
| 59 |
+
|
| 60 |
+
datasets/
|
| 61 |
+
checkpoints/
|
| 62 |
+
lightning_logs/
|
| 63 |
+
wandb/
|
| 64 |
+
mlruns/
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Ignore all files in the models directory
|
| 68 |
+
models/*
|
| 69 |
+
!models/.gitkeep
|
| 70 |
+
!models/README.md
|
| 71 |
+
|
| 72 |
+
# Ignore sroie files in the data directory
|
| 73 |
+
data/sroie/
|
LICENSE
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
MIT License
|
| 2 |
-
|
| 3 |
-
Copyright (c) 2025 Soumyajit Ghosh
|
| 4 |
-
|
| 5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
-
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
-
in the Software without restriction, including without limitation the rights
|
| 8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
-
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
-
furnished to do so, subject to the following conditions:
|
| 11 |
-
|
| 12 |
-
The above copyright notice and this permission notice shall be included in all
|
| 13 |
-
copies or substantial portions of the Software.
|
| 14 |
-
|
| 15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
SOFTWARE.
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Soumyajit Ghosh
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,323 +1,323 @@
|
|
| 1 |
-
# 📄 Smart Invoice Processor
|
| 2 |
-
|
| 3 |
-
End-to-end invoice/receipt processing with OCR + Rule-based extraction and a fine‑tuned LayoutLMv3 model. Upload an image or run via CLI to get clean, structured JSON (vendor, date, totals, address, etc.).
|
| 4 |
-
|
| 5 |
-

|
| 6 |
-

|
| 7 |
-

|
| 8 |
-

|
| 9 |
-

|
| 10 |
-
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
## 🎯 Features
|
| 14 |
-
|
| 15 |
-
- ✅ OCR using Tesseract (configurable, fast, multi-platform)
|
| 16 |
-
- ✅ Rule-based extraction (regex baselines)
|
| 17 |
-
- ✅ ML-based extraction (LayoutLMv3 fine‑tuned on SROIE) for robust field detection
|
| 18 |
-
- ✅ Clean JSON output (date, total, vendor, address, receipt number*)
|
| 19 |
-
- ✅ Confidence and simple validation (e.g., total found among amounts)
|
| 20 |
-
- ✅ Streamlit web UI with method toggle (ML vs Regex)
|
| 21 |
-
- ✅ CLI for single/batch processing with saving to JSON
|
| 22 |
-
- ✅ Tests for preprocessing/OCR/pipeline
|
| 23 |
-
|
| 24 |
-
> Note: SROIE does not include invoice/receipt number labels; the ML model won’t output it unless you add labeled data. The rule-based extractor can still provide it when formats allow.
|
| 25 |
-
|
| 26 |
-
---
|
| 27 |
-
|
| 28 |
-
## 📊 Demo
|
| 29 |
-
|
| 30 |
-
### Web Interface
|
| 31 |
-

|
| 32 |
-
*Clean upload → extract flow with method selector (ML vs Regex).*
|
| 33 |
-
|
| 34 |
-
### Successful Extraction (ML-based)
|
| 35 |
-

|
| 36 |
-
*Fields extracted with LayoutLMv3.*
|
| 37 |
-
|
| 38 |
-
### Format Detection (simulated)
|
| 39 |
-

|
| 40 |
-
*UI shows simple format hints and confidence.*
|
| 41 |
-
|
| 42 |
-
### Example JSON (Rule-based)
|
| 43 |
-
```json
|
| 44 |
-
{
|
| 45 |
-
"receipt_number": "PEGIV-1030765",
|
| 46 |
-
"date": "15/01/2019",
|
| 47 |
-
"bill_to": {
|
| 48 |
-
"name": "THE PEAK QUARRY WORKS",
|
| 49 |
-
"email": null
|
| 50 |
-
},
|
| 51 |
-
"items": [],
|
| 52 |
-
"total_amount": 193.0,
|
| 53 |
-
"extraction_confidence": 100,
|
| 54 |
-
"validation_passed": true,
|
| 55 |
-
"vendor": "OJC MARKETING SDN BHD",
|
| 56 |
-
"address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR"
|
| 57 |
-
}
|
| 58 |
-
```
|
| 59 |
-
### Example JSON (ML-based)
|
| 60 |
-
```json
|
| 61 |
-
{
|
| 62 |
-
"receipt_number": null,
|
| 63 |
-
"date": "15/01/2019",
|
| 64 |
-
"bill_to": null,
|
| 65 |
-
"items": [],
|
| 66 |
-
"total_amount": 193.0,
|
| 67 |
-
"vendor": "OJC MARKETING SDN BHD",
|
| 68 |
-
"address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR",
|
| 69 |
-
"raw_text": "…",
|
| 70 |
-
"raw_ocr_words": ["…"],
|
| 71 |
-
"raw_predictions": {
|
| 72 |
-
"DATE": {"text": "15/01/2019", "bbox": [[…]]},
|
| 73 |
-
"TOTAL": {"text": "193.00", "bbox": [[…]]},
|
| 74 |
-
"COMPANY": {"text": "OJC MARKETING SDN BHD", "bbox": [[…]]},
|
| 75 |
-
"ADDRESS": {"text": "…", "bbox": [[…]]}
|
| 76 |
-
}
|
| 77 |
-
}
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
## 🚀 Quick Start
|
| 81 |
-
|
| 82 |
-
### Prerequisites
|
| 83 |
-
- Python 3.10+
|
| 84 |
-
- Tesseract OCR
|
| 85 |
-
- (Optional) CUDA-capable GPU for training/inference speed
|
| 86 |
-
|
| 87 |
-
### Installation
|
| 88 |
-
|
| 89 |
-
1. Clone the repository
|
| 90 |
-
```bash
|
| 91 |
-
git clone https://github.com/GSoumyajit2005/invoice-processor-ml
|
| 92 |
-
cd invoice-processor-ml
|
| 93 |
-
```
|
| 94 |
-
|
| 95 |
-
2. Install dependencies
|
| 96 |
-
```bash
|
| 97 |
-
pip install -r requirements.txt
|
| 98 |
-
```
|
| 99 |
-
|
| 100 |
-
3. Install Tesseract OCR
|
| 101 |
-
- **Windows**: Download from [UB Mannheim](https://github.com/UB-Mannheim/tesseract/wiki)
|
| 102 |
-
- **Mac**: `brew install tesseract`
|
| 103 |
-
- **Linux**: `sudo apt install tesseract-ocr`
|
| 104 |
-
|
| 105 |
-
4. (Optional, Windows) Set Tesseract path in src/ocr.py if needed:
|
| 106 |
-
```bash
|
| 107 |
-
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
|
| 108 |
-
```
|
| 109 |
-
|
| 110 |
-
5. Run the web app
|
| 111 |
-
```bash
|
| 112 |
-
streamlit run app.py
|
| 113 |
-
```
|
| 114 |
-
|
| 115 |
-
## 💻 Usage
|
| 116 |
-
|
| 117 |
-
### Web Interface (Recommended)
|
| 118 |
-
|
| 119 |
-
The easiest way to use the processor is via the web interface.
|
| 120 |
-
|
| 121 |
-
```bash
|
| 122 |
-
streamlit run app.py
|
| 123 |
-
```
|
| 124 |
-
- Upload an invoice image (PNG/JPG).
|
| 125 |
-
- Choose extraction method in sidebar:
|
| 126 |
-
- ML-Based (LayoutLMv3)
|
| 127 |
-
- Rule-Based (Regex)
|
| 128 |
-
- View JSON, download results.
|
| 129 |
-
|
| 130 |
-
### Command-Line Interface (CLI)
|
| 131 |
-
|
| 132 |
-
You can also process invoices directly from the command line.
|
| 133 |
-
|
| 134 |
-
#### 1. Processing a Single Invoice
|
| 135 |
-
|
| 136 |
-
This command processes the provided sample invoice and prints the results to the console.
|
| 137 |
-
|
| 138 |
-
```bash
|
| 139 |
-
python src/pipeline.py data/samples/sample_invoice.jpg --save --method ml
|
| 140 |
-
# or
|
| 141 |
-
python src/pipeline.py data/samples/sample_invoice.jpg --save --method rules
|
| 142 |
-
```
|
| 143 |
-
|
| 144 |
-
#### 2. Batch Processing a Folder
|
| 145 |
-
|
| 146 |
-
The CLI can process an entire folder of images at once.
|
| 147 |
-
|
| 148 |
-
First, place your own invoice images (e.g., `my_invoice1.jpg`, `my_invoice2.png`) into the `data/raw/` folder.
|
| 149 |
-
|
| 150 |
-
Then, run the following command. It will process all images in `data/raw/`. Saved files are written to `outputs/{stem}_{method}.json`.
|
| 151 |
-
|
| 152 |
-
```bash
|
| 153 |
-
python src/pipeline.py data/raw --save --method ml
|
| 154 |
-
```
|
| 155 |
-
|
| 156 |
-
### Python API
|
| 157 |
-
|
| 158 |
-
You can integrate the pipeline directly into your own Python scripts.
|
| 159 |
-
|
| 160 |
-
```python
|
| 161 |
-
from src.pipeline import process_invoice
|
| 162 |
-
import json
|
| 163 |
-
|
| 164 |
-
result = process_invoice('data/samples/sample_invoice.jpg', method='ml')
|
| 165 |
-
print(json.dumps(result, indent=2))
|
| 166 |
-
```
|
| 167 |
-
|
| 168 |
-
## 🏗️ Architecture
|
| 169 |
-
|
| 170 |
-
```
|
| 171 |
-
┌────────────────┐
|
| 172 |
-
│ Upload Image │
|
| 173 |
-
└───────┬────────┘
|
| 174 |
-
│
|
| 175 |
-
▼
|
| 176 |
-
┌────────────────────┐
|
| 177 |
-
│ Preprocessing │ (OpenCV grayscale/denoise)
|
| 178 |
-
└────────┬───────────┘
|
| 179 |
-
│
|
| 180 |
-
▼
|
| 181 |
-
┌───────────────┐
|
| 182 |
-
│ OCR │ (Tesseract)
|
| 183 |
-
└───────┬───────┘
|
| 184 |
-
│
|
| 185 |
-
┌──────────────┴──────────────┐
|
| 186 |
-
│ │
|
| 187 |
-
▼ ▼
|
| 188 |
-
┌──────────────────┐ ┌────────────────────────┐
|
| 189 |
-
│ Rule-based IE │ │ ML-based IE (NER) │
|
| 190 |
-
│ (regex, heur.) │ │ LayoutLMv3 token-class │
|
| 191 |
-
└────────┬─────────┘ └───────────┬────────────┘
|
| 192 |
-
│ │
|
| 193 |
-
└──────────────┬──────────────────┘
|
| 194 |
-
▼
|
| 195 |
-
┌──────────────────┐
|
| 196 |
-
│ Post-process │
|
| 197 |
-
│ validate, scores │
|
| 198 |
-
└────────┬─────────┘
|
| 199 |
-
▼
|
| 200 |
-
┌──────────────────┐
|
| 201 |
-
│ JSON Output │
|
| 202 |
-
└──────────────────┘
|
| 203 |
-
```
|
| 204 |
-
|
| 205 |
-
## 📁 Project Structure
|
| 206 |
-
|
| 207 |
-
```
|
| 208 |
-
invoice-processor-ml/
|
| 209 |
-
│
|
| 210 |
-
├── data/
|
| 211 |
-
│ ├── raw/ # Input invoice images for processing
|
| 212 |
-
│ └── processed/ # (Reserved for future use)
|
| 213 |
-
│
|
| 214 |
-
│
|
| 215 |
-
├── data/samples/
|
| 216 |
-
│ └── sample_invoice.jpg # Public sample for quick testing
|
| 217 |
-
│
|
| 218 |
-
├── docs/
|
| 219 |
-
│ └── screenshots/ # UI Screenshots for the README demo
|
| 220 |
-
│
|
| 221 |
-
│
|
| 222 |
-
├── models/
|
| 223 |
-
│ └── layoutlmv3-sroie-best/ # Fine-tuned model (created after training)
|
| 224 |
-
│
|
| 225 |
-
├── outputs/ # Default folder for saved JSON results
|
| 226 |
-
│
|
| 227 |
-
├── src/
|
| 228 |
-
│ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
|
| 229 |
-
│ ├── ocr.py # Tesseract OCR integration
|
| 230 |
-
│ ├── extraction.py # Regex-based information extraction logic
|
| 231 |
-
│ ├── ml_extraction.py # ML-based extraction (LayoutLMv3)
|
| 232 |
-
│ └── pipeline.py # Main orchestrator for the pipeline and CLI
|
| 233 |
-
│
|
| 234 |
-
│
|
| 235 |
-
├── tests/ # <-- ADD THIS FOLDER
|
| 236 |
-
│ ├── test_preprocessing.py # Tests for the preprocessing module
|
| 237 |
-
│ ├── test_ocr.py # Tests for the OCR module
|
| 238 |
-
│ └── test_pipeline.py # End-to-end pipeline tests
|
| 239 |
-
│
|
| 240 |
-
├── app.py # Streamlit web interface
|
| 241 |
-
├── requirements.txt # Python dependencies
|
| 242 |
-
└── README.md # You are Here!
|
| 243 |
-
```
|
| 244 |
-
|
| 245 |
-
## 🧠 Model & Training
|
| 246 |
-
|
| 247 |
-
- **Model**: `microsoft/layoutlmv3-base` (125M params)
|
| 248 |
-
- **Task**: Token Classification (NER) with 9 labels: `O, B/I-COMPANY, B/I-ADDRESS, B/I-DATE, B/I-TOTAL`
|
| 249 |
-
- **Dataset**: SROIE (ICDAR 2019, English retail receipts)
|
| 250 |
-
- **Training**: RTX 3050 6GB, PyTorch 2.x, Transformers 4.x
|
| 251 |
-
- **Result**: Best F1 ≈ 0.922 on validation (epoch 5 saved)
|
| 252 |
-
|
| 253 |
-
- Training scripts(local):
|
| 254 |
-
- `train_layoutlm.py` (data prep, training loop with validation + model save)
|
| 255 |
-
- Model saved to: `models/layoutlmv3-sroie-best/`
|
| 256 |
-
|
| 257 |
-
## 📈 Performance
|
| 258 |
-
|
| 259 |
-
- **OCR accuracy (clear images)**: High with Tesseract
|
| 260 |
-
- **Rule-based extraction**: Strong on simple retail receipts
|
| 261 |
-
- **ML-based extraction (SROIE-style)**:
|
| 262 |
-
- COMPANY / ADDRESS / DATE / TOTAL: High F1 on simple receipts
|
| 263 |
-
- Complex business invoices: Partial extraction unless further fine-tuned
|
| 264 |
-
|
| 265 |
-
## ⚠️ Known Limitations
|
| 266 |
-
|
| 267 |
-
1. **Layout Sensitivity**: The ML model was fine‑tuned only on SROIE (retail receipts). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
|
| 268 |
-
2. **Invoice Number (ML)**: SROIE lacks invoice number labels; the ML model won’t output it unless you add labeled data. The rule-based method can still recover it on many formats.
|
| 269 |
-
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 270 |
-
4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
|
| 271 |
-
|
| 272 |
-
## 🔮 Future Enhancements
|
| 273 |
-
|
| 274 |
-
- [ ] Add and fine‑tune on mychen76/invoices-and-receipts_ocr_v1 (English) for broader invoice formats
|
| 275 |
-
- [ ] (Optional) Add FATURA (table-focused) for line-item extraction
|
| 276 |
-
- [ ] Sliding-window chunking for >512 token documents (to avoid truncation)
|
| 277 |
-
- [ ] Table detection (Camelot/Tabula/DeepDeSRT) for line items
|
| 278 |
-
- [ ] PDF support (pdf2image) for multipage invoices
|
| 279 |
-
- [ ] FastAPI backend + Docker
|
| 280 |
-
- [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
|
| 281 |
-
- [ ] Confidence calibration and better validation rules
|
| 282 |
-
|
| 283 |
-
## 🛠️ Tech Stack
|
| 284 |
-
|
| 285 |
-
| Component | Technology |
|
| 286 |
-
|-----------|------------|
|
| 287 |
-
| OCR | Tesseract 5.0+ |
|
| 288 |
-
| Image Processing | OpenCV, Pillow |
|
| 289 |
-
| ML/NLP | PyTorch 2.x, Transformers |
|
| 290 |
-
| Model | LayoutLMv3 (token class.) |
|
| 291 |
-
| Web Interface | Streamlit |
|
| 292 |
-
| Data Format | JSON |
|
| 293 |
-
|
| 294 |
-
## 📚 What I Learned
|
| 295 |
-
|
| 296 |
-
- OCR challenges (confusable characters, confidence-based filtering)
|
| 297 |
-
- Layout-aware NER with LayoutLMv3 (text + bbox + pixels)
|
| 298 |
-
- Data normalization (bbox to 0–1000 scale)
|
| 299 |
-
- End-to-end pipelines (UI + CLI + JSON output)
|
| 300 |
-
- When regex is enough vs when ML is needed
|
| 301 |
-
- Evaluation (seqeval F1 for NER)
|
| 302 |
-
|
| 303 |
-
## 🤝 Contributing
|
| 304 |
-
|
| 305 |
-
Contributions welcome! Areas needing improvement:
|
| 306 |
-
- New patterns for regex extractor
|
| 307 |
-
- Better preprocessing for OCR
|
| 308 |
-
- New datasets and training configs
|
| 309 |
-
- Tests and CI
|
| 310 |
-
|
| 311 |
-
## 📝 License
|
| 312 |
-
|
| 313 |
-
MIT License - See LICENSE file for details
|
| 314 |
-
|
| 315 |
-
## 👨💻 Author
|
| 316 |
-
|
| 317 |
-
**Soumyajit Ghosh** - 3rd Year BTech Student
|
| 318 |
-
- Exploring AI/ML and practical applications
|
| 319 |
-
- [LinkedIn](https://www.linkedin.com/in/soumyajit-ghosh-49a5b02b2?utm_source=share&utm_campaign) | [GitHub](https://github.com/GSoumyajit2005) | [Portfolio](#)(Coming Soon)
|
| 320 |
-
|
| 321 |
-
---
|
| 322 |
-
|
| 323 |
**Note**: "This is a learning project demonstrating an end-to-end ML pipeline. Not recommended for production use without further validation, retraining on diverse datasets, and security hardening."
|
|
|
|
| 1 |
+
# 📄 Smart Invoice Processor
|
| 2 |
+
|
| 3 |
+
End-to-end invoice/receipt processing with OCR + Rule-based extraction and a fine‑tuned LayoutLMv3 model. Upload an image or run via CLI to get clean, structured JSON (vendor, date, totals, address, etc.).
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+

|
| 7 |
+

|
| 8 |
+

|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 🎯 Features
|
| 14 |
+
|
| 15 |
+
- ✅ OCR using Tesseract (configurable, fast, multi-platform)
|
| 16 |
+
- ✅ Rule-based extraction (regex baselines)
|
| 17 |
+
- ✅ ML-based extraction (LayoutLMv3 fine‑tuned on SROIE) for robust field detection
|
| 18 |
+
- ✅ Clean JSON output (date, total, vendor, address, receipt number*)
|
| 19 |
+
- ✅ Confidence and simple validation (e.g., total found among amounts)
|
| 20 |
+
- ✅ Streamlit web UI with method toggle (ML vs Regex)
|
| 21 |
+
- ✅ CLI for single/batch processing with saving to JSON
|
| 22 |
+
- ✅ Tests for preprocessing/OCR/pipeline
|
| 23 |
+
|
| 24 |
+
> Note: SROIE does not include invoice/receipt number labels; the ML model won’t output it unless you add labeled data. The rule-based extractor can still provide it when formats allow.
|
| 25 |
+
u
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 📊 Demo
|
| 29 |
+
|
| 30 |
+
### Web Interface
|
| 31 |
+

|
| 32 |
+
*Clean upload → extract flow with method selector (ML vs Regex).*
|
| 33 |
+
|
| 34 |
+
### Successful Extraction (ML-based)
|
| 35 |
+

|
| 36 |
+
*Fields extracted with LayoutLMv3.*
|
| 37 |
+
|
| 38 |
+
### Format Detection (simulated)
|
| 39 |
+

|
| 40 |
+
*UI shows simple format hints and confidence.*
|
| 41 |
+
|
| 42 |
+
### Example JSON (Rule-based)
|
| 43 |
+
```json
|
| 44 |
+
{
|
| 45 |
+
"receipt_number": "PEGIV-1030765",
|
| 46 |
+
"date": "15/01/2019",
|
| 47 |
+
"bill_to": {
|
| 48 |
+
"name": "THE PEAK QUARRY WORKS",
|
| 49 |
+
"email": null
|
| 50 |
+
},
|
| 51 |
+
"items": [],
|
| 52 |
+
"total_amount": 193.0,
|
| 53 |
+
"extraction_confidence": 100,
|
| 54 |
+
"validation_passed": true,
|
| 55 |
+
"vendor": "OJC MARKETING SDN BHD",
|
| 56 |
+
"address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR"
|
| 57 |
+
}
|
| 58 |
+
```
|
| 59 |
+
### Example JSON (ML-based)
|
| 60 |
+
```json
|
| 61 |
+
{
|
| 62 |
+
"receipt_number": null,
|
| 63 |
+
"date": "15/01/2019",
|
| 64 |
+
"bill_to": null,
|
| 65 |
+
"items": [],
|
| 66 |
+
"total_amount": 193.0,
|
| 67 |
+
"vendor": "OJC MARKETING SDN BHD",
|
| 68 |
+
"address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR",
|
| 69 |
+
"raw_text": "…",
|
| 70 |
+
"raw_ocr_words": ["…"],
|
| 71 |
+
"raw_predictions": {
|
| 72 |
+
"DATE": {"text": "15/01/2019", "bbox": [[…]]},
|
| 73 |
+
"TOTAL": {"text": "193.00", "bbox": [[…]]},
|
| 74 |
+
"COMPANY": {"text": "OJC MARKETING SDN BHD", "bbox": [[…]]},
|
| 75 |
+
"ADDRESS": {"text": "…", "bbox": [[…]]}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## 🚀 Quick Start
|
| 81 |
+
|
| 82 |
+
### Prerequisites
|
| 83 |
+
- Python 3.10+
|
| 84 |
+
- Tesseract OCR
|
| 85 |
+
- (Optional) CUDA-capable GPU for training/inference speed
|
| 86 |
+
|
| 87 |
+
### Installation
|
| 88 |
+
|
| 89 |
+
1. Clone the repository
|
| 90 |
+
```bash
|
| 91 |
+
git clone https://github.com/GSoumyajit2005/invoice-processor-ml
|
| 92 |
+
cd invoice-processor-ml
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
2. Install dependencies
|
| 96 |
+
```bash
|
| 97 |
+
pip install -r requirements.txt
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
3. Install Tesseract OCR
|
| 101 |
+
- **Windows**: Download from [UB Mannheim](https://github.com/UB-Mannheim/tesseract/wiki)
|
| 102 |
+
- **Mac**: `brew install tesseract`
|
| 103 |
+
- **Linux**: `sudo apt install tesseract-ocr`
|
| 104 |
+
|
| 105 |
+
4. (Optional, Windows) Set Tesseract path in src/ocr.py if needed:
|
| 106 |
+
```bash
|
| 107 |
+
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
5. Run the web app
|
| 111 |
+
```bash
|
| 112 |
+
streamlit run app.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## 💻 Usage
|
| 116 |
+
|
| 117 |
+
### Web Interface (Recommended)
|
| 118 |
+
|
| 119 |
+
The easiest way to use the processor is via the web interface.
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
streamlit run app.py
|
| 123 |
+
```
|
| 124 |
+
- Upload an invoice image (PNG/JPG).
|
| 125 |
+
- Choose extraction method in sidebar:
|
| 126 |
+
- ML-Based (LayoutLMv3)
|
| 127 |
+
- Rule-Based (Regex)
|
| 128 |
+
- View JSON, download results.
|
| 129 |
+
|
| 130 |
+
### Command-Line Interface (CLI)
|
| 131 |
+
|
| 132 |
+
You can also process invoices directly from the command line.
|
| 133 |
+
|
| 134 |
+
#### 1. Processing a Single Invoice
|
| 135 |
+
|
| 136 |
+
This command processes the provided sample invoice and prints the results to the console.
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python src/pipeline.py data/samples/sample_invoice.jpg --save --method ml
|
| 140 |
+
# or
|
| 141 |
+
python src/pipeline.py data/samples/sample_invoice.jpg --save --method rules
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
#### 2. Batch Processing a Folder
|
| 145 |
+
|
| 146 |
+
The CLI can process an entire folder of images at once.
|
| 147 |
+
|
| 148 |
+
First, place your own invoice images (e.g., `my_invoice1.jpg`, `my_invoice2.png`) into the `data/raw/` folder.
|
| 149 |
+
|
| 150 |
+
Then, run the following command. It will process all images in `data/raw/`. Saved files are written to `outputs/{stem}_{method}.json`.
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
python src/pipeline.py data/raw --save --method ml
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Python API
|
| 157 |
+
|
| 158 |
+
You can integrate the pipeline directly into your own Python scripts.
|
| 159 |
+
|
| 160 |
+
```python
|
| 161 |
+
from src.pipeline import process_invoice
|
| 162 |
+
import json
|
| 163 |
+
|
| 164 |
+
result = process_invoice('data/samples/sample_invoice.jpg', method='ml')
|
| 165 |
+
print(json.dumps(result, indent=2))
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## 🏗️ Architecture
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
┌────────────────┐
|
| 172 |
+
│ Upload Image │
|
| 173 |
+
└───────┬────────┘
|
| 174 |
+
│
|
| 175 |
+
▼
|
| 176 |
+
┌────────────────────┐
|
| 177 |
+
│ Preprocessing │ (OpenCV grayscale/denoise)
|
| 178 |
+
└────────┬───────────┘
|
| 179 |
+
│
|
| 180 |
+
▼
|
| 181 |
+
┌───────────────┐
|
| 182 |
+
│ OCR │ (Tesseract)
|
| 183 |
+
└───────┬───────┘
|
| 184 |
+
│
|
| 185 |
+
┌──────────────┴──────────────┐
|
| 186 |
+
│ │
|
| 187 |
+
▼ ▼
|
| 188 |
+
┌──────────────────┐ ┌────────────────────────┐
|
| 189 |
+
│ Rule-based IE │ │ ML-based IE (NER) │
|
| 190 |
+
│ (regex, heur.) │ │ LayoutLMv3 token-class │
|
| 191 |
+
└────────┬─────────┘ └───────────┬────────────┘
|
| 192 |
+
│ │
|
| 193 |
+
└──────────────┬──────────────────┘
|
| 194 |
+
▼
|
| 195 |
+
┌──────────────────┐
|
| 196 |
+
│ Post-process │
|
| 197 |
+
│ validate, scores │
|
| 198 |
+
└────────┬─────────┘
|
| 199 |
+
▼
|
| 200 |
+
┌──────────────────┐
|
| 201 |
+
│ JSON Output │
|
| 202 |
+
└──────────────────┘
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## 📁 Project Structure
|
| 206 |
+
|
| 207 |
+
```
|
| 208 |
+
invoice-processor-ml/
|
| 209 |
+
│
|
| 210 |
+
├── data/
|
| 211 |
+
│ ├── raw/ # Input invoice images for processing
|
| 212 |
+
│ └── processed/ # (Reserved for future use)
|
| 213 |
+
│
|
| 214 |
+
│
|
| 215 |
+
├── data/samples/
|
| 216 |
+
│ └── sample_invoice.jpg # Public sample for quick testing
|
| 217 |
+
│
|
| 218 |
+
├── docs/
|
| 219 |
+
│ └── screenshots/ # UI Screenshots for the README demo
|
| 220 |
+
│
|
| 221 |
+
│
|
| 222 |
+
├── models/
|
| 223 |
+
│ └── layoutlmv3-sroie-best/ # Fine-tuned model (created after training)
|
| 224 |
+
│
|
| 225 |
+
├── outputs/ # Default folder for saved JSON results
|
| 226 |
+
│
|
| 227 |
+
├── src/
|
| 228 |
+
│ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
|
| 229 |
+
│ ├── ocr.py # Tesseract OCR integration
|
| 230 |
+
│ ├── extraction.py # Regex-based information extraction logic
|
| 231 |
+
│ ├── ml_extraction.py # ML-based extraction (LayoutLMv3)
|
| 232 |
+
│ └── pipeline.py # Main orchestrator for the pipeline and CLI
|
| 233 |
+
│
|
| 234 |
+
│
|
| 235 |
+
├── tests/ # <-- ADD THIS FOLDER
|
| 236 |
+
│ ├── test_preprocessing.py # Tests for the preprocessing module
|
| 237 |
+
│ ├── test_ocr.py # Tests for the OCR module
|
| 238 |
+
│ └── test_pipeline.py # End-to-end pipeline tests
|
| 239 |
+
│
|
| 240 |
+
├── app.py # Streamlit web interface
|
| 241 |
+
├── requirements.txt # Python dependencies
|
| 242 |
+
└── README.md # You are Here!
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
## 🧠 Model & Training
|
| 246 |
+
|
| 247 |
+
- **Model**: `microsoft/layoutlmv3-base` (125M params)
|
| 248 |
+
- **Task**: Token Classification (NER) with 9 labels: `O, B/I-COMPANY, B/I-ADDRESS, B/I-DATE, B/I-TOTAL`
|
| 249 |
+
- **Dataset**: SROIE (ICDAR 2019, English retail receipts)
|
| 250 |
+
- **Training**: RTX 3050 6GB, PyTorch 2.x, Transformers 4.x
|
| 251 |
+
- **Result**: Best F1 ≈ 0.922 on validation (epoch 5 saved)
|
| 252 |
+
|
| 253 |
+
- Training scripts (local):
|
| 254 |
+
- `train_layoutlm.py` (data prep, training loop with validation + model save)
|
| 255 |
+
- Model saved to: `models/layoutlmv3-sroie-best/`
|
| 256 |
+
|
| 257 |
+
## 📈 Performance
|
| 258 |
+
|
| 259 |
+
- **OCR accuracy (clear images)**: High with Tesseract
|
| 260 |
+
- **Rule-based extraction**: Strong on simple retail receipts
|
| 261 |
+
- **ML-based extraction (SROIE-style)**:
|
| 262 |
+
- COMPANY / ADDRESS / DATE / TOTAL: High F1 on simple receipts
|
| 263 |
+
- Complex business invoices: Partial extraction unless further fine-tuned
|
| 264 |
+
|
| 265 |
+
## ⚠️ Known Limitations
|
| 266 |
+
|
| 267 |
+
1. **Layout Sensitivity**: The ML model was fine‑tuned only on SROIE (retail receipts). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
|
| 268 |
+
2. **Invoice Number (ML)**: SROIE lacks invoice number labels; the ML model won’t output it unless you add labeled data. The rule-based method can still recover it on many formats.
|
| 269 |
+
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 270 |
+
4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
|
| 271 |
+
|
| 272 |
+
## 🔮 Future Enhancements
|
| 273 |
+
|
| 274 |
+
- [ ] Add and fine‑tune on mychen76/invoices-and-receipts_ocr_v1 (English) for broader invoice formats
|
| 275 |
+
- [ ] (Optional) Add FATURA (table-focused) for line-item extraction
|
| 276 |
+
- [ ] Sliding-window chunking for >512 token documents (to avoid truncation)
|
| 277 |
+
- [ ] Table detection (Camelot/Tabula/DeepDeSRT) for line items
|
| 278 |
+
- [ ] PDF support (pdf2image) for multipage invoices
|
| 279 |
+
- [ ] FastAPI backend + Docker
|
| 280 |
+
- [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
|
| 281 |
+
- [ ] Confidence calibration and better validation rules
|
| 282 |
+
|
| 283 |
+
## 🛠️ Tech Stack
|
| 284 |
+
|
| 285 |
+
| Component | Technology |
|
| 286 |
+
|-----------|------------|
|
| 287 |
+
| OCR | Tesseract 5.0+ |
|
| 288 |
+
| Image Processing | OpenCV, Pillow |
|
| 289 |
+
| ML/NLP | PyTorch 2.x, Transformers |
|
| 290 |
+
| Model | LayoutLMv3 (token class.) |
|
| 291 |
+
| Web Interface | Streamlit |
|
| 292 |
+
| Data Format | JSON |
|
| 293 |
+
|
| 294 |
+
## 📚 What I Learned
|
| 295 |
+
|
| 296 |
+
- OCR challenges (confusable characters, confidence-based filtering)
|
| 297 |
+
- Layout-aware NER with LayoutLMv3 (text + bbox + pixels)
|
| 298 |
+
- Data normalization (bbox to 0–1000 scale)
|
| 299 |
+
- End-to-end pipelines (UI + CLI + JSON output)
|
| 300 |
+
- When regex is enough vs when ML is needed
|
| 301 |
+
- Evaluation (seqeval F1 for NER)
|
| 302 |
+
|
| 303 |
+
## 🤝 Contributing
|
| 304 |
+
|
| 305 |
+
Contributions welcome! Areas needing improvement:
|
| 306 |
+
- New patterns for regex extractor
|
| 307 |
+
- Better preprocessing for OCR
|
| 308 |
+
- New datasets and training configs
|
| 309 |
+
- Tests and CI
|
| 310 |
+
|
| 311 |
+
## 📝 License
|
| 312 |
+
|
| 313 |
+
MIT License - See LICENSE file for details
|
| 314 |
+
|
| 315 |
+
## 👨💻 Author
|
| 316 |
+
|
| 317 |
+
**Soumyajit Ghosh** - 3rd Year BTech Student
|
| 318 |
+
- Exploring AI/ML and practical applications
|
| 319 |
+
- [LinkedIn](https://www.linkedin.com/in/soumyajit-ghosh-49a5b02b2?utm_source=share&utm_campaign) | [GitHub](https://github.com/GSoumyajit2005) | [Portfolio](#)(Coming Soon)
|
| 320 |
+
|
| 321 |
+
---
|
| 322 |
+
|
| 323 |
**Note**: "This is a learning project demonstrating an end-to-end ML pipeline. Not recommended for production use without further validation, retraining on diverse datasets, and security hardening."
|
app.py
CHANGED
|
@@ -1,313 +1,313 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import os
|
| 3 |
-
import json
|
| 4 |
-
from datetime import datetime
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
# Import our actual, working pipeline function
|
| 11 |
-
import sys
|
| 12 |
-
sys.path.append('src')
|
| 13 |
-
from pipeline import process_invoice
|
| 14 |
-
|
| 15 |
-
# --- Mock Functions to support the UI without errors ---
|
| 16 |
-
# These functions simulate the ones from your example README.
|
| 17 |
-
# They allow the UI to render without needing to build a complex format detector today.
|
| 18 |
-
|
| 19 |
-
def detect_invoice_format(ocr_text: str):
|
| 20 |
-
"""
|
| 21 |
-
A mock function to simulate format detection.
|
| 22 |
-
In a real system, this would analyze the text layout.
|
| 23 |
-
"""
|
| 24 |
-
# Simple heuristic: if it contains "SDN BHD", it's our known format.
|
| 25 |
-
if "SDN BHD" in ocr_text:
|
| 26 |
-
return {
|
| 27 |
-
'name': 'Template A (Retail)',
|
| 28 |
-
'confidence': 95.0,
|
| 29 |
-
'supported': True,
|
| 30 |
-
'indicators': ["Found 'SDN BHD' suffix", "Date format DD/MM/YYYY detected"]
|
| 31 |
-
}
|
| 32 |
-
else:
|
| 33 |
-
return {
|
| 34 |
-
'name': 'Unknown Format',
|
| 35 |
-
'confidence': 20.0,
|
| 36 |
-
'supported': False,
|
| 37 |
-
'indicators': ["No known company suffixes found"]
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
def get_format_recommendations(format_info):
|
| 41 |
-
"""Mock recommendations based on the detected format."""
|
| 42 |
-
if format_info['supported']:
|
| 43 |
-
return ["• Extraction should be highly accurate."]
|
| 44 |
-
else:
|
| 45 |
-
return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
|
| 46 |
-
|
| 47 |
-
# --- Streamlit App ---
|
| 48 |
-
|
| 49 |
-
# Page configuration
|
| 50 |
-
st.set_page_config(
|
| 51 |
-
page_title="Invoice Processor",
|
| 52 |
-
page_icon="📄",
|
| 53 |
-
layout="wide",
|
| 54 |
-
initial_sidebar_state="expanded"
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
# Custom CSS for styling
|
| 58 |
-
st.markdown("""
|
| 59 |
-
<style>
|
| 60 |
-
.main-header {
|
| 61 |
-
font-size: 3rem;
|
| 62 |
-
color: #1f77b4;
|
| 63 |
-
text-align: center;
|
| 64 |
-
margin-bottom: 2rem;
|
| 65 |
-
}
|
| 66 |
-
.success-box {
|
| 67 |
-
padding: 1rem;
|
| 68 |
-
border-radius: 0.5rem;
|
| 69 |
-
background-color: #d4edda;
|
| 70 |
-
border: 1px solid #c3e6cb;
|
| 71 |
-
margin: 1rem 0;
|
| 72 |
-
}
|
| 73 |
-
.warning-box {
|
| 74 |
-
padding: 1rem;
|
| 75 |
-
border-radius: 0.5rem;
|
| 76 |
-
background-color: #fff3cd;
|
| 77 |
-
border: 1px solid #ffeaa7;
|
| 78 |
-
margin: 1rem 0;
|
| 79 |
-
}
|
| 80 |
-
.error-box {
|
| 81 |
-
padding: 1rem;
|
| 82 |
-
border-radius: 0.5rem;
|
| 83 |
-
background-color: #f8d7da;
|
| 84 |
-
border: 1px solid #f5c6cb;
|
| 85 |
-
margin: 1rem 0;
|
| 86 |
-
}
|
| 87 |
-
</style>
|
| 88 |
-
""", unsafe_allow_html=True)
|
| 89 |
-
|
| 90 |
-
# Title
|
| 91 |
-
st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
|
| 92 |
-
st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
|
| 93 |
-
|
| 94 |
-
# Sidebar
|
| 95 |
-
with st.sidebar:
|
| 96 |
-
st.header("ℹ️ About")
|
| 97 |
-
st.info("""
|
| 98 |
-
This app uses the pipeline you built to automatically extract:
|
| 99 |
-
- Receipt/Invoice number
|
| 100 |
-
- Date
|
| 101 |
-
- Customer information
|
| 102 |
-
- Line items
|
| 103 |
-
- Total amount
|
| 104 |
-
|
| 105 |
-
**Technology Stack:**
|
| 106 |
-
- Tesseract OCR
|
| 107 |
-
- OpenCV
|
| 108 |
-
- Python Regex
|
| 109 |
-
- Streamlit
|
| 110 |
-
""")
|
| 111 |
-
|
| 112 |
-
st.header("📊 Stats")
|
| 113 |
-
if 'processed_count' not in st.session_state:
|
| 114 |
-
st.session_state.processed_count = 0
|
| 115 |
-
st.metric("Invoices Processed Today", st.session_state.processed_count)
|
| 116 |
-
|
| 117 |
-
st.header("⚙️ Configuration")
|
| 118 |
-
extraction_method = st.selectbox(
|
| 119 |
-
"Choose Extraction Method:",
|
| 120 |
-
('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
|
| 121 |
-
help="ML-Based is more robust but may miss fields not in its training data. Rule-Based is faster but more fragile."
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
# Main content
|
| 125 |
-
tab1, tab2, tab3 = st.tabs(["📤 Upload & Process", "📚 Sample Invoices", "ℹ️ How It Works"])
|
| 126 |
-
|
| 127 |
-
with tab1:
|
| 128 |
-
st.header("Upload an Invoice")
|
| 129 |
-
|
| 130 |
-
uploaded_file = st.file_uploader(
|
| 131 |
-
"Choose an invoice image (JPG, PNG)",
|
| 132 |
-
type=['jpg', 'jpeg', 'png'],
|
| 133 |
-
help="Upload a clear image of an invoice or receipt"
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
if uploaded_file is not None:
|
| 137 |
-
col1, col2 = st.columns([1, 1])
|
| 138 |
-
|
| 139 |
-
with col1:
|
| 140 |
-
st.subheader("📸 Original Image")
|
| 141 |
-
image = Image.open(uploaded_file)
|
| 142 |
-
st.image(image, use_container_width=True)
|
| 143 |
-
st.caption(f"Filename: {uploaded_file.name}")
|
| 144 |
-
|
| 145 |
-
with col2:
|
| 146 |
-
st.subheader("🔄 Processing Status")
|
| 147 |
-
|
| 148 |
-
if st.button("🚀 Extract Data", type="primary"):
|
| 149 |
-
with st.spinner("Executing your custom pipeline..."):
|
| 150 |
-
try:
|
| 151 |
-
# Save the uploaded file to a temporary path to be used by our pipeline
|
| 152 |
-
temp_dir = "temp"
|
| 153 |
-
os.makedirs(temp_dir, exist_ok=True)
|
| 154 |
-
temp_path = os.path.join(temp_dir, uploaded_file.name)
|
| 155 |
-
with open(temp_path, "wb") as f:
|
| 156 |
-
f.write(uploaded_file.getbuffer())
|
| 157 |
-
|
| 158 |
-
# Step 1: Call YOUR full pipeline function
|
| 159 |
-
st.write("✅ Calling `process_invoice`...")
|
| 160 |
-
# Map the user-friendly name from the dropdown to the actual method parameter
|
| 161 |
-
method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
|
| 162 |
-
st.write(f"⚙️ Using **{method.upper()}** extraction method...")
|
| 163 |
-
|
| 164 |
-
# Call the pipeline with the selected method
|
| 165 |
-
extracted_data = process_invoice(temp_path, method=method)
|
| 166 |
-
|
| 167 |
-
# Step 2: Simulate format detection using the extracted data
|
| 168 |
-
st.write("✅ Simulating format detection...")
|
| 169 |
-
format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
|
| 170 |
-
|
| 171 |
-
# Store results in session state to display them
|
| 172 |
-
st.session_state.extracted_data = extracted_data
|
| 173 |
-
st.session_state.format_info = format_info
|
| 174 |
-
st.session_state.processed_count += 1
|
| 175 |
-
|
| 176 |
-
st.success("✅ Pipeline executed successfully!")
|
| 177 |
-
|
| 178 |
-
except Exception as e:
|
| 179 |
-
st.error(f"❌ An error occurred in the pipeline: {str(e)}")
|
| 180 |
-
|
| 181 |
-
# Display results if they exist in the session state
|
| 182 |
-
if 'extracted_data' in st.session_state:
|
| 183 |
-
st.markdown("---")
|
| 184 |
-
st.header("📊 Extraction Results")
|
| 185 |
-
|
| 186 |
-
# --- Format Detection Section ---
|
| 187 |
-
format_info = st.session_state.format_info
|
| 188 |
-
st.subheader("📋 Detected Format (Simulated)")
|
| 189 |
-
col1_fmt, col2_fmt = st.columns([2, 3])
|
| 190 |
-
with col1_fmt:
|
| 191 |
-
st.metric("Format Type", format_info['name'])
|
| 192 |
-
st.metric("Detection Confidence", f"{format_info['confidence']:.0f}%")
|
| 193 |
-
if format_info['supported']: st.success("✅ Fully Supported")
|
| 194 |
-
else: st.warning("⚠️ Limited Support")
|
| 195 |
-
with col2_fmt:
|
| 196 |
-
st.write("**Detected Indicators:**")
|
| 197 |
-
for indicator in format_info['indicators']: st.write(f"• {indicator}")
|
| 198 |
-
st.write("**Recommendations:**")
|
| 199 |
-
for rec in get_format_recommendations(format_info): st.write(rec)
|
| 200 |
-
st.markdown("---")
|
| 201 |
-
|
| 202 |
-
# --- Main Results Section ---
|
| 203 |
-
data = st.session_state.extracted_data
|
| 204 |
-
|
| 205 |
-
# Confidence display
|
| 206 |
-
confidence = data.get('extraction_confidence', 0)
|
| 207 |
-
if confidence >= 80:
|
| 208 |
-
st.markdown(f'<div class="success-box">✅ <strong>High Confidence: {confidence}%</strong> - Most key fields were found.</div>', unsafe_allow_html=True)
|
| 209 |
-
elif confidence >= 50:
|
| 210 |
-
st.markdown(f'<div class="warning-box">⚠️ <strong>Medium Confidence: {confidence}%</strong> - Some fields may be missing.</div>', unsafe_allow_html=True)
|
| 211 |
-
else:
|
| 212 |
-
st.markdown(f'<div class="error-box">❌ <strong>Low Confidence: {confidence}%</strong> - Format likely unsupported.</div>', unsafe_allow_html=True)
|
| 213 |
-
|
| 214 |
-
# Validation display
|
| 215 |
-
if data.get('validation_passed', False):
|
| 216 |
-
st.success("✔️ Validation Passed: Total amount appears consistent with other extracted amounts.")
|
| 217 |
-
else:
|
| 218 |
-
st.warning("⚠️ Validation Failed: Total amount could not be verified against other numbers.")
|
| 219 |
-
|
| 220 |
-
# Key metrics display
|
| 221 |
-
# Key metrics display
|
| 222 |
-
st.metric("🏢 Vendor", data.get('vendor') or "N/A") # <-- ADD THIS
|
| 223 |
-
|
| 224 |
-
res_col1, res_col2, res_col3 = st.columns(3)
|
| 225 |
-
res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
|
| 226 |
-
res_col2.metric("📅 Date", data.get('date') or "N/A")
|
| 227 |
-
res_col3.metric("💵 Total Amount", f"${data.get('total_amount'):.2f}" if data.get('total_amount') is not None else "N/A")
|
| 228 |
-
|
| 229 |
-
# Use an expander for longer text fields like address
|
| 230 |
-
with st.expander("Show More Details"):
|
| 231 |
-
st.markdown(f"**👤 Bill To:** {data.get('bill_to', {}).get('name') if data.get('bill_to') else 'N/A'}")
|
| 232 |
-
st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
|
| 233 |
-
|
| 234 |
-
# Line items table
|
| 235 |
-
if data.get('items'):
|
| 236 |
-
st.subheader("🛒 Line Items")
|
| 237 |
-
# Ensure data is in the right format for DataFrame
|
| 238 |
-
items_df_data = [{
|
| 239 |
-
"Description": item.get("description", "N/A"),
|
| 240 |
-
"Qty": item.get("quantity", "N/A"),
|
| 241 |
-
"Unit Price": f"${item.get('unit_price', 0.0):.2f}",
|
| 242 |
-
"Total": f"${item.get('total', 0.0):.2f}"
|
| 243 |
-
} for item in data['items']]
|
| 244 |
-
df = pd.DataFrame(items_df_data)
|
| 245 |
-
st.dataframe(df, use_container_width=True)
|
| 246 |
-
else:
|
| 247 |
-
st.info("ℹ️ No line items were extracted.")
|
| 248 |
-
|
| 249 |
-
# JSON output and download
|
| 250 |
-
with st.expander("📄 View Full JSON Output"):
|
| 251 |
-
st.json(data)
|
| 252 |
-
|
| 253 |
-
json_str = json.dumps(data, indent=2)
|
| 254 |
-
st.download_button(
|
| 255 |
-
label="💾 Download JSON",
|
| 256 |
-
data=json_str,
|
| 257 |
-
file_name=f"invoice_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
| 258 |
-
mime="application/json"
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
with st.expander("📝 View Raw OCR Text"):
|
| 262 |
-
raw_text = data.get('raw_text', '')
|
| 263 |
-
if raw_text:
|
| 264 |
-
st.text(raw_text)
|
| 265 |
-
else:
|
| 266 |
-
st.info("No OCR text available.")
|
| 267 |
-
|
| 268 |
-
with tab2:
|
| 269 |
-
st.header("📚 Sample Invoices")
|
| 270 |
-
st.write("Try the sample invoice below to see how the system performs:")
|
| 271 |
-
|
| 272 |
-
sample_dir = "data/samples" # ✅ Points to the correct folder
|
| 273 |
-
if os.path.exists(sample_dir):
|
| 274 |
-
sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
|
| 275 |
-
|
| 276 |
-
if sample_files:
|
| 277 |
-
# Display the first sample found
|
| 278 |
-
img_path = os.path.join(sample_dir, sample_files[0])
|
| 279 |
-
st.image(Image.open(img_path), caption=sample_files[0], use_container_width=True)
|
| 280 |
-
st.info("You can download this image and upload it in the 'Upload & Process' tab to test the pipeline.")
|
| 281 |
-
else:
|
| 282 |
-
st.warning("No sample invoices found in `data/samples/`.")
|
| 283 |
-
else:
|
| 284 |
-
st.error("The `data/samples` directory was not found.")
|
| 285 |
-
|
| 286 |
-
with tab3:
|
| 287 |
-
st.header("ℹ️ How It Works (Your Custom Pipeline)")
|
| 288 |
-
st.markdown("""
|
| 289 |
-
This app follows the exact pipeline you built:
|
| 290 |
-
```
|
| 291 |
-
1. 📸 Image Upload
|
| 292 |
-
↓
|
| 293 |
-
2. 🔄 Preprocessing (OpenCV)
|
| 294 |
-
Grayscale conversion and noise removal.
|
| 295 |
-
↓
|
| 296 |
-
3. 🔍 OCR (Tesseract)
|
| 297 |
-
Optimized with PSM 6 for receipt layouts.
|
| 298 |
-
↓
|
| 299 |
-
4. 🎯 Rule-Based Extraction (Regex)
|
| 300 |
-
Your custom patterns find specific fields.
|
| 301 |
-
↓
|
| 302 |
-
5. ✅ Confidence & Validation
|
| 303 |
-
Heuristics to check the quality of the extraction.
|
| 304 |
-
↓
|
| 305 |
-
6. 📊 Output JSON
|
| 306 |
-
Presents all extracted data in a structured format.
|
| 307 |
-
```
|
| 308 |
-
""")
|
| 309 |
-
st.info("This rule-based system is a great foundation. The next step is to replace the extraction logic with an ML model like LayoutLM to handle more diverse formats!")
|
| 310 |
-
|
| 311 |
-
# Footer
|
| 312 |
-
st.markdown("---")
|
| 313 |
st.markdown("<div style='text-align: center; color: #666;'>Built with your custom Python pipeline | UI by Streamlit</div>", unsafe_allow_html=True)
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Import our actual, working pipeline function
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.append('src')
|
| 13 |
+
from pipeline import process_invoice
|
| 14 |
+
|
| 15 |
+
# --- Mock Functions to support the UI without errors ---
|
| 16 |
+
# These functions simulate the ones from your example README.
|
| 17 |
+
# They allow the UI to render without needing to build a complex format detector today.
|
| 18 |
+
|
| 19 |
+
def detect_invoice_format(ocr_text: str):
|
| 20 |
+
"""
|
| 21 |
+
A mock function to simulate format detection.
|
| 22 |
+
In a real system, this would analyze the text layout.
|
| 23 |
+
"""
|
| 24 |
+
# Simple heuristic: if it contains "SDN BHD", it's our known format.
|
| 25 |
+
if "SDN BHD" in ocr_text:
|
| 26 |
+
return {
|
| 27 |
+
'name': 'Template A (Retail)',
|
| 28 |
+
'confidence': 95.0,
|
| 29 |
+
'supported': True,
|
| 30 |
+
'indicators': ["Found 'SDN BHD' suffix", "Date format DD/MM/YYYY detected"]
|
| 31 |
+
}
|
| 32 |
+
else:
|
| 33 |
+
return {
|
| 34 |
+
'name': 'Unknown Format',
|
| 35 |
+
'confidence': 20.0,
|
| 36 |
+
'supported': False,
|
| 37 |
+
'indicators': ["No known company suffixes found"]
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def get_format_recommendations(format_info):
|
| 41 |
+
"""Mock recommendations based on the detected format."""
|
| 42 |
+
if format_info['supported']:
|
| 43 |
+
return ["• Extraction should be highly accurate."]
|
| 44 |
+
else:
|
| 45 |
+
return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
|
| 46 |
+
|
| 47 |
+
# --- Streamlit App ---
|
| 48 |
+
|
| 49 |
+
# Page configuration
|
| 50 |
+
st.set_page_config(
|
| 51 |
+
page_title="Invoice Processor",
|
| 52 |
+
page_icon="📄",
|
| 53 |
+
layout="wide",
|
| 54 |
+
initial_sidebar_state="expanded"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Custom CSS for styling
|
| 58 |
+
st.markdown("""
|
| 59 |
+
<style>
|
| 60 |
+
.main-header {
|
| 61 |
+
font-size: 3rem;
|
| 62 |
+
color: #1f77b4;
|
| 63 |
+
text-align: center;
|
| 64 |
+
margin-bottom: 2rem;
|
| 65 |
+
}
|
| 66 |
+
.success-box {
|
| 67 |
+
padding: 1rem;
|
| 68 |
+
border-radius: 0.5rem;
|
| 69 |
+
background-color: #d4edda;
|
| 70 |
+
border: 1px solid #c3e6cb;
|
| 71 |
+
margin: 1rem 0;
|
| 72 |
+
}
|
| 73 |
+
.warning-box {
|
| 74 |
+
padding: 1rem;
|
| 75 |
+
border-radius: 0.5rem;
|
| 76 |
+
background-color: #fff3cd;
|
| 77 |
+
border: 1px solid #ffeaa7;
|
| 78 |
+
margin: 1rem 0;
|
| 79 |
+
}
|
| 80 |
+
.error-box {
|
| 81 |
+
padding: 1rem;
|
| 82 |
+
border-radius: 0.5rem;
|
| 83 |
+
background-color: #f8d7da;
|
| 84 |
+
border: 1px solid #f5c6cb;
|
| 85 |
+
margin: 1rem 0;
|
| 86 |
+
}
|
| 87 |
+
</style>
|
| 88 |
+
""", unsafe_allow_html=True)
|
| 89 |
+
|
| 90 |
+
# Title
|
| 91 |
+
st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
|
| 92 |
+
st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
|
| 93 |
+
|
| 94 |
+
# Sidebar
|
| 95 |
+
with st.sidebar:
|
| 96 |
+
st.header("ℹ️ About")
|
| 97 |
+
st.info("""
|
| 98 |
+
This app uses the pipeline you built to automatically extract:
|
| 99 |
+
- Receipt/Invoice number
|
| 100 |
+
- Date
|
| 101 |
+
- Customer information
|
| 102 |
+
- Line items
|
| 103 |
+
- Total amount
|
| 104 |
+
|
| 105 |
+
**Technology Stack:**
|
| 106 |
+
- Tesseract OCR
|
| 107 |
+
- OpenCV
|
| 108 |
+
- Python Regex
|
| 109 |
+
- Streamlit
|
| 110 |
+
""")
|
| 111 |
+
|
| 112 |
+
st.header("📊 Stats")
|
| 113 |
+
if 'processed_count' not in st.session_state:
|
| 114 |
+
st.session_state.processed_count = 0
|
| 115 |
+
st.metric("Invoices Processed Today", st.session_state.processed_count)
|
| 116 |
+
|
| 117 |
+
st.header("⚙️ Configuration")
|
| 118 |
+
extraction_method = st.selectbox(
|
| 119 |
+
"Choose Extraction Method:",
|
| 120 |
+
('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
|
| 121 |
+
help="ML-Based is more robust but may miss fields not in its training data. Rule-Based is faster but more fragile."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Main content
|
| 125 |
+
tab1, tab2, tab3 = st.tabs(["📤 Upload & Process", "📚 Sample Invoices", "ℹ️ How It Works"])
|
| 126 |
+
|
| 127 |
+
with tab1:
|
| 128 |
+
st.header("Upload an Invoice")
|
| 129 |
+
|
| 130 |
+
uploaded_file = st.file_uploader(
|
| 131 |
+
"Choose an invoice image (JPG, PNG)",
|
| 132 |
+
type=['jpg', 'jpeg', 'png'],
|
| 133 |
+
help="Upload a clear image of an invoice or receipt"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if uploaded_file is not None:
|
| 137 |
+
col1, col2 = st.columns([1, 1])
|
| 138 |
+
|
| 139 |
+
with col1:
|
| 140 |
+
st.subheader("📸 Original Image")
|
| 141 |
+
image = Image.open(uploaded_file)
|
| 142 |
+
st.image(image, use_container_width=True)
|
| 143 |
+
st.caption(f"Filename: {uploaded_file.name}")
|
| 144 |
+
|
| 145 |
+
with col2:
|
| 146 |
+
st.subheader("🔄 Processing Status")
|
| 147 |
+
|
| 148 |
+
if st.button("🚀 Extract Data", type="primary"):
|
| 149 |
+
with st.spinner("Executing your custom pipeline..."):
|
| 150 |
+
try:
|
| 151 |
+
# Save the uploaded file to a temporary path to be used by our pipeline
|
| 152 |
+
temp_dir = "temp"
|
| 153 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 154 |
+
temp_path = os.path.join(temp_dir, uploaded_file.name)
|
| 155 |
+
with open(temp_path, "wb") as f:
|
| 156 |
+
f.write(uploaded_file.getbuffer())
|
| 157 |
+
|
| 158 |
+
# Step 1: Call YOUR full pipeline function
|
| 159 |
+
st.write("✅ Calling `process_invoice`...")
|
| 160 |
+
# Map the user-friendly name from the dropdown to the actual method parameter
|
| 161 |
+
method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
|
| 162 |
+
st.write(f"⚙️ Using **{method.upper()}** extraction method...")
|
| 163 |
+
|
| 164 |
+
# Call the pipeline with the selected method
|
| 165 |
+
extracted_data = process_invoice(temp_path, method=method)
|
| 166 |
+
|
| 167 |
+
# Step 2: Simulate format detection using the extracted data
|
| 168 |
+
st.write("✅ Simulating format detection...")
|
| 169 |
+
format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
|
| 170 |
+
|
| 171 |
+
# Store results in session state to display them
|
| 172 |
+
st.session_state.extracted_data = extracted_data
|
| 173 |
+
st.session_state.format_info = format_info
|
| 174 |
+
st.session_state.processed_count += 1
|
| 175 |
+
|
| 176 |
+
st.success("✅ Pipeline executed successfully!")
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
st.error(f"❌ An error occurred in the pipeline: {str(e)}")
|
| 180 |
+
|
| 181 |
+
# Display results if they exist in the session state
|
| 182 |
+
if 'extracted_data' in st.session_state:
|
| 183 |
+
st.markdown("---")
|
| 184 |
+
st.header("📊 Extraction Results")
|
| 185 |
+
|
| 186 |
+
# --- Format Detection Section ---
|
| 187 |
+
format_info = st.session_state.format_info
|
| 188 |
+
st.subheader("📋 Detected Format (Simulated)")
|
| 189 |
+
col1_fmt, col2_fmt = st.columns([2, 3])
|
| 190 |
+
with col1_fmt:
|
| 191 |
+
st.metric("Format Type", format_info['name'])
|
| 192 |
+
st.metric("Detection Confidence", f"{format_info['confidence']:.0f}%")
|
| 193 |
+
if format_info['supported']: st.success("✅ Fully Supported")
|
| 194 |
+
else: st.warning("⚠️ Limited Support")
|
| 195 |
+
with col2_fmt:
|
| 196 |
+
st.write("**Detected Indicators:**")
|
| 197 |
+
for indicator in format_info['indicators']: st.write(f"• {indicator}")
|
| 198 |
+
st.write("**Recommendations:**")
|
| 199 |
+
for rec in get_format_recommendations(format_info): st.write(rec)
|
| 200 |
+
st.markdown("---")
|
| 201 |
+
|
| 202 |
+
# --- Main Results Section ---
|
| 203 |
+
data = st.session_state.extracted_data
|
| 204 |
+
|
| 205 |
+
# Confidence display
|
| 206 |
+
confidence = data.get('extraction_confidence', 0)
|
| 207 |
+
if confidence >= 80:
|
| 208 |
+
st.markdown(f'<div class="success-box">✅ <strong>High Confidence: {confidence}%</strong> - Most key fields were found.</div>', unsafe_allow_html=True)
|
| 209 |
+
elif confidence >= 50:
|
| 210 |
+
st.markdown(f'<div class="warning-box">⚠️ <strong>Medium Confidence: {confidence}%</strong> - Some fields may be missing.</div>', unsafe_allow_html=True)
|
| 211 |
+
else:
|
| 212 |
+
st.markdown(f'<div class="error-box">❌ <strong>Low Confidence: {confidence}%</strong> - Format likely unsupported.</div>', unsafe_allow_html=True)
|
| 213 |
+
|
| 214 |
+
# Validation display
|
| 215 |
+
if data.get('validation_passed', False):
|
| 216 |
+
st.success("✔️ Validation Passed: Total amount appears consistent with other extracted amounts.")
|
| 217 |
+
else:
|
| 218 |
+
st.warning("⚠️ Validation Failed: Total amount could not be verified against other numbers.")
|
| 219 |
+
|
| 220 |
+
# Key metrics display
|
| 221 |
+
# Key metrics display
|
| 222 |
+
st.metric("🏢 Vendor", data.get('vendor') or "N/A") # <-- ADD THIS
|
| 223 |
+
|
| 224 |
+
res_col1, res_col2, res_col3 = st.columns(3)
|
| 225 |
+
res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
|
| 226 |
+
res_col2.metric("📅 Date", data.get('date') or "N/A")
|
| 227 |
+
res_col3.metric("💵 Total Amount", f"${data.get('total_amount'):.2f}" if data.get('total_amount') is not None else "N/A")
|
| 228 |
+
|
| 229 |
+
# Use an expander for longer text fields like address
|
| 230 |
+
with st.expander("Show More Details"):
|
| 231 |
+
st.markdown(f"**👤 Bill To:** {data.get('bill_to', {}).get('name') if data.get('bill_to') else 'N/A'}")
|
| 232 |
+
st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
|
| 233 |
+
|
| 234 |
+
# Line items table
|
| 235 |
+
if data.get('items'):
|
| 236 |
+
st.subheader("🛒 Line Items")
|
| 237 |
+
# Ensure data is in the right format for DataFrame
|
| 238 |
+
items_df_data = [{
|
| 239 |
+
"Description": item.get("description", "N/A"),
|
| 240 |
+
"Qty": item.get("quantity", "N/A"),
|
| 241 |
+
"Unit Price": f"${item.get('unit_price', 0.0):.2f}",
|
| 242 |
+
"Total": f"${item.get('total', 0.0):.2f}"
|
| 243 |
+
} for item in data['items']]
|
| 244 |
+
df = pd.DataFrame(items_df_data)
|
| 245 |
+
st.dataframe(df, use_container_width=True)
|
| 246 |
+
else:
|
| 247 |
+
st.info("ℹ️ No line items were extracted.")
|
| 248 |
+
|
| 249 |
+
# JSON output and download
|
| 250 |
+
with st.expander("📄 View Full JSON Output"):
|
| 251 |
+
st.json(data)
|
| 252 |
+
|
| 253 |
+
json_str = json.dumps(data, indent=2)
|
| 254 |
+
st.download_button(
|
| 255 |
+
label="💾 Download JSON",
|
| 256 |
+
data=json_str,
|
| 257 |
+
file_name=f"invoice_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
| 258 |
+
mime="application/json"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
with st.expander("📝 View Raw OCR Text"):
|
| 262 |
+
raw_text = data.get('raw_text', '')
|
| 263 |
+
if raw_text:
|
| 264 |
+
st.text(raw_text)
|
| 265 |
+
else:
|
| 266 |
+
st.info("No OCR text available.")
|
| 267 |
+
|
| 268 |
+
with tab2:
|
| 269 |
+
st.header("📚 Sample Invoices")
|
| 270 |
+
st.write("Try the sample invoice below to see how the system performs:")
|
| 271 |
+
|
| 272 |
+
sample_dir = "data/samples" # ✅ Points to the correct folder
|
| 273 |
+
if os.path.exists(sample_dir):
|
| 274 |
+
sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
|
| 275 |
+
|
| 276 |
+
if sample_files:
|
| 277 |
+
# Display the first sample found
|
| 278 |
+
img_path = os.path.join(sample_dir, sample_files[0])
|
| 279 |
+
st.image(Image.open(img_path), caption=sample_files[0], use_container_width=True)
|
| 280 |
+
st.info("You can download this image and upload it in the 'Upload & Process' tab to test the pipeline.")
|
| 281 |
+
else:
|
| 282 |
+
st.warning("No sample invoices found in `data/samples/`.")
|
| 283 |
+
else:
|
| 284 |
+
st.error("The `data/samples` directory was not found.")
|
| 285 |
+
|
| 286 |
+
with tab3:
|
| 287 |
+
st.header("ℹ️ How It Works (Your Custom Pipeline)")
|
| 288 |
+
st.markdown("""
|
| 289 |
+
This app follows the exact pipeline you built:
|
| 290 |
+
```
|
| 291 |
+
1. 📸 Image Upload
|
| 292 |
+
↓
|
| 293 |
+
2. 🔄 Preprocessing (OpenCV)
|
| 294 |
+
Grayscale conversion and noise removal.
|
| 295 |
+
↓
|
| 296 |
+
3. 🔍 OCR (Tesseract)
|
| 297 |
+
Optimized with PSM 6 for receipt layouts.
|
| 298 |
+
↓
|
| 299 |
+
4. 🎯 Rule-Based Extraction (Regex)
|
| 300 |
+
Your custom patterns find specific fields.
|
| 301 |
+
↓
|
| 302 |
+
5. ✅ Confidence & Validation
|
| 303 |
+
Heuristics to check the quality of the extraction.
|
| 304 |
+
↓
|
| 305 |
+
6. 📊 Output JSON
|
| 306 |
+
Presents all extracted data in a structured format.
|
| 307 |
+
```
|
| 308 |
+
""")
|
| 309 |
+
st.info("This rule-based system is a great foundation. The next step is to replace the extraction logic with an ML model like LayoutLM to handle more diverse formats!")
|
| 310 |
+
|
| 311 |
+
# Footer
|
| 312 |
+
st.markdown("---")
|
| 313 |
st.markdown("<div style='text-align: center; color: #666;'>Built with your custom Python pipeline | UI by Streamlit</div>", unsafe_allow_html=True)
|
eval_new_dataset.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from src.data_loader import load_unified_dataset
|
| 3 |
+
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from seqeval.metrics import classification_report
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from train_combined import UnifiedDataset, label2id, id2label, LABEL_LIST
|
| 8 |
+
|
| 9 |
+
# Load Model
|
| 10 |
+
model_path = "./models/layoutlmv3-generalized"
|
| 11 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
|
| 12 |
+
processor = LayoutLMv3Processor.from_pretrained(model_path, apply_ocr=False)
|
| 13 |
+
device = torch.device("cuda")
|
| 14 |
+
model.to(device)
|
| 15 |
+
|
| 16 |
+
# Load ONLY the new dataset (validation split)
|
| 17 |
+
# We want to see how well it learned THIS specific dataset
|
| 18 |
+
print("Loading new dataset validation split...")
|
| 19 |
+
val_data = load_unified_dataset(split="valid", sample_size=None)
|
| 20 |
+
dataset = UnifiedDataset(val_data, processor, label2id)
|
| 21 |
+
loader = DataLoader(dataset, batch_size=4, collate_fn=DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt"))
|
| 22 |
+
|
| 23 |
+
print("Running evaluation...")
|
| 24 |
+
model.eval()
|
| 25 |
+
preds, labs = [], []
|
| 26 |
+
|
| 27 |
+
for batch in tqdm(loader):
|
| 28 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
outputs = model(**batch)
|
| 31 |
+
|
| 32 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 33 |
+
labels = batch['labels']
|
| 34 |
+
|
| 35 |
+
for i in range(len(labels)):
|
| 36 |
+
p = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
|
| 37 |
+
l = [id2label[l.item()] for l in labels[i] if l.item() != -100]
|
| 38 |
+
preds.append(p)
|
| 39 |
+
labs.append(l)
|
| 40 |
+
|
| 41 |
+
print("\nClassification Report:")
|
| 42 |
+
print(classification_report(labs, preds))
|
explore_new_dataset.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
import json
|
| 3 |
+
import ast # <--- Added for robust parsing
|
| 4 |
+
|
| 5 |
+
# --- 1. Load the dataset ---
|
| 6 |
+
print("📥 Loading 'mychen76/invoices-and-receipts_ocr_v1' from Hugging Face...")
|
| 7 |
+
try:
|
| 8 |
+
dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split='train')
|
| 9 |
+
print("✅ Dataset loaded successfully!")
|
| 10 |
+
except Exception as e:
|
| 11 |
+
print(f"❌ Failed to load dataset. Error: {e}")
|
| 12 |
+
exit()
|
| 13 |
+
|
| 14 |
+
# --- 2. Print Dataset Information ---
|
| 15 |
+
print("\n" + "="*60)
|
| 16 |
+
print("📊 DATASET INFORMATION & FEATURES")
|
| 17 |
+
print("="*60)
|
| 18 |
+
print(f"Number of examples: {len(dataset)}")
|
| 19 |
+
print(f"\nFeatures (Columns): {dataset.features}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# --- 3. Explore a Single Example ---
|
| 23 |
+
print("\n" + "="*60)
|
| 24 |
+
print("📄 EXPLORING THE FIRST SAMPLE")
|
| 25 |
+
print("="*60)
|
| 26 |
+
if len(dataset) > 0:
|
| 27 |
+
sample = dataset[0]
|
| 28 |
+
|
| 29 |
+
# Parse the main wrapper JSONs
|
| 30 |
+
try:
|
| 31 |
+
raw_data = json.loads(sample['raw_data'])
|
| 32 |
+
parsed_data = json.loads(sample['parsed_data'])
|
| 33 |
+
except json.JSONDecodeError as e:
|
| 34 |
+
print(f"❌ Error decoding main JSON wrappers: {e}")
|
| 35 |
+
exit()
|
| 36 |
+
|
| 37 |
+
print(f"\nImage object: {sample['image']}")
|
| 38 |
+
|
| 39 |
+
# --- ROBUST PARSING LOGIC ---
|
| 40 |
+
def safe_parse(content):
|
| 41 |
+
"""Try JSON, fallback to AST (for single quotes)"""
|
| 42 |
+
if isinstance(content, list):
|
| 43 |
+
return content # Already a list
|
| 44 |
+
if isinstance(content, str):
|
| 45 |
+
try:
|
| 46 |
+
return json.loads(content)
|
| 47 |
+
except json.JSONDecodeError:
|
| 48 |
+
try:
|
| 49 |
+
return ast.literal_eval(content)
|
| 50 |
+
except:
|
| 51 |
+
return None
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
ocr_words = safe_parse(raw_data.get('ocr_words'))
|
| 55 |
+
ocr_boxes = safe_parse(raw_data.get('ocr_boxes'))
|
| 56 |
+
|
| 57 |
+
if ocr_words and ocr_boxes:
|
| 58 |
+
print(f"\nFound {len(ocr_words)} OCR words.")
|
| 59 |
+
print("Sample Word & Box Format:")
|
| 60 |
+
# Print first 3 to check coordinate format (4 numbers or 8 numbers?)
|
| 61 |
+
for i in range(min(3, len(ocr_words))):
|
| 62 |
+
print(f" Word: '{ocr_words[i]}' | Box: {ocr_boxes[i]}")
|
| 63 |
+
else:
|
| 64 |
+
print("❌ OCR fields missing or could not be parsed.")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
print("Dataset is empty.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# --- 4. Discover All Unique NER Tags ---
|
| 72 |
+
print("\n" + "="*60)
|
| 73 |
+
print("📋 ALL UNIQUE ENTITY LABELS IN THIS DATASET")
|
| 74 |
+
print("="*60)
|
| 75 |
+
if len(dataset) > 0:
|
| 76 |
+
all_entity_labels = set()
|
| 77 |
+
|
| 78 |
+
print("Scanning dataset for labels...")
|
| 79 |
+
for i, example in enumerate(dataset):
|
| 80 |
+
try:
|
| 81 |
+
# Parse parsed_data
|
| 82 |
+
parsed_example = json.loads(example['parsed_data'])
|
| 83 |
+
|
| 84 |
+
# The 'json' field inside might be a string or a dict
|
| 85 |
+
fields_data = parsed_example.get('json', {})
|
| 86 |
+
|
| 87 |
+
if isinstance(fields_data, str):
|
| 88 |
+
try:
|
| 89 |
+
fields = json.loads(fields_data)
|
| 90 |
+
except:
|
| 91 |
+
fields = ast.literal_eval(fields_data)
|
| 92 |
+
else:
|
| 93 |
+
fields = fields_data
|
| 94 |
+
|
| 95 |
+
if fields:
|
| 96 |
+
all_entity_labels.update(fields.keys())
|
| 97 |
+
|
| 98 |
+
except Exception:
|
| 99 |
+
continue # Skip corrupted examples silently
|
| 100 |
+
|
| 101 |
+
if all_entity_labels:
|
| 102 |
+
print(f"\nFound {len(all_entity_labels)} unique entity labels:")
|
| 103 |
+
print(sorted(list(all_entity_labels)))
|
| 104 |
+
else:
|
| 105 |
+
print("Could not find any entity labels.")
|
| 106 |
+
else:
|
| 107 |
+
print("Cannot analyze tags of an empty dataset.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Add this to explore_new_dataset.py
|
| 111 |
+
sample = dataset[0]
|
| 112 |
+
sample['image'].save("data/samples/test_invoice_no.jpg")
|
| 113 |
+
print("Saved sample image to data/samples/test_invoice_no.jpg")
|
load_sroie_dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def load_sroie(path):
|
| 6 |
+
print(f"🔄 Loading SROIE from local path: {path}")
|
| 7 |
+
path = Path(path)
|
| 8 |
+
dataset = {'train': [], 'test': []}
|
| 9 |
+
|
| 10 |
+
for split in ["train", "test"]:
|
| 11 |
+
split_path = path / split
|
| 12 |
+
|
| 13 |
+
if (split_path / "images").exists(): img_dir = split_path / "images"
|
| 14 |
+
elif (split_path / "img").exists(): img_dir = split_path / "img"
|
| 15 |
+
else: continue
|
| 16 |
+
|
| 17 |
+
if (split_path / "tagged").exists(): ann_dir = split_path / "tagged"
|
| 18 |
+
elif (split_path / "box").exists(): ann_dir = split_path / "box"
|
| 19 |
+
else: continue
|
| 20 |
+
|
| 21 |
+
examples = []
|
| 22 |
+
for img_file in sorted(img_dir.iterdir()):
|
| 23 |
+
if img_file.suffix.lower() not in [".jpg", ".png"]: continue
|
| 24 |
+
|
| 25 |
+
name = img_file.stem
|
| 26 |
+
json_path = ann_dir / f"{name}.json"
|
| 27 |
+
if not json_path.exists(): continue
|
| 28 |
+
|
| 29 |
+
with open(json_path, encoding="utf8") as f:
|
| 30 |
+
data = json.load(f)
|
| 31 |
+
|
| 32 |
+
if "words" in data and "bbox" in data and "labels" in data:
|
| 33 |
+
# --- NORMALIZATION HAPPENS HERE (YOUR FIX) ---
|
| 34 |
+
try:
|
| 35 |
+
with Image.open(img_file) as img:
|
| 36 |
+
width, height = img.size
|
| 37 |
+
|
| 38 |
+
norm_boxes = []
|
| 39 |
+
for box in data["bbox"]:
|
| 40 |
+
# SROIE is raw [x0, y0, x1, y1]
|
| 41 |
+
x0, y0, x1, y1 = box
|
| 42 |
+
|
| 43 |
+
# Normalize and Clamp
|
| 44 |
+
norm_box = [
|
| 45 |
+
int(max(0, min(1000 * (x0 / width), 1000))),
|
| 46 |
+
int(max(0, min(1000 * (y0 / height), 1000))),
|
| 47 |
+
int(max(0, min(1000 * (x1 / width), 1000))),
|
| 48 |
+
int(max(0, min(1000 * (y1 / height), 1000)))
|
| 49 |
+
]
|
| 50 |
+
norm_boxes.append(norm_box)
|
| 51 |
+
|
| 52 |
+
examples.append({
|
| 53 |
+
"image_path": str(img_file),
|
| 54 |
+
"words": data["words"],
|
| 55 |
+
"bboxes": norm_boxes, # Storing normalized boxes
|
| 56 |
+
"ner_tags": data["labels"]
|
| 57 |
+
})
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Skipping {name}: {e}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
dataset[split] = examples
|
| 63 |
+
print(f" Mapped {len(examples)} paths for {split}")
|
| 64 |
+
|
| 65 |
+
return dataset
|
notebooks/test_setup.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
# This is just a verification script - you can copy this
|
| 2 |
-
import pytesseract
|
| 3 |
-
from PIL import Image
|
| 4 |
-
import cv2
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
# If Windows, you might need to set this path:
|
| 8 |
-
# pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
|
| 9 |
-
|
| 10 |
-
print("✅ All imports successful!")
|
| 11 |
-
print(f"Tesseract version: {pytesseract.get_tesseract_version()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/test_visual.ipynb
DELETED
|
File without changes
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|
src/data_loader.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data_loader.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import ast
|
| 5 |
+
import numpy as np
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from difflib import SequenceMatcher
|
| 8 |
+
|
| 9 |
+
# --- CONFIGURATION ---
|
| 10 |
+
LABEL_MAPPING = {
|
| 11 |
+
# Vendor/Company
|
| 12 |
+
"seller": "COMPANY",
|
| 13 |
+
"store_name": "COMPANY",
|
| 14 |
+
|
| 15 |
+
# Address
|
| 16 |
+
"store_addr": "ADDRESS",
|
| 17 |
+
|
| 18 |
+
# Date
|
| 19 |
+
"date": "DATE",
|
| 20 |
+
"invoice_date": "DATE",
|
| 21 |
+
|
| 22 |
+
# Total
|
| 23 |
+
"total": "TOTAL",
|
| 24 |
+
"total_gross_worth": "TOTAL",
|
| 25 |
+
|
| 26 |
+
# Receipt Number / Invoice No
|
| 27 |
+
"invoice_no": "INVOICE_NO",
|
| 28 |
+
|
| 29 |
+
# Bill To / Client
|
| 30 |
+
"client": "BILL_TO"
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def safe_parse(content):
|
| 34 |
+
"""Robustly parses input that might be a list, a JSON string, or a Python string literal."""
|
| 35 |
+
if isinstance(content, list):
|
| 36 |
+
return content
|
| 37 |
+
if isinstance(content, str):
|
| 38 |
+
try:
|
| 39 |
+
return json.loads(content)
|
| 40 |
+
except json.JSONDecodeError:
|
| 41 |
+
pass
|
| 42 |
+
try:
|
| 43 |
+
return ast.literal_eval(content)
|
| 44 |
+
except (ValueError, SyntaxError):
|
| 45 |
+
pass
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
def normalize_box(box, width, height):
|
| 49 |
+
"""Converts 8-point polygons to 4-point normalized [0-1000] bbox."""
|
| 50 |
+
try:
|
| 51 |
+
# Handle nested format variations
|
| 52 |
+
if isinstance(box, list) and len(box) == 2 and isinstance(box[0], list):
|
| 53 |
+
polygon = box[0]
|
| 54 |
+
elif isinstance(box, list) and len(box) == 4 and isinstance(box[0], list):
|
| 55 |
+
polygon = box
|
| 56 |
+
else:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
xs = [point[0] for point in polygon]
|
| 60 |
+
ys = [point[1] for point in polygon]
|
| 61 |
+
|
| 62 |
+
return [
|
| 63 |
+
int(max(0, min(1000 * (min(xs) / width), 1000))),
|
| 64 |
+
int(max(0, min(1000 * (min(ys) / height), 1000))),
|
| 65 |
+
int(max(0, min(1000 * (max(xs) / width), 1000))),
|
| 66 |
+
int(max(0, min(1000 * (max(ys) / height), 1000)))
|
| 67 |
+
]
|
| 68 |
+
except Exception:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
def tokenize_and_spread_boxes(words, boxes):
|
| 72 |
+
"""
|
| 73 |
+
Splits phrases into individual words and duplicates the bounding box.
|
| 74 |
+
Input: ['Invoice #123'], [BOX_A]
|
| 75 |
+
Output: ['Invoice', '#123'], [BOX_A, BOX_A]
|
| 76 |
+
"""
|
| 77 |
+
tokenized_words = []
|
| 78 |
+
tokenized_boxes = []
|
| 79 |
+
|
| 80 |
+
for word, box in zip(words, boxes):
|
| 81 |
+
# Split by whitespace
|
| 82 |
+
sub_words = str(word).split()
|
| 83 |
+
for sw in sub_words:
|
| 84 |
+
tokenized_words.append(sw)
|
| 85 |
+
tokenized_boxes.append(box)
|
| 86 |
+
|
| 87 |
+
return tokenized_words, tokenized_boxes
|
| 88 |
+
|
| 89 |
+
def align_labels(ocr_words, label_map):
|
| 90 |
+
"""Matches OCR words to Ground Truth values using Sub-sequence Matching."""
|
| 91 |
+
tags = ["O"] * len(ocr_words)
|
| 92 |
+
|
| 93 |
+
for target_text, label_class in label_map.items():
|
| 94 |
+
if not target_text: continue
|
| 95 |
+
|
| 96 |
+
target_tokens = str(target_text).split()
|
| 97 |
+
if not target_tokens: continue
|
| 98 |
+
|
| 99 |
+
n_target = len(target_tokens)
|
| 100 |
+
|
| 101 |
+
# Sliding window search
|
| 102 |
+
for i in range(len(ocr_words) - n_target + 1):
|
| 103 |
+
window = ocr_words[i : i + n_target]
|
| 104 |
+
|
| 105 |
+
# Check match
|
| 106 |
+
match = True
|
| 107 |
+
for j in range(n_target):
|
| 108 |
+
# Clean punctuation for comparison
|
| 109 |
+
w_clean = window[j].strip(".,-:")
|
| 110 |
+
t_clean = target_tokens[j].strip(".,-:")
|
| 111 |
+
if w_clean not in t_clean and t_clean not in w_clean:
|
| 112 |
+
match = False
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if match:
|
| 116 |
+
tags[i] = f"B-{label_class}"
|
| 117 |
+
for k in range(1, n_target):
|
| 118 |
+
tags[i + k] = f"I-{label_class}"
|
| 119 |
+
|
| 120 |
+
return tags
|
| 121 |
+
|
| 122 |
+
def load_unified_dataset(split="train", sample_size=None):
|
| 123 |
+
print(f"🔄 Loading dataset 'mychen76/invoices-and-receipts_ocr_v1' ({split})...")
|
| 124 |
+
dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split=split)
|
| 125 |
+
|
| 126 |
+
if sample_size:
|
| 127 |
+
dataset = dataset.select(range(sample_size))
|
| 128 |
+
|
| 129 |
+
processed_data = []
|
| 130 |
+
|
| 131 |
+
print("⚙️ Processing, Tokenizing, and Aligning...")
|
| 132 |
+
for example in dataset:
|
| 133 |
+
try:
|
| 134 |
+
image = example['image']
|
| 135 |
+
if image.mode != "RGB":
|
| 136 |
+
image = image.convert("RGB")
|
| 137 |
+
width, height = image.size
|
| 138 |
+
|
| 139 |
+
# 1. Parse Raw OCR
|
| 140 |
+
raw_words = safe_parse(json.loads(example['raw_data']).get('ocr_words'))
|
| 141 |
+
raw_boxes = safe_parse(json.loads(example['raw_data']).get('ocr_boxes'))
|
| 142 |
+
|
| 143 |
+
if not raw_words or not raw_boxes or len(raw_words) != len(raw_boxes):
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
# 2. Normalize Boxes first
|
| 147 |
+
norm_boxes = []
|
| 148 |
+
valid_words = []
|
| 149 |
+
for i, box in enumerate(raw_boxes):
|
| 150 |
+
nb = normalize_box(box, width, height)
|
| 151 |
+
if nb:
|
| 152 |
+
norm_boxes.append(nb)
|
| 153 |
+
valid_words.append(raw_words[i])
|
| 154 |
+
|
| 155 |
+
# 3. TOKENIZE (The Fix)
|
| 156 |
+
final_words, final_boxes = tokenize_and_spread_boxes(valid_words, norm_boxes)
|
| 157 |
+
|
| 158 |
+
# 4. Map Labels
|
| 159 |
+
parsed_json = json.loads(example['parsed_data'])
|
| 160 |
+
fields = safe_parse(parsed_json.get('json', {}))
|
| 161 |
+
label_value_map = {}
|
| 162 |
+
if isinstance(fields, dict):
|
| 163 |
+
for k, v in fields.items():
|
| 164 |
+
if k in LABEL_MAPPING and v:
|
| 165 |
+
label_value_map[v] = LABEL_MAPPING[k]
|
| 166 |
+
|
| 167 |
+
# 5. Align Labels
|
| 168 |
+
final_tags = align_labels(final_words, label_value_map)
|
| 169 |
+
|
| 170 |
+
# Only keep if we found at least one entity (cleaner training data)
|
| 171 |
+
unique_tags = set(final_tags)
|
| 172 |
+
if len(unique_tags) > 1:
|
| 173 |
+
processed_data.append({
|
| 174 |
+
"image": image,
|
| 175 |
+
"words": final_words,
|
| 176 |
+
"bboxes": final_boxes,
|
| 177 |
+
"ner_tags": final_tags
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
except Exception:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
print(f"✅ Successfully processed {len(processed_data)} examples.")
|
| 184 |
+
return processed_data
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
# Test run
|
| 188 |
+
data = load_unified_dataset(sample_size=20)
|
| 189 |
+
if len(data) > 0:
|
| 190 |
+
print(f"\nSample 0 Words: {data[0]['words'][:10]}...")
|
| 191 |
+
print(f"Sample 0 Tags: {data[0]['ner_tags'][:10]}...")
|
| 192 |
+
|
| 193 |
+
all_tags = [t for item in data for t in item['ner_tags']]
|
| 194 |
+
unique_tags = set(all_tags)
|
| 195 |
+
print(f"\nUnique Tags Found in Sample: {unique_tags}")
|
| 196 |
+
else:
|
| 197 |
+
print("No valid examples found in sample.")
|
src/extraction.py
CHANGED
|
@@ -1,273 +1,123 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
pattern2 = r'\d{
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
dates.extend(re.findall(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
if
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
bill_to_text = None
|
| 126 |
-
for i, line in enumerate(lines):
|
| 127 |
-
lower_line = line.lower()
|
| 128 |
-
if any(h in lower_line for h in headings):
|
| 129 |
-
# Capture text after colon or hyphen if present
|
| 130 |
-
split_line = re.split(r'[:\-]', line, maxsplit=1)
|
| 131 |
-
if len(split_line) > 1:
|
| 132 |
-
bill_to_text = split_line[1].strip()
|
| 133 |
-
else:
|
| 134 |
-
# If name is on next line
|
| 135 |
-
if i + 1 < len(lines):
|
| 136 |
-
bill_to_text = lines[i + 1].strip()
|
| 137 |
-
break
|
| 138 |
-
|
| 139 |
-
if not bill_to_text:
|
| 140 |
-
return None
|
| 141 |
-
|
| 142 |
-
# Extract email if present
|
| 143 |
-
email_match = re.search(r'[\w\.-]+@[\w\.-]+\.\w+', bill_to_text)
|
| 144 |
-
email = email_match.group(0) if email_match else None
|
| 145 |
-
|
| 146 |
-
# Remove email from name
|
| 147 |
-
if email:
|
| 148 |
-
bill_to_text = bill_to_text.replace(email, '').strip()
|
| 149 |
-
|
| 150 |
-
if len(bill_to_text) > 2: # Basic validation
|
| 151 |
-
bill_to = {"name": bill_to_text, "email": email}
|
| 152 |
-
|
| 153 |
-
return bill_to
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def extract_line_items(text: str) -> List[Dict[str, Any]]:
|
| 157 |
-
"""
|
| 158 |
-
Extract line items from receipt text more robustly.
|
| 159 |
-
Handles:
|
| 160 |
-
- Multi-line descriptions
|
| 161 |
-
- Prices with or without currency symbols
|
| 162 |
-
- Quantities in different formats
|
| 163 |
-
- Missing decimals
|
| 164 |
-
|
| 165 |
-
Args:
|
| 166 |
-
text: Raw OCR text
|
| 167 |
-
|
| 168 |
-
Returns:
|
| 169 |
-
List of dictionaries with description, quantity, unit_price, total
|
| 170 |
-
"""
|
| 171 |
-
items = []
|
| 172 |
-
lines = text.split('\n')
|
| 173 |
-
|
| 174 |
-
# Keywords to detect start/end of item section
|
| 175 |
-
start_keywords = ['description', 'item', 'qty', 'price', 'amount']
|
| 176 |
-
end_keywords = ['total', 'subtotal', 'tax', 'gst']
|
| 177 |
-
|
| 178 |
-
# Detect section
|
| 179 |
-
start_index = -1
|
| 180 |
-
end_index = len(lines)
|
| 181 |
-
for i, line in enumerate(lines):
|
| 182 |
-
lower = line.lower()
|
| 183 |
-
if start_index == -1 and any(k in lower for k in start_keywords):
|
| 184 |
-
start_index = i + 1
|
| 185 |
-
if start_index != -1 and any(k in lower for k in end_keywords):
|
| 186 |
-
end_index = i
|
| 187 |
-
break
|
| 188 |
-
|
| 189 |
-
if start_index == -1:
|
| 190 |
-
return []
|
| 191 |
-
|
| 192 |
-
item_lines = lines[start_index:end_index]
|
| 193 |
-
|
| 194 |
-
current_description = ""
|
| 195 |
-
for line in item_lines:
|
| 196 |
-
# Remove currency symbols, commas, etc.
|
| 197 |
-
clean_line = re.sub(r'[^\d\.\s]', '', line)
|
| 198 |
-
|
| 199 |
-
# Find all numbers (floats or integers)
|
| 200 |
-
amounts_on_line = re.findall(r'\d+(?:\.\d+)?', clean_line)
|
| 201 |
-
|
| 202 |
-
# Attempt to detect quantity at the start: "2 ", "3 x", etc.
|
| 203 |
-
qty_match = re.match(r'^\s*(\d+)\s*(?:x)?', line)
|
| 204 |
-
quantity = int(qty_match.group(1)) if qty_match else 1
|
| 205 |
-
|
| 206 |
-
# Extract description by removing numbers and common symbols
|
| 207 |
-
desc_part = re.sub(r'[\d\.\s]+', '', line).strip()
|
| 208 |
-
if len(desc_part) > 0:
|
| 209 |
-
if current_description:
|
| 210 |
-
current_description += " " + desc_part
|
| 211 |
-
else:
|
| 212 |
-
current_description = desc_part
|
| 213 |
-
|
| 214 |
-
# If there are numbers and a description, create item
|
| 215 |
-
if amounts_on_line and current_description:
|
| 216 |
-
try:
|
| 217 |
-
# Heuristic: last number is total, second last is unit price
|
| 218 |
-
item_total = float(amounts_on_line[-1])
|
| 219 |
-
unit_price = float(amounts_on_line[-2]) if len(amounts_on_line) > 1 else item_total
|
| 220 |
-
|
| 221 |
-
items.append({
|
| 222 |
-
"description": current_description.strip(),
|
| 223 |
-
"quantity": quantity,
|
| 224 |
-
"unit_price": unit_price,
|
| 225 |
-
"total": item_total
|
| 226 |
-
})
|
| 227 |
-
current_description = "" # reset for next item
|
| 228 |
-
except ValueError:
|
| 229 |
-
current_description = ""
|
| 230 |
-
continue
|
| 231 |
-
|
| 232 |
-
return items
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def structure_output(text: str) -> Dict[str, Any]:
|
| 236 |
-
"""
|
| 237 |
-
Extract all information and return in the desired advanced format.
|
| 238 |
-
"""
|
| 239 |
-
|
| 240 |
-
# Old fields
|
| 241 |
-
date = extract_dates(text)[0] if extract_dates(text) else None
|
| 242 |
-
total = extract_total(text)
|
| 243 |
-
|
| 244 |
-
# New fields
|
| 245 |
-
bill_to = extract_bill_to(text)
|
| 246 |
-
items = extract_line_items(text)
|
| 247 |
-
invoice_num = extract_invoice_number(text) # Renamed for clarity
|
| 248 |
-
|
| 249 |
-
data = {
|
| 250 |
-
"receipt_number": invoice_num,
|
| 251 |
-
"date": date,
|
| 252 |
-
"bill_to": bill_to,
|
| 253 |
-
"items": items,
|
| 254 |
-
"total_amount": total,
|
| 255 |
-
"raw_text": text
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
# --- Confidence and Validation ---
|
| 259 |
-
fields_to_check = ['receipt_number', 'date', 'bill_to', 'total_amount']
|
| 260 |
-
extracted_fields = sum(1 for field in fields_to_check if data.get(field) is not None)
|
| 261 |
-
if items: # Count items as an extracted field
|
| 262 |
-
extracted_fields += 1
|
| 263 |
-
|
| 264 |
-
data['extraction_confidence'] = int((extracted_fields / (len(fields_to_check) + 1)) * 100)
|
| 265 |
-
|
| 266 |
-
# A more advanced validation
|
| 267 |
-
items_total = sum(item.get('total', 0) for item in items)
|
| 268 |
-
data['validation_passed'] = False
|
| 269 |
-
if total is not None and abs(total - items_total) < 0.01: # Check if total matches sum of items
|
| 270 |
-
data['validation_passed'] = True
|
| 271 |
-
|
| 272 |
-
return data
|
| 273 |
-
|
|
|
|
| 1 |
+
# src/extraction.py
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Dict, Optional, Any
|
| 5 |
+
|
| 6 |
+
def extract_dates(text: str) -> List[str]:
|
| 7 |
+
if not text: return []
|
| 8 |
+
dates = []
|
| 9 |
+
# DD/MM/YYYY or DD-MM-YYYY
|
| 10 |
+
pattern1 = r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b'
|
| 11 |
+
# YYYY-MM-DD
|
| 12 |
+
pattern2 = r'\b\d{4}[/-]\d{1,2}[/-]\d{1,2}\b'
|
| 13 |
+
|
| 14 |
+
dates.extend(re.findall(pattern1, text))
|
| 15 |
+
dates.extend(re.findall(pattern2, text))
|
| 16 |
+
return list(dict.fromkeys(dates))
|
| 17 |
+
|
| 18 |
+
def extract_amounts(text: str) -> List[float]:
|
| 19 |
+
if not text: return []
|
| 20 |
+
# Matches: 1,234.56 or 1234.56
|
| 21 |
+
pattern = r'\b\d{1,3}(?:,\d{3})*\.\d{2}\b'
|
| 22 |
+
amounts_strings = re.findall(pattern, text)
|
| 23 |
+
|
| 24 |
+
amounts = []
|
| 25 |
+
for amt_str in amounts_strings:
|
| 26 |
+
amt_cleaned = amt_str.replace(',', '')
|
| 27 |
+
try:
|
| 28 |
+
amounts.append(float(amt_cleaned))
|
| 29 |
+
except ValueError:
|
| 30 |
+
continue
|
| 31 |
+
return amounts
|
| 32 |
+
|
| 33 |
+
def extract_total(text: str) -> Optional[float]:
|
| 34 |
+
"""
|
| 35 |
+
Robust total extraction looking for keywords + largest number context.
|
| 36 |
+
"""
|
| 37 |
+
if not text: return None
|
| 38 |
+
|
| 39 |
+
# 1. Try specific "Total" keywords first
|
| 40 |
+
# Looks for "Total: 123.45" or "Total Amount $123.45"
|
| 41 |
+
pattern = r'(?:TOTAL|AMOUNT DUE|GRAND TOTAL|BALANCE|PAYABLE)[\w\s]*[:$]?\s*([\d,]+\.\d{2})'
|
| 42 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 43 |
+
|
| 44 |
+
if matches:
|
| 45 |
+
# Return the last match (often the grand total at bottom)
|
| 46 |
+
try:
|
| 47 |
+
return float(matches[-1].replace(',', ''))
|
| 48 |
+
except ValueError:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
# 2. Fallback: Find the largest monetary value in the bottom half of text
|
| 52 |
+
# (Risky, but better than None)
|
| 53 |
+
amounts = extract_amounts(text)
|
| 54 |
+
if amounts:
|
| 55 |
+
return max(amounts)
|
| 56 |
+
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
def extract_vendor(text: str) -> Optional[str]:
|
| 60 |
+
if not text: return None
|
| 61 |
+
lines = text.strip().split('\n')
|
| 62 |
+
company_suffixes = ['SDN BHD', 'INC', 'LTD', 'LLC', 'PLC', 'CORP', 'PTY', 'PVT', 'LIMITED']
|
| 63 |
+
|
| 64 |
+
for line in lines[:10]: # Check top 10 lines
|
| 65 |
+
line_upper = line.upper()
|
| 66 |
+
if any(suffix in line_upper for suffix in company_suffixes):
|
| 67 |
+
return line.strip()
|
| 68 |
+
|
| 69 |
+
# Fallback: Return first non-empty line that isn't a date
|
| 70 |
+
for line in lines[:5]:
|
| 71 |
+
if len(line.strip()) > 3 and not re.search(r'\d{2}/\d{2}', line):
|
| 72 |
+
return line.strip()
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def extract_invoice_number(text: str) -> Optional[str]:
|
| 76 |
+
"""
|
| 77 |
+
Improved regex that handles alphanumeric AND numeric IDs.
|
| 78 |
+
"""
|
| 79 |
+
if not text: return None
|
| 80 |
+
|
| 81 |
+
# Strategy 1: Look for "Invoice No: XXXXX" pattern
|
| 82 |
+
# Matches: "Invoice No: 12345", "Inv #: AB-123", "Bill No. 999"
|
| 83 |
+
keyword_pattern = r'(?:INVOICE|BILL|RECEIPT)\s*(?:NO|NUMBER|#|NUM)?[\s\.:-]*([A-Z0-9\-/]{3,})'
|
| 84 |
+
match = re.search(keyword_pattern, text, re.IGNORECASE)
|
| 85 |
+
if match:
|
| 86 |
+
return match.group(1)
|
| 87 |
+
|
| 88 |
+
# Strategy 2: Look for standalone labeled patterns (Existing Logic)
|
| 89 |
+
# Only if Strategy 1 fails
|
| 90 |
+
lines = text.split('\n')
|
| 91 |
+
for line in lines[:20]:
|
| 92 |
+
if any(k in line.lower() for k in ['invoice', 'no', '#']):
|
| 93 |
+
# Allow pure digits now if they are long enough (e.g. 40378170)
|
| 94 |
+
# Match 4+ digits OR alphanumeric
|
| 95 |
+
token_match = re.search(r'\b([A-Z0-9-]{4,})\b', line)
|
| 96 |
+
if token_match:
|
| 97 |
+
return token_match.group(1)
|
| 98 |
+
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
|
| 102 |
+
if not text: return None
|
| 103 |
+
|
| 104 |
+
# Look for "Bill To" block
|
| 105 |
+
match = re.search(r'(?:BILL|BILLED)\s*TO[:\s]+([^\n]+)', text, re.IGNORECASE)
|
| 106 |
+
if match:
|
| 107 |
+
name = match.group(1).strip()
|
| 108 |
+
return {"name": name, "email": None}
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def extract_line_items(text: str) -> List[Dict[str, Any]]:
|
| 112 |
+
# (Keeping your existing logic simple for now)
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
def structure_output(text: str) -> Dict[str, Any]:
|
| 116 |
+
"""Legacy wrapper for rule-based-only pipeline"""
|
| 117 |
+
return {
|
| 118 |
+
"receipt_number": extract_invoice_number(text),
|
| 119 |
+
"date": extract_dates(text)[0] if extract_dates(text) else None,
|
| 120 |
+
"total_amount": extract_total(text),
|
| 121 |
+
"vendor": extract_vendor(text),
|
| 122 |
+
"raw_text": text
|
| 123 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ml_extraction.py
CHANGED
|
@@ -1,176 +1,144 @@
|
|
| 1 |
-
# src/ml_extraction.py
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import pytesseract
|
| 7 |
-
from typing import List, Dict, Any
|
| 8 |
-
import re
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-sroie-invoice-extraction"
|
| 15 |
-
|
| 16 |
-
# ---
|
| 17 |
-
def load_model_and_processor(model_path, hub_id):
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
print(f"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
if
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
""
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
#
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
]
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
| 145 |
-
|
| 146 |
-
# 7. Format the output to be consistent with your rule-based output
|
| 147 |
-
# Format the output to be consistent with the desired UI structure
|
| 148 |
-
# Format the output to be a superset of all possible fields
|
| 149 |
-
final_output = {
|
| 150 |
-
# --- Standard UI Fields ---
|
| 151 |
-
"receipt_number": None, # SROIE doesn't train for this. Your regex model will provide it.
|
| 152 |
-
"date": extracted_entities.get("DATE", {}).get("text"),
|
| 153 |
-
"bill_to": None, # SROIE doesn't train for this. Your regex model will provide it.
|
| 154 |
-
"items": [], # SROIE doesn't train for line items.
|
| 155 |
-
"total_amount": None,
|
| 156 |
-
|
| 157 |
-
# --- Additional Fields from ML Model ---
|
| 158 |
-
"vendor": extracted_entities.get("COMPANY", {}).get("text"), # The ML model finds 'COMPANY'
|
| 159 |
-
"address": extracted_entities.get("ADDRESS", {}).get("text"),
|
| 160 |
-
|
| 161 |
-
# --- Debugging Info ---
|
| 162 |
-
"raw_text": " ".join(words),
|
| 163 |
-
"raw_ocr_words": words,
|
| 164 |
-
"raw_predictions": extracted_entities
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
# Safely extract and convert total
|
| 168 |
-
total_text = extracted_entities.get("TOTAL", {}).get("text")
|
| 169 |
-
if total_text:
|
| 170 |
-
try:
|
| 171 |
-
cleaned_total = re.sub(r'[^\d.]', '', total_text)
|
| 172 |
-
final_output["total_amount"] = float(cleaned_total)
|
| 173 |
-
except (ValueError, TypeError):
|
| 174 |
-
final_output["total_amount"] = None
|
| 175 |
-
|
| 176 |
return final_output
|
|
|
|
| 1 |
+
# src/ml_extraction.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import pytesseract
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
import re
|
| 9 |
+
import numpy as np
|
| 10 |
+
from extraction import extract_invoice_number, extract_total
|
| 11 |
+
|
| 12 |
+
# --- CONFIGURATION ---
|
| 13 |
+
LOCAL_MODEL_PATH = "./models/layoutlmv3-generalized"
|
| 14 |
+
HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-sroie-invoice-extraction"
|
| 15 |
+
|
| 16 |
+
# --- Load Model ---
|
| 17 |
+
def load_model_and_processor(model_path, hub_id):
|
| 18 |
+
try:
|
| 19 |
+
print(f"Attempting to load model from local path: {model_path}...")
|
| 20 |
+
processor = LayoutLMv3Processor.from_pretrained(model_path)
|
| 21 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
|
| 22 |
+
print("✅ Model loaded successfully from local path.")
|
| 23 |
+
except OSError:
|
| 24 |
+
print(f"Model not found locally. Downloading from Hub: {hub_id}...")
|
| 25 |
+
from huggingface_hub import snapshot_download
|
| 26 |
+
snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
|
| 27 |
+
processor = LayoutLMv3Processor.from_pretrained(model_path)
|
| 28 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
|
| 29 |
+
print("✅ Model downloaded and loaded successfully.")
|
| 30 |
+
return model, processor
|
| 31 |
+
|
| 32 |
+
MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
|
| 33 |
+
|
| 34 |
+
if MODEL and PROCESSOR:
|
| 35 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
MODEL.to(DEVICE)
|
| 37 |
+
MODEL.eval()
|
| 38 |
+
print(f"ML Model is ready on device: {DEVICE}")
|
| 39 |
+
else:
|
| 40 |
+
DEVICE = None
|
| 41 |
+
print("❌ Could not load ML model.")
|
| 42 |
+
|
| 43 |
+
def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
|
| 44 |
+
word_ids = encoding.word_ids(batch_index=0)
|
| 45 |
+
word_level_preds = {}
|
| 46 |
+
for idx, word_id in enumerate(word_ids):
|
| 47 |
+
if word_id is not None:
|
| 48 |
+
label_id = predictions[idx]
|
| 49 |
+
if label_id != -100:
|
| 50 |
+
if word_id not in word_level_preds:
|
| 51 |
+
word_level_preds[word_id] = id2label[label_id]
|
| 52 |
+
|
| 53 |
+
entities = {}
|
| 54 |
+
for word_idx, label in word_level_preds.items():
|
| 55 |
+
if label == 'O': continue
|
| 56 |
+
entity_type = label[2:]
|
| 57 |
+
word = words[word_idx]
|
| 58 |
+
|
| 59 |
+
if label.startswith('B-'):
|
| 60 |
+
entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
|
| 61 |
+
elif label.startswith('I-') and entity_type in entities:
|
| 62 |
+
entities[entity_type]['text'] += " " + word
|
| 63 |
+
entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
|
| 64 |
+
|
| 65 |
+
for entity in entities.values():
|
| 66 |
+
entity['text'] = entity['text'].strip()
|
| 67 |
+
|
| 68 |
+
return entities
|
| 69 |
+
|
| 70 |
+
def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
| 71 |
+
if not MODEL or not PROCESSOR:
|
| 72 |
+
raise RuntimeError("ML model is not loaded.")
|
| 73 |
+
|
| 74 |
+
# 1. Load Image
|
| 75 |
+
image = Image.open(image_path).convert("RGB")
|
| 76 |
+
width, height = image.size
|
| 77 |
+
ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
| 78 |
+
|
| 79 |
+
words = []
|
| 80 |
+
unnormalized_boxes = []
|
| 81 |
+
for i in range(len(ocr_data['level'])):
|
| 82 |
+
if int(ocr_data['conf'][i]) > 30 and ocr_data['text'][i].strip() != '':
|
| 83 |
+
words.append(ocr_data['text'][i])
|
| 84 |
+
unnormalized_boxes.append([
|
| 85 |
+
ocr_data['left'][i], ocr_data['top'][i],
|
| 86 |
+
ocr_data['width'][i], ocr_data['height'][i]
|
| 87 |
+
])
|
| 88 |
+
|
| 89 |
+
raw_text = " ".join(words)
|
| 90 |
+
|
| 91 |
+
# 2. Normalize Boxes (WITH SAFETY CLAMP)
|
| 92 |
+
normalized_boxes = []
|
| 93 |
+
for box in unnormalized_boxes:
|
| 94 |
+
x, y, w, h = box
|
| 95 |
+
x0, y0, x1, y1 = x, y, x + w, y + h
|
| 96 |
+
|
| 97 |
+
# ⚠️ The Fix: Ensure values never exceed 1000 or drop below 0
|
| 98 |
+
normalized_boxes.append([
|
| 99 |
+
max(0, min(1000, int(1000 * (x0 / width)))),
|
| 100 |
+
max(0, min(1000, int(1000 * (y0 / height)))),
|
| 101 |
+
max(0, min(1000, int(1000 * (x1 / width)))),
|
| 102 |
+
max(0, min(1000, int(1000 * (y1 / height)))),
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
# 3. Inference
|
| 106 |
+
encoding = PROCESSOR(
|
| 107 |
+
image, text=words, boxes=normalized_boxes,
|
| 108 |
+
truncation=True, max_length=512, return_tensors="pt"
|
| 109 |
+
).to(DEVICE)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = MODEL(**encoding)
|
| 113 |
+
|
| 114 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 115 |
+
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
| 116 |
+
|
| 117 |
+
# 4. Construct Output
|
| 118 |
+
final_output = {
|
| 119 |
+
"vendor": extracted_entities.get("COMPANY", {}).get("text"),
|
| 120 |
+
"date": extracted_entities.get("DATE", {}).get("text"),
|
| 121 |
+
"address": extracted_entities.get("ADDRESS", {}).get("text"),
|
| 122 |
+
"receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
|
| 123 |
+
"bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
|
| 124 |
+
"total_amount": None,
|
| 125 |
+
"items": [],
|
| 126 |
+
"raw_text": raw_text
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Fallbacks
|
| 130 |
+
ml_total = extracted_entities.get("TOTAL", {}).get("text")
|
| 131 |
+
if ml_total:
|
| 132 |
+
try:
|
| 133 |
+
cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
|
| 134 |
+
final_output["total_amount"] = float(cleaned)
|
| 135 |
+
except (ValueError, TypeError):
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
if final_output["total_amount"] is None:
|
| 139 |
+
final_output["total_amount"] = extract_total(raw_text)
|
| 140 |
+
|
| 141 |
+
if not final_output["receipt_number"]:
|
| 142 |
+
final_output["receipt_number"] = extract_invoice_number(raw_text)
|
| 143 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
return final_output
|
src/ocr.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
import pytesseract
|
| 2 |
-
import numpy as np
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
|
| 6 |
-
|
| 7 |
-
def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
|
| 8 |
-
if image is None:
|
| 9 |
-
raise ValueError("Input image is None")
|
| 10 |
-
text = pytesseract.image_to_string(image, lang=lang, config=config)
|
| 11 |
-
return text.strip()
|
| 12 |
-
|
| 13 |
-
def extract_text_with_boxes(image):
|
| 14 |
-
pass
|
| 15 |
-
|
|
|
|
| 1 |
+
import pytesseract
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
#pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
|
| 6 |
+
|
| 7 |
+
def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
|
| 8 |
+
if image is None:
|
| 9 |
+
raise ValueError("Input image is None")
|
| 10 |
+
text = pytesseract.image_to_string(image, lang=lang, config=config)
|
| 11 |
+
return text.strip()
|
| 12 |
+
|
| 13 |
+
def extract_text_with_boxes(image):
|
| 14 |
+
pass
|
| 15 |
+
|
src/pipeline.py
CHANGED
|
@@ -1,151 +1,151 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Main invoice processing pipeline
|
| 3 |
-
Orchestrates preprocessing, OCR, and extraction
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
from typing import Dict, Any, Optional
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
import json
|
| 9 |
-
|
| 10 |
-
# Make sure all your modules are imported
|
| 11 |
-
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 12 |
-
from ocr import extract_text
|
| 13 |
-
from extraction import structure_output
|
| 14 |
-
from ml_extraction import extract_ml_based
|
| 15 |
-
|
| 16 |
-
def process_invoice(image_path: str,
|
| 17 |
-
method: str = 'ml', # <-- New parameter: 'ml' or 'rules'
|
| 18 |
-
save_results: bool = False,
|
| 19 |
-
output_dir: str = 'outputs') -> Dict[str, Any]:
|
| 20 |
-
"""
|
| 21 |
-
Process an invoice image using either rule-based or ML-based extraction.
|
| 22 |
-
|
| 23 |
-
Args:
|
| 24 |
-
image_path: Path to the invoice image.
|
| 25 |
-
method: The extraction method to use ('ml' or 'rules'). Default is 'ml'.
|
| 26 |
-
save_results: Whether to save JSON results to a file.
|
| 27 |
-
output_dir: Directory to save results.
|
| 28 |
-
|
| 29 |
-
Returns:
|
| 30 |
-
A dictionary with the extracted invoice data.
|
| 31 |
-
"""
|
| 32 |
-
if not Path(image_path).exists():
|
| 33 |
-
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
| 34 |
-
|
| 35 |
-
print(f"Processing with '{method}' method...")
|
| 36 |
-
|
| 37 |
-
if method == 'ml':
|
| 38 |
-
# --- ML-Based Extraction ---
|
| 39 |
-
try:
|
| 40 |
-
# The ml_extraction function handles everything internally
|
| 41 |
-
structured_data = extract_ml_based(image_path)
|
| 42 |
-
except Exception as e:
|
| 43 |
-
raise ValueError(f"Error during ML-based extraction: {e}")
|
| 44 |
-
|
| 45 |
-
elif method == 'rules':
|
| 46 |
-
# --- Rule-Based Extraction (Your original logic) ---
|
| 47 |
-
try:
|
| 48 |
-
image = load_image(image_path)
|
| 49 |
-
gray_image = convert_to_grayscale(image)
|
| 50 |
-
preprocessed_image = remove_noise(gray_image, kernel_size=3)
|
| 51 |
-
text = extract_text(preprocessed_image, config='--psm 6')
|
| 52 |
-
structured_data = structure_output(text) # Calls your old extraction.py
|
| 53 |
-
except Exception as e:
|
| 54 |
-
raise ValueError(f"Error during rule-based extraction: {e}")
|
| 55 |
-
|
| 56 |
-
else:
|
| 57 |
-
raise ValueError(f"Unknown extraction method: '{method}'. Choose 'ml' or 'rules'.")
|
| 58 |
-
|
| 59 |
-
# --- Saving Logic (remains the same) ---
|
| 60 |
-
if save_results:
|
| 61 |
-
output_path = Path(output_dir)
|
| 62 |
-
output_path.mkdir(parents=True, exist_ok=True)
|
| 63 |
-
json_path = output_path / (Path(image_path).stem + f"_{method}.json") # Add method to filename
|
| 64 |
-
try:
|
| 65 |
-
with open(json_path, 'w', encoding='utf-8') as f:
|
| 66 |
-
json.dump(structured_data, f, indent=2, ensure_ascii=False)
|
| 67 |
-
except Exception as e:
|
| 68 |
-
raise IOError(f"Error saving results to {json_path}: {e}")
|
| 69 |
-
|
| 70 |
-
return structured_data
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
|
| 74 |
-
"""Process multiple invoices in a folder""" # Corrected indentation
|
| 75 |
-
results = []
|
| 76 |
-
|
| 77 |
-
supported_extensions = ['*.jpg', '*.png', '*.jpeg']
|
| 78 |
-
|
| 79 |
-
for ext in supported_extensions:
|
| 80 |
-
for img_file in Path(image_folder).glob(ext):
|
| 81 |
-
print(f"🔄 Processing: {img_file}")
|
| 82 |
-
try:
|
| 83 |
-
result = process_invoice(str(img_file), save_results=True, output_dir=output_dir)
|
| 84 |
-
results.append(result)
|
| 85 |
-
except Exception as e:
|
| 86 |
-
print(f"❌ Error processing {img_file}: {e}")
|
| 87 |
-
|
| 88 |
-
print(f"\n🎉 Batch processing complete! {len(results)} invoices processed.")
|
| 89 |
-
return results
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def main():
|
| 93 |
-
"""Command-line interface for invoice processing"""
|
| 94 |
-
import argparse
|
| 95 |
-
|
| 96 |
-
parser = argparse.ArgumentParser(
|
| 97 |
-
description='Process invoice images or folders and extract structured data.',
|
| 98 |
-
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 99 |
-
epilog="""
|
| 100 |
-
Examples:
|
| 101 |
-
# Process a single invoice
|
| 102 |
-
python src/pipeline.py data/raw/receipt1.jpg
|
| 103 |
-
|
| 104 |
-
# Process and save a single invoice
|
| 105 |
-
python src/pipeline.py data/raw/receipt1.jpg --save
|
| 106 |
-
|
| 107 |
-
# Process an entire folder of invoices
|
| 108 |
-
python src/pipeline.py data/raw --save --output results/
|
| 109 |
-
"""
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
# Corrected: Single 'path' argument
|
| 113 |
-
parser.add_argument('path', help='Path to an invoice image or a folder of images')
|
| 114 |
-
parser.add_argument('--save', action='store_true', help='Save results to JSON files')
|
| 115 |
-
parser.add_argument('--output', default='outputs', help='Output directory for JSON files')
|
| 116 |
-
parser.add_argument('--method', default='ml', choices=['ml', 'rules'], help="Extraction method: 'ml' or 'rules'")
|
| 117 |
-
|
| 118 |
-
args = parser.parse_args()
|
| 119 |
-
|
| 120 |
-
try:
|
| 121 |
-
# Check if path is a directory or a file
|
| 122 |
-
if Path(args.path).is_dir():
|
| 123 |
-
process_batch(args.path, output_dir=args.output)
|
| 124 |
-
elif Path(args.path).is_file():
|
| 125 |
-
# Corrected: Use args.path
|
| 126 |
-
print(f"🔄 Processing: {args.path}")
|
| 127 |
-
result = process_invoice(args.path, method=args.method, save_results=args.save, output_dir=args.output)
|
| 128 |
-
|
| 129 |
-
print("\n📊 Extracted Data:")
|
| 130 |
-
print("=" * 60)
|
| 131 |
-
print(f"Vendor: {result.get('vendor', 'N/A')}")
|
| 132 |
-
print(f"Invoice Number: {result.get('invoice_number', 'N/A')}")
|
| 133 |
-
print(f"Date: {result.get('date', 'N/A')}")
|
| 134 |
-
print(f"Total: ${result.get('
|
| 135 |
-
print("=" * 60)
|
| 136 |
-
|
| 137 |
-
if args.save:
|
| 138 |
-
print(f"\n💾 JSON saved to: {args.output}/{Path(args.path).stem}.json")
|
| 139 |
-
else:
|
| 140 |
-
raise FileNotFoundError(f"Path does not exist: {args.path}")
|
| 141 |
-
|
| 142 |
-
except Exception as e:
|
| 143 |
-
print(f"❌ An error occurred: {e}")
|
| 144 |
-
return 1
|
| 145 |
-
|
| 146 |
-
return 0
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
if __name__ == '__main__':
|
| 150 |
-
import sys
|
| 151 |
sys.exit(main())
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main invoice processing pipeline
|
| 3 |
+
Orchestrates preprocessing, OCR, and extraction
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
# Make sure all your modules are imported
|
| 11 |
+
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 12 |
+
from ocr import extract_text
|
| 13 |
+
from extraction import structure_output
|
| 14 |
+
from ml_extraction import extract_ml_based
|
| 15 |
+
|
| 16 |
+
def process_invoice(image_path: str,
|
| 17 |
+
method: str = 'ml', # <-- New parameter: 'ml' or 'rules'
|
| 18 |
+
save_results: bool = False,
|
| 19 |
+
output_dir: str = 'outputs') -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Process an invoice image using either rule-based or ML-based extraction.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
image_path: Path to the invoice image.
|
| 25 |
+
method: The extraction method to use ('ml' or 'rules'). Default is 'ml'.
|
| 26 |
+
save_results: Whether to save JSON results to a file.
|
| 27 |
+
output_dir: Directory to save results.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
A dictionary with the extracted invoice data.
|
| 31 |
+
"""
|
| 32 |
+
if not Path(image_path).exists():
|
| 33 |
+
raise FileNotFoundError(f"Image not found at path: {image_path}")
|
| 34 |
+
|
| 35 |
+
print(f"Processing with '{method}' method...")
|
| 36 |
+
|
| 37 |
+
if method == 'ml':
|
| 38 |
+
# --- ML-Based Extraction ---
|
| 39 |
+
try:
|
| 40 |
+
# The ml_extraction function handles everything internally
|
| 41 |
+
structured_data = extract_ml_based(image_path)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
raise ValueError(f"Error during ML-based extraction: {e}")
|
| 44 |
+
|
| 45 |
+
elif method == 'rules':
|
| 46 |
+
# --- Rule-Based Extraction (Your original logic) ---
|
| 47 |
+
try:
|
| 48 |
+
image = load_image(image_path)
|
| 49 |
+
gray_image = convert_to_grayscale(image)
|
| 50 |
+
preprocessed_image = remove_noise(gray_image, kernel_size=3)
|
| 51 |
+
text = extract_text(preprocessed_image, config='--psm 6')
|
| 52 |
+
structured_data = structure_output(text) # Calls your old extraction.py
|
| 53 |
+
except Exception as e:
|
| 54 |
+
raise ValueError(f"Error during rule-based extraction: {e}")
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Unknown extraction method: '{method}'. Choose 'ml' or 'rules'.")
|
| 58 |
+
|
| 59 |
+
# --- Saving Logic (remains the same) ---
|
| 60 |
+
if save_results:
|
| 61 |
+
output_path = Path(output_dir)
|
| 62 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
json_path = output_path / (Path(image_path).stem + f"_{method}.json") # Add method to filename
|
| 64 |
+
try:
|
| 65 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 66 |
+
json.dump(structured_data, f, indent=2, ensure_ascii=False)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
raise IOError(f"Error saving results to {json_path}: {e}")
|
| 69 |
+
|
| 70 |
+
return structured_data
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
|
| 74 |
+
"""Process multiple invoices in a folder""" # Corrected indentation
|
| 75 |
+
results = []
|
| 76 |
+
|
| 77 |
+
supported_extensions = ['*.jpg', '*.png', '*.jpeg']
|
| 78 |
+
|
| 79 |
+
for ext in supported_extensions:
|
| 80 |
+
for img_file in Path(image_folder).glob(ext):
|
| 81 |
+
print(f"🔄 Processing: {img_file}")
|
| 82 |
+
try:
|
| 83 |
+
result = process_invoice(str(img_file), save_results=True, output_dir=output_dir)
|
| 84 |
+
results.append(result)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"❌ Error processing {img_file}: {e}")
|
| 87 |
+
|
| 88 |
+
print(f"\n🎉 Batch processing complete! {len(results)} invoices processed.")
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
"""Command-line interface for invoice processing"""
|
| 94 |
+
import argparse
|
| 95 |
+
|
| 96 |
+
parser = argparse.ArgumentParser(
|
| 97 |
+
description='Process invoice images or folders and extract structured data.',
|
| 98 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 99 |
+
epilog="""
|
| 100 |
+
Examples:
|
| 101 |
+
# Process a single invoice
|
| 102 |
+
python src/pipeline.py data/raw/receipt1.jpg
|
| 103 |
+
|
| 104 |
+
# Process and save a single invoice
|
| 105 |
+
python src/pipeline.py data/raw/receipt1.jpg --save
|
| 106 |
+
|
| 107 |
+
# Process an entire folder of invoices
|
| 108 |
+
python src/pipeline.py data/raw --save --output results/
|
| 109 |
+
"""
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Corrected: Single 'path' argument
|
| 113 |
+
parser.add_argument('path', help='Path to an invoice image or a folder of images')
|
| 114 |
+
parser.add_argument('--save', action='store_true', help='Save results to JSON files')
|
| 115 |
+
parser.add_argument('--output', default='outputs', help='Output directory for JSON files')
|
| 116 |
+
parser.add_argument('--method', default='ml', choices=['ml', 'rules'], help="Extraction method: 'ml' or 'rules'")
|
| 117 |
+
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Check if path is a directory or a file
|
| 122 |
+
if Path(args.path).is_dir():
|
| 123 |
+
process_batch(args.path, output_dir=args.output)
|
| 124 |
+
elif Path(args.path).is_file():
|
| 125 |
+
# Corrected: Use args.path
|
| 126 |
+
print(f"🔄 Processing: {args.path}")
|
| 127 |
+
result = process_invoice(args.path, method=args.method, save_results=args.save, output_dir=args.output)
|
| 128 |
+
|
| 129 |
+
print("\n📊 Extracted Data:")
|
| 130 |
+
print("=" * 60)
|
| 131 |
+
print(f"Vendor: {result.get('vendor', 'N/A')}")
|
| 132 |
+
print(f"Invoice Number: {result.get('invoice_number', 'N/A')}")
|
| 133 |
+
print(f"Date: {result.get('date', 'N/A')}")
|
| 134 |
+
print(f"Total: ${result.get('total_amount', 0.0)}")
|
| 135 |
+
print("=" * 60)
|
| 136 |
+
|
| 137 |
+
if args.save:
|
| 138 |
+
print(f"\n💾 JSON saved to: {args.output}/{Path(args.path).stem}.json")
|
| 139 |
+
else:
|
| 140 |
+
raise FileNotFoundError(f"Path does not exist: {args.path}")
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"❌ An error occurred: {e}")
|
| 144 |
+
return 1
|
| 145 |
+
|
| 146 |
+
return 0
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == '__main__':
|
| 150 |
+
import sys
|
| 151 |
sys.exit(main())
|
src/preprocessing.py
CHANGED
|
@@ -1,78 +1,78 @@
|
|
| 1 |
-
import cv2
|
| 2 |
-
import numpy as np
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def load_image(image_path: str) -> np.ndarray:
|
| 8 |
-
if not Path(image_path).exists():
|
| 9 |
-
raise FileNotFoundError(f"Image not found : {image_path}")
|
| 10 |
-
image = cv2.imread(image_path)
|
| 11 |
-
if image is None:
|
| 12 |
-
raise ValueError(f"Could not load image: {image_path}")
|
| 13 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 14 |
-
return image
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
|
| 18 |
-
if image is None:
|
| 19 |
-
raise ValueError(f"Image is None, cannot convert to grayscale")
|
| 20 |
-
if len(image.shape) ==2:
|
| 21 |
-
return image
|
| 22 |
-
return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def remove_noise(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
| 26 |
-
if image is None:
|
| 27 |
-
raise ValueError(f"Image is None, cannot remove noise")
|
| 28 |
-
if kernel_size <= 0:
|
| 29 |
-
raise ValueError("Kernel size must be positive")
|
| 30 |
-
if kernel_size % 2 == 0:
|
| 31 |
-
raise ValueError("Kernel size must be odd")
|
| 32 |
-
denoised_image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
| 33 |
-
return denoised_image
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def binarize(image: np.ndarray, method: str = 'adaptive', block_size: int=11, C: int=2) -> np.ndarray:
|
| 37 |
-
if image is None:
|
| 38 |
-
raise ValueError(f"Image is None, cannot binarize")
|
| 39 |
-
if image.ndim != 2:
|
| 40 |
-
raise ValueError("Input image must be grayscale for binarization")
|
| 41 |
-
if method == 'simple':
|
| 42 |
-
_, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
|
| 43 |
-
elif method == 'adaptive':
|
| 44 |
-
binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, block_size, C)
|
| 45 |
-
else:
|
| 46 |
-
raise ValueError(f"Unknown binarization method: {method}")
|
| 47 |
-
return binary_image
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def deskew(image):
|
| 51 |
-
pass
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def preprocess_pipeline(image: np.ndarray,
|
| 55 |
-
steps: list = ['grayscale', 'denoise', 'binarize'],
|
| 56 |
-
denoise_kernel: int = 3,
|
| 57 |
-
binarize_method: str = 'adaptive',
|
| 58 |
-
binarize_block_size: int = 11,
|
| 59 |
-
binarize_C: int = 2) -> np.ndarray:
|
| 60 |
-
if image is None:
|
| 61 |
-
raise ValueError("Input image is None")
|
| 62 |
-
|
| 63 |
-
processed = image
|
| 64 |
-
|
| 65 |
-
for step in steps:
|
| 66 |
-
if step == 'grayscale':
|
| 67 |
-
processed = convert_to_grayscale(processed)
|
| 68 |
-
elif step == 'denoise':
|
| 69 |
-
processed = remove_noise(processed, kernel_size=denoise_kernel)
|
| 70 |
-
elif step == 'binarize':
|
| 71 |
-
processed = binarize(processed,
|
| 72 |
-
method=binarize_method,
|
| 73 |
-
block_size=binarize_block_size,
|
| 74 |
-
C=binarize_C)
|
| 75 |
-
else:
|
| 76 |
-
raise ValueError(f"Unknown preprocessing step: {step}")
|
| 77 |
-
|
| 78 |
-
return processed
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_image(image_path: str) -> np.ndarray:
|
| 8 |
+
if not Path(image_path).exists():
|
| 9 |
+
raise FileNotFoundError(f"Image not found : {image_path}")
|
| 10 |
+
image = cv2.imread(image_path)
|
| 11 |
+
if image is None:
|
| 12 |
+
raise ValueError(f"Could not load image: {image_path}")
|
| 13 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 14 |
+
return image
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
|
| 18 |
+
if image is None:
|
| 19 |
+
raise ValueError(f"Image is None, cannot convert to grayscale")
|
| 20 |
+
if len(image.shape) ==2:
|
| 21 |
+
return image
|
| 22 |
+
return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def remove_noise(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
| 26 |
+
if image is None:
|
| 27 |
+
raise ValueError(f"Image is None, cannot remove noise")
|
| 28 |
+
if kernel_size <= 0:
|
| 29 |
+
raise ValueError("Kernel size must be positive")
|
| 30 |
+
if kernel_size % 2 == 0:
|
| 31 |
+
raise ValueError("Kernel size must be odd")
|
| 32 |
+
denoised_image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
| 33 |
+
return denoised_image
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def binarize(image: np.ndarray, method: str = 'adaptive', block_size: int=11, C: int=2) -> np.ndarray:
|
| 37 |
+
if image is None:
|
| 38 |
+
raise ValueError(f"Image is None, cannot binarize")
|
| 39 |
+
if image.ndim != 2:
|
| 40 |
+
raise ValueError("Input image must be grayscale for binarization")
|
| 41 |
+
if method == 'simple':
|
| 42 |
+
_, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
|
| 43 |
+
elif method == 'adaptive':
|
| 44 |
+
binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, block_size, C)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unknown binarization method: {method}")
|
| 47 |
+
return binary_image
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def deskew(image):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def preprocess_pipeline(image: np.ndarray,
|
| 55 |
+
steps: list = ['grayscale', 'denoise', 'binarize'],
|
| 56 |
+
denoise_kernel: int = 3,
|
| 57 |
+
binarize_method: str = 'adaptive',
|
| 58 |
+
binarize_block_size: int = 11,
|
| 59 |
+
binarize_C: int = 2) -> np.ndarray:
|
| 60 |
+
if image is None:
|
| 61 |
+
raise ValueError("Input image is None")
|
| 62 |
+
|
| 63 |
+
processed = image
|
| 64 |
+
|
| 65 |
+
for step in steps:
|
| 66 |
+
if step == 'grayscale':
|
| 67 |
+
processed = convert_to_grayscale(processed)
|
| 68 |
+
elif step == 'denoise':
|
| 69 |
+
processed = remove_noise(processed, kernel_size=denoise_kernel)
|
| 70 |
+
elif step == 'binarize':
|
| 71 |
+
processed = binarize(processed,
|
| 72 |
+
method=binarize_method,
|
| 73 |
+
block_size=binarize_block_size,
|
| 74 |
+
C=binarize_C)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown preprocessing step: {step}")
|
| 77 |
+
|
| 78 |
+
return processed
|
tests/test_extraction.py
CHANGED
|
@@ -1,41 +1,41 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
sys.path.append('src')
|
| 3 |
-
|
| 4 |
-
from extraction import extract_dates, extract_amounts, extract_total, extract_vendor, extract_invoice_number
|
| 5 |
-
|
| 6 |
-
receipt_text = """
|
| 7 |
-
tan chay yee
|
| 8 |
-
|
| 9 |
-
*** COPY ***
|
| 10 |
-
|
| 11 |
-
OJC MARKETING SDN BHD.
|
| 12 |
-
|
| 13 |
-
ROC NO: 538358-H
|
| 14 |
-
|
| 15 |
-
TAX INVOICE
|
| 16 |
-
|
| 17 |
-
Invoice No: PEGIV-1030765
|
| 18 |
-
Date: 15/01/2019 11:05:16 AM
|
| 19 |
-
|
| 20 |
-
TOTAL: 193.00
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
print("🧪 Testing Extraction Functions")
|
| 24 |
-
print("=" * 60)
|
| 25 |
-
|
| 26 |
-
dates = extract_dates(receipt_text)
|
| 27 |
-
print(f"\n📅 Date: {dates}")
|
| 28 |
-
|
| 29 |
-
amounts = extract_amounts(receipt_text)
|
| 30 |
-
print(f"\n💰 Amounts: {amounts}")
|
| 31 |
-
|
| 32 |
-
total = extract_total(receipt_text)
|
| 33 |
-
print(f"\n💵 Total: {total}")
|
| 34 |
-
|
| 35 |
-
vendor = extract_vendor(receipt_text)
|
| 36 |
-
print(f"\n🏢 Vendor: {vendor}")
|
| 37 |
-
|
| 38 |
-
invoice_num = extract_invoice_number(receipt_text)
|
| 39 |
-
print(f"\n📄 Invoice Number: {invoice_num}")
|
| 40 |
-
|
| 41 |
print("\n✅ All extraction tests complete!")
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('src')
|
| 3 |
+
|
| 4 |
+
from extraction import extract_dates, extract_amounts, extract_total, extract_vendor, extract_invoice_number
|
| 5 |
+
|
| 6 |
+
receipt_text = """
|
| 7 |
+
tan chay yee
|
| 8 |
+
|
| 9 |
+
*** COPY ***
|
| 10 |
+
|
| 11 |
+
OJC MARKETING SDN BHD.
|
| 12 |
+
|
| 13 |
+
ROC NO: 538358-H
|
| 14 |
+
|
| 15 |
+
TAX INVOICE
|
| 16 |
+
|
| 17 |
+
Invoice No: PEGIV-1030765
|
| 18 |
+
Date: 15/01/2019 11:05:16 AM
|
| 19 |
+
|
| 20 |
+
TOTAL: 193.00
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
print("🧪 Testing Extraction Functions")
|
| 24 |
+
print("=" * 60)
|
| 25 |
+
|
| 26 |
+
dates = extract_dates(receipt_text)
|
| 27 |
+
print(f"\n📅 Date: {dates}")
|
| 28 |
+
|
| 29 |
+
amounts = extract_amounts(receipt_text)
|
| 30 |
+
print(f"\n💰 Amounts: {amounts}")
|
| 31 |
+
|
| 32 |
+
total = extract_total(receipt_text)
|
| 33 |
+
print(f"\n💵 Total: {total}")
|
| 34 |
+
|
| 35 |
+
vendor = extract_vendor(receipt_text)
|
| 36 |
+
print(f"\n🏢 Vendor: {vendor}")
|
| 37 |
+
|
| 38 |
+
invoice_num = extract_invoice_number(receipt_text)
|
| 39 |
+
print(f"\n📄 Invoice Number: {invoice_num}")
|
| 40 |
+
|
| 41 |
print("\n✅ All extraction tests complete!")
|
tests/test_full_pipeline.py
CHANGED
|
@@ -1,42 +1,42 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
sys.path.append('src')
|
| 3 |
-
|
| 4 |
-
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 5 |
-
from ocr import extract_text
|
| 6 |
-
from extraction import structure_output
|
| 7 |
-
import json
|
| 8 |
-
|
| 9 |
-
print("=" * 60)
|
| 10 |
-
print("🎯 FULL INVOICE PROCESSING PIPELINE TEST")
|
| 11 |
-
print("=" * 60)
|
| 12 |
-
|
| 13 |
-
# Step 1: Load and preprocess image
|
| 14 |
-
print("\n1️⃣ Loading and preprocessing image...")
|
| 15 |
-
image = load_image('data/raw/receipt3.jpg')
|
| 16 |
-
gray = convert_to_grayscale(image)
|
| 17 |
-
denoised = remove_noise(gray, kernel_size=3)
|
| 18 |
-
print("✅ Image preprocessed")
|
| 19 |
-
|
| 20 |
-
# Step 2: Extract text with OCR
|
| 21 |
-
print("\n2️⃣ Extracting text with OCR...")
|
| 22 |
-
text = extract_text(denoised, config='--psm 6')
|
| 23 |
-
print(f"✅ Extracted {len(text)} characters")
|
| 24 |
-
|
| 25 |
-
# Step 3: Extract structured information
|
| 26 |
-
print("\n3️⃣ Extracting structured information...")
|
| 27 |
-
result = structure_output(text)
|
| 28 |
-
print("✅ Information extracted")
|
| 29 |
-
|
| 30 |
-
# Step 4: Display results
|
| 31 |
-
print("\n" + "=" * 60)
|
| 32 |
-
print("📊 EXTRACTED INVOICE DATA (JSON)")
|
| 33 |
-
print("=" * 60)
|
| 34 |
-
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 35 |
-
print("=" * 60)
|
| 36 |
-
|
| 37 |
-
print("\n🎉 PIPELINE COMPLETE!")
|
| 38 |
-
print("\n📋 Summary:")
|
| 39 |
-
print(f" Vendor: {result['vendor']}")
|
| 40 |
-
print(f" Invoice #: {result['invoice_number']}")
|
| 41 |
-
print(f" Date: {result['date']}")
|
| 42 |
print(f" Total: ${result['total']}")
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('src')
|
| 3 |
+
|
| 4 |
+
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 5 |
+
from ocr import extract_text
|
| 6 |
+
from extraction import structure_output
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
print("=" * 60)
|
| 10 |
+
print("🎯 FULL INVOICE PROCESSING PIPELINE TEST")
|
| 11 |
+
print("=" * 60)
|
| 12 |
+
|
| 13 |
+
# Step 1: Load and preprocess image
|
| 14 |
+
print("\n1️⃣ Loading and preprocessing image...")
|
| 15 |
+
image = load_image('data/raw/receipt3.jpg')
|
| 16 |
+
gray = convert_to_grayscale(image)
|
| 17 |
+
denoised = remove_noise(gray, kernel_size=3)
|
| 18 |
+
print("✅ Image preprocessed")
|
| 19 |
+
|
| 20 |
+
# Step 2: Extract text with OCR
|
| 21 |
+
print("\n2️⃣ Extracting text with OCR...")
|
| 22 |
+
text = extract_text(denoised, config='--psm 6')
|
| 23 |
+
print(f"✅ Extracted {len(text)} characters")
|
| 24 |
+
|
| 25 |
+
# Step 3: Extract structured information
|
| 26 |
+
print("\n3️⃣ Extracting structured information...")
|
| 27 |
+
result = structure_output(text)
|
| 28 |
+
print("✅ Information extracted")
|
| 29 |
+
|
| 30 |
+
# Step 4: Display results
|
| 31 |
+
print("\n" + "=" * 60)
|
| 32 |
+
print("📊 EXTRACTED INVOICE DATA (JSON)")
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 35 |
+
print("=" * 60)
|
| 36 |
+
|
| 37 |
+
print("\n🎉 PIPELINE COMPLETE!")
|
| 38 |
+
print("\n📋 Summary:")
|
| 39 |
+
print(f" Vendor: {result['vendor']}")
|
| 40 |
+
print(f" Invoice #: {result['invoice_number']}")
|
| 41 |
+
print(f" Date: {result['date']}")
|
| 42 |
print(f" Total: ${result['total']}")
|
tests/test_ocr.py
CHANGED
|
@@ -1,101 +1,101 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
sys.path.append('src')
|
| 3 |
-
|
| 4 |
-
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 5 |
-
from ocr import extract_text
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
print("=" * 60)
|
| 10 |
-
print("🎯 OPTIMIZING GRAYSCALE OCR")
|
| 11 |
-
print("=" * 60)
|
| 12 |
-
|
| 13 |
-
# Load and convert to grayscale
|
| 14 |
-
image = load_image('data/raw/receipt3.jpg')
|
| 15 |
-
gray = convert_to_grayscale(image)
|
| 16 |
-
|
| 17 |
-
# Test 1: Different PSM modes
|
| 18 |
-
print("\n📊 Testing different Tesseract PSM modes...\n")
|
| 19 |
-
|
| 20 |
-
psm_configs = [
|
| 21 |
-
('', 'Default'),
|
| 22 |
-
('--psm 3', 'Automatic page segmentation'),
|
| 23 |
-
('--psm 4', 'Single column of text'),
|
| 24 |
-
('--psm 6', 'Uniform block of text'),
|
| 25 |
-
('--psm 11', 'Sparse text, find as much as possible'),
|
| 26 |
-
('--psm 12', 'Sparse text with OSD (Orientation and Script Detection)'),
|
| 27 |
-
]
|
| 28 |
-
|
| 29 |
-
results = {}
|
| 30 |
-
for config, desc in psm_configs:
|
| 31 |
-
text = extract_text(gray, config=config)
|
| 32 |
-
results[desc] = text
|
| 33 |
-
print(f"{desc:50s} → {len(text):4d} chars")
|
| 34 |
-
|
| 35 |
-
# Find best result
|
| 36 |
-
best_desc = max(results, key=lambda k: len(results[k]))
|
| 37 |
-
best_text = results[best_desc]
|
| 38 |
-
|
| 39 |
-
print(f"\n✅ WINNER: {best_desc} ({len(best_text)} chars)")
|
| 40 |
-
|
| 41 |
-
# Test 2: With slight denoising
|
| 42 |
-
print("\n📊 Testing with light denoising...\n")
|
| 43 |
-
|
| 44 |
-
denoised = remove_noise(gray, kernel_size=3)
|
| 45 |
-
text_denoised = extract_text(denoised, config='--psm 6')
|
| 46 |
-
print(f"Grayscale + Denoise (psm 6): {len(text_denoised)} chars")
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# Display best result
|
| 50 |
-
print("\n" + "=" * 60)
|
| 51 |
-
print("📄 BEST EXTRACTED TEXT:")
|
| 52 |
-
print("=" * 60)
|
| 53 |
-
print(best_text)
|
| 54 |
-
print("=" * 60)
|
| 55 |
-
|
| 56 |
-
# Visualize
|
| 57 |
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 58 |
-
|
| 59 |
-
axes[0].imshow(image)
|
| 60 |
-
axes[0].set_title("Original")
|
| 61 |
-
axes[0].axis('off')
|
| 62 |
-
|
| 63 |
-
axes[1].imshow(gray, cmap='gray')
|
| 64 |
-
axes[1].set_title(f"Grayscale\n({len(best_text)} chars - {best_desc})")
|
| 65 |
-
axes[1].axis('off')
|
| 66 |
-
|
| 67 |
-
axes[2].imshow(denoised, cmap='gray')
|
| 68 |
-
axes[2].set_title(f"Denoised\n({len(text_denoised)} chars)")
|
| 69 |
-
axes[2].axis('off')
|
| 70 |
-
|
| 71 |
-
plt.tight_layout()
|
| 72 |
-
plt.show()
|
| 73 |
-
|
| 74 |
-
print(f"\n💡 Recommended pipeline: Grayscale + {best_desc}")
|
| 75 |
-
|
| 76 |
-
# Test the combination we missed!
|
| 77 |
-
print("\n📊 Testing BEST combination...\n")
|
| 78 |
-
|
| 79 |
-
denoised = remove_noise(gray, kernel_size=3)
|
| 80 |
-
|
| 81 |
-
# Test PSM 11 on denoised
|
| 82 |
-
text_denoised_psm11 = extract_text(denoised, config='--psm 11')
|
| 83 |
-
text_denoised_psm6 = extract_text(denoised, config='--psm 6')
|
| 84 |
-
|
| 85 |
-
print(f"Denoised + PSM 6: {len(text_denoised_psm6)} chars")
|
| 86 |
-
print(f"Denoised + PSM 11: {len(text_denoised_psm11)} chars")
|
| 87 |
-
|
| 88 |
-
if len(text_denoised_psm11) > len(text_denoised_psm6):
|
| 89 |
-
print(f"\n✅ PSM 11 wins! ({len(text_denoised_psm11)} chars)")
|
| 90 |
-
best_config = '--psm 11'
|
| 91 |
-
best_text_final = text_denoised_psm11
|
| 92 |
-
else:
|
| 93 |
-
print(f"\n✅ PSM 6 wins! ({len(text_denoised_psm6)} chars)")
|
| 94 |
-
best_config = '--psm 6'
|
| 95 |
-
best_text_final = text_denoised_psm6
|
| 96 |
-
|
| 97 |
-
print(f"\n🏆 FINAL WINNER: Denoised + {best_config}")
|
| 98 |
-
print("\nFull text:")
|
| 99 |
-
print("=" * 60)
|
| 100 |
-
print(best_text_final)
|
| 101 |
print("=" * 60)
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('src')
|
| 3 |
+
|
| 4 |
+
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 5 |
+
from ocr import extract_text
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
print("=" * 60)
|
| 10 |
+
print("🎯 OPTIMIZING GRAYSCALE OCR")
|
| 11 |
+
print("=" * 60)
|
| 12 |
+
|
| 13 |
+
# Load and convert to grayscale
|
| 14 |
+
image = load_image('data/raw/receipt3.jpg')
|
| 15 |
+
gray = convert_to_grayscale(image)
|
| 16 |
+
|
| 17 |
+
# Test 1: Different PSM modes
|
| 18 |
+
print("\n📊 Testing different Tesseract PSM modes...\n")
|
| 19 |
+
|
| 20 |
+
psm_configs = [
|
| 21 |
+
('', 'Default'),
|
| 22 |
+
('--psm 3', 'Automatic page segmentation'),
|
| 23 |
+
('--psm 4', 'Single column of text'),
|
| 24 |
+
('--psm 6', 'Uniform block of text'),
|
| 25 |
+
('--psm 11', 'Sparse text, find as much as possible'),
|
| 26 |
+
('--psm 12', 'Sparse text with OSD (Orientation and Script Detection)'),
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
results = {}
|
| 30 |
+
for config, desc in psm_configs:
|
| 31 |
+
text = extract_text(gray, config=config)
|
| 32 |
+
results[desc] = text
|
| 33 |
+
print(f"{desc:50s} → {len(text):4d} chars")
|
| 34 |
+
|
| 35 |
+
# Find best result
|
| 36 |
+
best_desc = max(results, key=lambda k: len(results[k]))
|
| 37 |
+
best_text = results[best_desc]
|
| 38 |
+
|
| 39 |
+
print(f"\n✅ WINNER: {best_desc} ({len(best_text)} chars)")
|
| 40 |
+
|
| 41 |
+
# Test 2: With slight denoising
|
| 42 |
+
print("\n📊 Testing with light denoising...\n")
|
| 43 |
+
|
| 44 |
+
denoised = remove_noise(gray, kernel_size=3)
|
| 45 |
+
text_denoised = extract_text(denoised, config='--psm 6')
|
| 46 |
+
print(f"Grayscale + Denoise (psm 6): {len(text_denoised)} chars")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Display best result
|
| 50 |
+
print("\n" + "=" * 60)
|
| 51 |
+
print("📄 BEST EXTRACTED TEXT:")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
print(best_text)
|
| 54 |
+
print("=" * 60)
|
| 55 |
+
|
| 56 |
+
# Visualize
|
| 57 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 58 |
+
|
| 59 |
+
axes[0].imshow(image)
|
| 60 |
+
axes[0].set_title("Original")
|
| 61 |
+
axes[0].axis('off')
|
| 62 |
+
|
| 63 |
+
axes[1].imshow(gray, cmap='gray')
|
| 64 |
+
axes[1].set_title(f"Grayscale\n({len(best_text)} chars - {best_desc})")
|
| 65 |
+
axes[1].axis('off')
|
| 66 |
+
|
| 67 |
+
axes[2].imshow(denoised, cmap='gray')
|
| 68 |
+
axes[2].set_title(f"Denoised\n({len(text_denoised)} chars)")
|
| 69 |
+
axes[2].axis('off')
|
| 70 |
+
|
| 71 |
+
plt.tight_layout()
|
| 72 |
+
plt.show()
|
| 73 |
+
|
| 74 |
+
print(f"\n💡 Recommended pipeline: Grayscale + {best_desc}")
|
| 75 |
+
|
| 76 |
+
# Test the combination we missed!
|
| 77 |
+
print("\n📊 Testing BEST combination...\n")
|
| 78 |
+
|
| 79 |
+
denoised = remove_noise(gray, kernel_size=3)
|
| 80 |
+
|
| 81 |
+
# Test PSM 11 on denoised
|
| 82 |
+
text_denoised_psm11 = extract_text(denoised, config='--psm 11')
|
| 83 |
+
text_denoised_psm6 = extract_text(denoised, config='--psm 6')
|
| 84 |
+
|
| 85 |
+
print(f"Denoised + PSM 6: {len(text_denoised_psm6)} chars")
|
| 86 |
+
print(f"Denoised + PSM 11: {len(text_denoised_psm11)} chars")
|
| 87 |
+
|
| 88 |
+
if len(text_denoised_psm11) > len(text_denoised_psm6):
|
| 89 |
+
print(f"\n✅ PSM 11 wins! ({len(text_denoised_psm11)} chars)")
|
| 90 |
+
best_config = '--psm 11'
|
| 91 |
+
best_text_final = text_denoised_psm11
|
| 92 |
+
else:
|
| 93 |
+
print(f"\n✅ PSM 6 wins! ({len(text_denoised_psm6)} chars)")
|
| 94 |
+
best_config = '--psm 6'
|
| 95 |
+
best_text_final = text_denoised_psm6
|
| 96 |
+
|
| 97 |
+
print(f"\n🏆 FINAL WINNER: Denoised + {best_config}")
|
| 98 |
+
print("\nFull text:")
|
| 99 |
+
print("=" * 60)
|
| 100 |
+
print(best_text_final)
|
| 101 |
print("=" * 60)
|
tests/test_pipeline.py
CHANGED
|
@@ -1,96 +1,96 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import json
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
# Add the 'src' directory to the Python path
|
| 6 |
-
sys.path.append('src')
|
| 7 |
-
|
| 8 |
-
from pipeline import process_invoice
|
| 9 |
-
|
| 10 |
-
def test_full_pipeline():
|
| 11 |
-
"""
|
| 12 |
-
Tests the full invoice processing pipeline on a sample receipt
|
| 13 |
-
and prints the advanced JSON structure.
|
| 14 |
-
"""
|
| 15 |
-
print("=" * 60)
|
| 16 |
-
print("🎯 ADVANCED INVOICE PROCESSING PIPELINE TEST")
|
| 17 |
-
print("=" * 60)
|
| 18 |
-
|
| 19 |
-
# --- Configuration ---
|
| 20 |
-
image_path = 'data/raw/receipt1.jpg'
|
| 21 |
-
save_output = True
|
| 22 |
-
output_dir = 'outputs'
|
| 23 |
-
|
| 24 |
-
# Check if the image exists
|
| 25 |
-
if not Path(image_path).exists():
|
| 26 |
-
print(f"❌ ERROR: Test image not found at '{image_path}'")
|
| 27 |
-
return
|
| 28 |
-
|
| 29 |
-
# --- Processing ---
|
| 30 |
-
print(f"\n🔄 Processing invoice: {image_path}...")
|
| 31 |
-
try:
|
| 32 |
-
# Call the main processing function
|
| 33 |
-
result = process_invoice(image_path, save_results=save_output, output_dir=output_dir)
|
| 34 |
-
print("✅ Invoice processed successfully!")
|
| 35 |
-
except Exception as e:
|
| 36 |
-
print(f"❌ An error occurred during processing: {e}")
|
| 37 |
-
# Print traceback for detailed debugging
|
| 38 |
-
import traceback
|
| 39 |
-
traceback.print_exc()
|
| 40 |
-
return
|
| 41 |
-
|
| 42 |
-
# --- Display Results ---
|
| 43 |
-
print("\n" + "=" * 60)
|
| 44 |
-
print("📊 EXTRACTED INVOICE DATA (Advanced JSON)")
|
| 45 |
-
print("=" * 60)
|
| 46 |
-
|
| 47 |
-
# Pretty-print the JSON to the console
|
| 48 |
-
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 49 |
-
|
| 50 |
-
print("\n" + "=" * 60)
|
| 51 |
-
print("📋 SUMMARY OF KEY EXTRACTED FIELDS")
|
| 52 |
-
print("=" * 60)
|
| 53 |
-
|
| 54 |
-
# --- Print a clean summary ---
|
| 55 |
-
print(f"📄 Receipt Number: {result.get('receipt_number', 'N/A')}")
|
| 56 |
-
print(f"📅 Date: {result.get('date', 'N/A')}")
|
| 57 |
-
|
| 58 |
-
# Print Bill To info safely
|
| 59 |
-
bill_to = result.get('bill_to')
|
| 60 |
-
if bill_to and isinstance(bill_to, dict):
|
| 61 |
-
print(f"👤 Bill To: {bill_to.get('name', 'N/A')}")
|
| 62 |
-
else:
|
| 63 |
-
print("👤 Bill To: N/A")
|
| 64 |
-
|
| 65 |
-
# Print line items
|
| 66 |
-
print("\n🛒 Line Items:")
|
| 67 |
-
items = result.get('items', [])
|
| 68 |
-
if items:
|
| 69 |
-
for i, item in enumerate(items, 1):
|
| 70 |
-
desc = item.get('description', 'No Description')
|
| 71 |
-
qty = item.get('quantity', 1)
|
| 72 |
-
total = item.get('total', 0.0)
|
| 73 |
-
print(f" - Item {i}: {desc[:40]:<40} | Qty: {qty} | Total: {total:.2f}")
|
| 74 |
-
else:
|
| 75 |
-
print(" - No line items extracted.")
|
| 76 |
-
|
| 77 |
-
# Print total and validation status
|
| 78 |
-
print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0):.2f}")
|
| 79 |
-
|
| 80 |
-
confidence = result.get('extraction_confidence', 0)
|
| 81 |
-
print(f"📈 Confidence: {confidence}%")
|
| 82 |
-
|
| 83 |
-
validation = "✅ Passed" if result.get('validation_passed', False) else "❌ Failed"
|
| 84 |
-
print(f"✔️ Validation: {validation}")
|
| 85 |
-
|
| 86 |
-
print("\n" + "=" * 60)
|
| 87 |
-
|
| 88 |
-
if save_output:
|
| 89 |
-
json_path = Path(output_dir) / (Path(image_path).stem + '.json')
|
| 90 |
-
print(f"\n💾 Full JSON output saved to: {json_path}")
|
| 91 |
-
|
| 92 |
-
print("\n🎉 PIPELINE TEST COMPLETE!")
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if __name__ == '__main__':
|
| 96 |
test_full_pipeline()
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Add the 'src' directory to the Python path
|
| 6 |
+
sys.path.append('src')
|
| 7 |
+
|
| 8 |
+
from pipeline import process_invoice
|
| 9 |
+
|
| 10 |
+
def test_full_pipeline():
|
| 11 |
+
"""
|
| 12 |
+
Tests the full invoice processing pipeline on a sample receipt
|
| 13 |
+
and prints the advanced JSON structure.
|
| 14 |
+
"""
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
print("🎯 ADVANCED INVOICE PROCESSING PIPELINE TEST")
|
| 17 |
+
print("=" * 60)
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
image_path = 'data/raw/receipt1.jpg'
|
| 21 |
+
save_output = True
|
| 22 |
+
output_dir = 'outputs'
|
| 23 |
+
|
| 24 |
+
# Check if the image exists
|
| 25 |
+
if not Path(image_path).exists():
|
| 26 |
+
print(f"❌ ERROR: Test image not found at '{image_path}'")
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
# --- Processing ---
|
| 30 |
+
print(f"\n🔄 Processing invoice: {image_path}...")
|
| 31 |
+
try:
|
| 32 |
+
# Call the main processing function
|
| 33 |
+
result = process_invoice(image_path, save_results=save_output, output_dir=output_dir)
|
| 34 |
+
print("✅ Invoice processed successfully!")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"❌ An error occurred during processing: {e}")
|
| 37 |
+
# Print traceback for detailed debugging
|
| 38 |
+
import traceback
|
| 39 |
+
traceback.print_exc()
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# --- Display Results ---
|
| 43 |
+
print("\n" + "=" * 60)
|
| 44 |
+
print("📊 EXTRACTED INVOICE DATA (Advanced JSON)")
|
| 45 |
+
print("=" * 60)
|
| 46 |
+
|
| 47 |
+
# Pretty-print the JSON to the console
|
| 48 |
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
| 49 |
+
|
| 50 |
+
print("\n" + "=" * 60)
|
| 51 |
+
print("📋 SUMMARY OF KEY EXTRACTED FIELDS")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
# --- Print a clean summary ---
|
| 55 |
+
print(f"📄 Receipt Number: {result.get('receipt_number', 'N/A')}")
|
| 56 |
+
print(f"📅 Date: {result.get('date', 'N/A')}")
|
| 57 |
+
|
| 58 |
+
# Print Bill To info safely
|
| 59 |
+
bill_to = result.get('bill_to')
|
| 60 |
+
if bill_to and isinstance(bill_to, dict):
|
| 61 |
+
print(f"👤 Bill To: {bill_to.get('name', 'N/A')}")
|
| 62 |
+
else:
|
| 63 |
+
print("👤 Bill To: N/A")
|
| 64 |
+
|
| 65 |
+
# Print line items
|
| 66 |
+
print("\n🛒 Line Items:")
|
| 67 |
+
items = result.get('items', [])
|
| 68 |
+
if items:
|
| 69 |
+
for i, item in enumerate(items, 1):
|
| 70 |
+
desc = item.get('description', 'No Description')
|
| 71 |
+
qty = item.get('quantity', 1)
|
| 72 |
+
total = item.get('total', 0.0)
|
| 73 |
+
print(f" - Item {i}: {desc[:40]:<40} | Qty: {qty} | Total: {total:.2f}")
|
| 74 |
+
else:
|
| 75 |
+
print(" - No line items extracted.")
|
| 76 |
+
|
| 77 |
+
# Print total and validation status
|
| 78 |
+
print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0):.2f}")
|
| 79 |
+
|
| 80 |
+
confidence = result.get('extraction_confidence', 0)
|
| 81 |
+
print(f"📈 Confidence: {confidence}%")
|
| 82 |
+
|
| 83 |
+
validation = "✅ Passed" if result.get('validation_passed', False) else "❌ Failed"
|
| 84 |
+
print(f"✔️ Validation: {validation}")
|
| 85 |
+
|
| 86 |
+
print("\n" + "=" * 60)
|
| 87 |
+
|
| 88 |
+
if save_output:
|
| 89 |
+
json_path = Path(output_dir) / (Path(image_path).stem + '.json')
|
| 90 |
+
print(f"\n💾 Full JSON output saved to: {json_path}")
|
| 91 |
+
|
| 92 |
+
print("\n🎉 PIPELINE TEST COMPLETE!")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
test_full_pipeline()
|
tests/test_preprocessing.py
CHANGED
|
@@ -1,177 +1,177 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
sys.path.append('src') # So Python can find our modules
|
| 3 |
-
|
| 4 |
-
from preprocessing import load_image, convert_to_grayscale, remove_noise, binarize, preprocess_pipeline
|
| 5 |
-
import numpy as np
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
|
| 8 |
-
# Test 1: Load a valid image
|
| 9 |
-
print("Test 1: Loading receipt1.jpg...")
|
| 10 |
-
image = load_image('data/raw/receipt1.jpg')
|
| 11 |
-
print(f"✅ Success! Image shape: {image.shape}")
|
| 12 |
-
print(f" Data type: {image.dtype}")
|
| 13 |
-
print(f" Value range: {image.min()} to {image.max()}")
|
| 14 |
-
|
| 15 |
-
# Test 2: Visualize it
|
| 16 |
-
print("\nTest 2: Displaying image...")
|
| 17 |
-
plt.imshow(image)
|
| 18 |
-
plt.title("Loaded Receipt")
|
| 19 |
-
plt.axis('off')
|
| 20 |
-
plt.show()
|
| 21 |
-
print("✅ If you see the receipt image, it worked!")
|
| 22 |
-
|
| 23 |
-
# Test 3: Try loading non-existent file
|
| 24 |
-
print("\nTest 3: Testing error handling...")
|
| 25 |
-
try:
|
| 26 |
-
load_image('data/raw/fake_image.jpg')
|
| 27 |
-
print("❌ Should have raised FileNotFoundError!")
|
| 28 |
-
except FileNotFoundError as e:
|
| 29 |
-
print(f"✅ Correctly raised error: {e}")
|
| 30 |
-
|
| 31 |
-
# Test 4: Grayscale conversion
|
| 32 |
-
print("\nTest 4: Converting to grayscale...")
|
| 33 |
-
gray = convert_to_grayscale(image)
|
| 34 |
-
print(f"✅ Success! Grayscale shape: {gray.shape}")
|
| 35 |
-
print(f" Original had 3 channels, now has: {len(gray.shape)} dimensions")
|
| 36 |
-
|
| 37 |
-
# Visualize side-by-side
|
| 38 |
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
|
| 39 |
-
ax1.imshow(image)
|
| 40 |
-
ax1.set_title("Original (RGB)")
|
| 41 |
-
ax1.axis('off')
|
| 42 |
-
|
| 43 |
-
ax2.imshow(gray, cmap='gray') # cmap='gray' tells matplotlib to display in grayscale
|
| 44 |
-
ax2.set_title("Grayscale")
|
| 45 |
-
ax2.axis('off')
|
| 46 |
-
|
| 47 |
-
plt.tight_layout()
|
| 48 |
-
plt.show()
|
| 49 |
-
|
| 50 |
-
# Test 5: Already grayscale (should return as-is)
|
| 51 |
-
print("\nTest 5: Converting already-grayscale image...")
|
| 52 |
-
gray_again = convert_to_grayscale(gray)
|
| 53 |
-
print(f"✅ Returned without error: {gray_again.shape}")
|
| 54 |
-
assert gray_again is gray, "Should return same object if already grayscale"
|
| 55 |
-
print("✅ Correctly returned the same image!")
|
| 56 |
-
|
| 57 |
-
print("\n🎉 Grayscale tests passed!")
|
| 58 |
-
|
| 59 |
-
# Test 6: Binarization - Simple method
|
| 60 |
-
print("\nTest 6: Simple binarization...")
|
| 61 |
-
binary_simple = binarize(gray, method='simple')
|
| 62 |
-
print(f"✅ Success! Binary shape: {binary_simple.shape}")
|
| 63 |
-
print(f" Unique values: {np.unique(binary_simple)}") # Should be [0, 255]
|
| 64 |
-
|
| 65 |
-
# Test 7: Binarization - Adaptive method
|
| 66 |
-
print("\nTest 7: Adaptive binarization...")
|
| 67 |
-
binary_adaptive = binarize(gray, method='adaptive', block_size=11, C=2)
|
| 68 |
-
print(f"✅ Success! Binary shape: {binary_adaptive.shape}")
|
| 69 |
-
print(f" Unique values: {np.unique(binary_adaptive)}")
|
| 70 |
-
|
| 71 |
-
# Visualize comparison
|
| 72 |
-
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 73 |
-
|
| 74 |
-
axes[0, 0].imshow(image)
|
| 75 |
-
axes[0, 0].set_title("1. Original (RGB)")
|
| 76 |
-
axes[0, 0].axis('off')
|
| 77 |
-
|
| 78 |
-
axes[0, 1].imshow(gray, cmap='gray')
|
| 79 |
-
axes[0, 1].set_title("2. Grayscale")
|
| 80 |
-
axes[0, 1].axis('off')
|
| 81 |
-
|
| 82 |
-
axes[1, 0].imshow(binary_simple, cmap='gray')
|
| 83 |
-
axes[1, 0].set_title("3. Simple Threshold")
|
| 84 |
-
axes[1, 0].axis('off')
|
| 85 |
-
|
| 86 |
-
axes[1, 1].imshow(binary_adaptive, cmap='gray')
|
| 87 |
-
axes[1, 1].set_title("4. Adaptive Threshold")
|
| 88 |
-
axes[1, 1].axis('off')
|
| 89 |
-
|
| 90 |
-
plt.tight_layout()
|
| 91 |
-
plt.show()
|
| 92 |
-
|
| 93 |
-
# Test 8: Error handling
|
| 94 |
-
print("\nTest 8: Testing error handling...")
|
| 95 |
-
try:
|
| 96 |
-
binarize(image, method='adaptive') # RGB image (3D) should fail
|
| 97 |
-
print("❌ Should have raised ValueError!")
|
| 98 |
-
except ValueError as e:
|
| 99 |
-
print(f"✅ Correctly raised error: {e}")
|
| 100 |
-
|
| 101 |
-
print("\n🎉 Binarization tests passed!")
|
| 102 |
-
|
| 103 |
-
# Test 9: Noise removal
|
| 104 |
-
print("\nTest 9: Noise removal...")
|
| 105 |
-
denoised = remove_noise(gray, kernel_size=3)
|
| 106 |
-
print(f"✅ Success! Denoised shape: {denoised.shape}")
|
| 107 |
-
|
| 108 |
-
# Test different kernel sizes
|
| 109 |
-
denoised_light = remove_noise(gray, kernel_size=3)
|
| 110 |
-
denoised_heavy = remove_noise(gray, kernel_size=7)
|
| 111 |
-
|
| 112 |
-
# Visualize comparison
|
| 113 |
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 114 |
-
|
| 115 |
-
axes[0].imshow(gray, cmap='gray')
|
| 116 |
-
axes[0].set_title("Original Grayscale")
|
| 117 |
-
axes[0].axis('off')
|
| 118 |
-
|
| 119 |
-
axes[1].imshow(denoised_light, cmap='gray')
|
| 120 |
-
axes[1].set_title("Denoised (kernel=3)")
|
| 121 |
-
axes[1].axis('off')
|
| 122 |
-
|
| 123 |
-
axes[2].imshow(denoised_heavy, cmap='gray')
|
| 124 |
-
axes[2].set_title("Denoised (kernel=7)")
|
| 125 |
-
axes[2].axis('off')
|
| 126 |
-
|
| 127 |
-
plt.tight_layout()
|
| 128 |
-
plt.show()
|
| 129 |
-
print(" Notice: kernel=7 is blurrier but removes more noise")
|
| 130 |
-
|
| 131 |
-
# Test 10: Error handling
|
| 132 |
-
print("\nTest 10: Noise removal error handling...")
|
| 133 |
-
try:
|
| 134 |
-
remove_noise(gray, kernel_size=4) # Even number
|
| 135 |
-
print("❌ Should have raised ValueError!")
|
| 136 |
-
except ValueError as e:
|
| 137 |
-
print(f"✅ Correctly raised error: {e}")
|
| 138 |
-
|
| 139 |
-
print("\n🎉 Noise removal tests passed!")
|
| 140 |
-
|
| 141 |
-
# Test 11: Full pipeline
|
| 142 |
-
print("\nTest 11: Full preprocessing pipeline...")
|
| 143 |
-
|
| 144 |
-
# Test with all steps
|
| 145 |
-
full_processed = preprocess_pipeline(image,
|
| 146 |
-
steps=['grayscale', 'denoise', 'binarize'],
|
| 147 |
-
denoise_kernel=3,
|
| 148 |
-
binarize_method='adaptive')
|
| 149 |
-
print(f"✅ Full pipeline success! Shape: {full_processed.shape}")
|
| 150 |
-
|
| 151 |
-
# Test with selective steps (your clean images)
|
| 152 |
-
clean_processed = preprocess_pipeline(image,
|
| 153 |
-
steps=['grayscale', 'binarize'],
|
| 154 |
-
binarize_method='adaptive')
|
| 155 |
-
print(f"✅ Clean pipeline success! Shape: {clean_processed.shape}")
|
| 156 |
-
|
| 157 |
-
# Visualize comparison
|
| 158 |
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 159 |
-
|
| 160 |
-
axes[0].imshow(image)
|
| 161 |
-
axes[0].set_title("Original")
|
| 162 |
-
axes[0].axis('off')
|
| 163 |
-
|
| 164 |
-
axes[1].imshow(full_processed, cmap='gray')
|
| 165 |
-
axes[1].set_title("Full Pipeline\n(grayscale → denoise → binarize)")
|
| 166 |
-
axes[1].axis('off')
|
| 167 |
-
|
| 168 |
-
axes[2].imshow(clean_processed, cmap='gray')
|
| 169 |
-
axes[2].set_title("Clean Pipeline\n(grayscale → binarize)")
|
| 170 |
-
axes[2].axis('off')
|
| 171 |
-
|
| 172 |
-
plt.tight_layout()
|
| 173 |
-
plt.show()
|
| 174 |
-
|
| 175 |
-
print("\n🎉 Pipeline tests passed!")
|
| 176 |
-
|
| 177 |
-
print("\n🎉 All tests passed!")
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('src') # So Python can find our modules
|
| 3 |
+
|
| 4 |
+
from preprocessing import load_image, convert_to_grayscale, remove_noise, binarize, preprocess_pipeline
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
# Test 1: Load a valid image
|
| 9 |
+
print("Test 1: Loading receipt1.jpg...")
|
| 10 |
+
image = load_image('data/raw/receipt1.jpg')
|
| 11 |
+
print(f"✅ Success! Image shape: {image.shape}")
|
| 12 |
+
print(f" Data type: {image.dtype}")
|
| 13 |
+
print(f" Value range: {image.min()} to {image.max()}")
|
| 14 |
+
|
| 15 |
+
# Test 2: Visualize it
|
| 16 |
+
print("\nTest 2: Displaying image...")
|
| 17 |
+
plt.imshow(image)
|
| 18 |
+
plt.title("Loaded Receipt")
|
| 19 |
+
plt.axis('off')
|
| 20 |
+
plt.show()
|
| 21 |
+
print("✅ If you see the receipt image, it worked!")
|
| 22 |
+
|
| 23 |
+
# Test 3: Try loading non-existent file
|
| 24 |
+
print("\nTest 3: Testing error handling...")
|
| 25 |
+
try:
|
| 26 |
+
load_image('data/raw/fake_image.jpg')
|
| 27 |
+
print("❌ Should have raised FileNotFoundError!")
|
| 28 |
+
except FileNotFoundError as e:
|
| 29 |
+
print(f"✅ Correctly raised error: {e}")
|
| 30 |
+
|
| 31 |
+
# Test 4: Grayscale conversion
|
| 32 |
+
print("\nTest 4: Converting to grayscale...")
|
| 33 |
+
gray = convert_to_grayscale(image)
|
| 34 |
+
print(f"✅ Success! Grayscale shape: {gray.shape}")
|
| 35 |
+
print(f" Original had 3 channels, now has: {len(gray.shape)} dimensions")
|
| 36 |
+
|
| 37 |
+
# Visualize side-by-side
|
| 38 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
|
| 39 |
+
ax1.imshow(image)
|
| 40 |
+
ax1.set_title("Original (RGB)")
|
| 41 |
+
ax1.axis('off')
|
| 42 |
+
|
| 43 |
+
ax2.imshow(gray, cmap='gray') # cmap='gray' tells matplotlib to display in grayscale
|
| 44 |
+
ax2.set_title("Grayscale")
|
| 45 |
+
ax2.axis('off')
|
| 46 |
+
|
| 47 |
+
plt.tight_layout()
|
| 48 |
+
plt.show()
|
| 49 |
+
|
| 50 |
+
# Test 5: Already grayscale (should return as-is)
|
| 51 |
+
print("\nTest 5: Converting already-grayscale image...")
|
| 52 |
+
gray_again = convert_to_grayscale(gray)
|
| 53 |
+
print(f"✅ Returned without error: {gray_again.shape}")
|
| 54 |
+
assert gray_again is gray, "Should return same object if already grayscale"
|
| 55 |
+
print("✅ Correctly returned the same image!")
|
| 56 |
+
|
| 57 |
+
print("\n🎉 Grayscale tests passed!")
|
| 58 |
+
|
| 59 |
+
# Test 6: Binarization - Simple method
|
| 60 |
+
print("\nTest 6: Simple binarization...")
|
| 61 |
+
binary_simple = binarize(gray, method='simple')
|
| 62 |
+
print(f"✅ Success! Binary shape: {binary_simple.shape}")
|
| 63 |
+
print(f" Unique values: {np.unique(binary_simple)}") # Should be [0, 255]
|
| 64 |
+
|
| 65 |
+
# Test 7: Binarization - Adaptive method
|
| 66 |
+
print("\nTest 7: Adaptive binarization...")
|
| 67 |
+
binary_adaptive = binarize(gray, method='adaptive', block_size=11, C=2)
|
| 68 |
+
print(f"✅ Success! Binary shape: {binary_adaptive.shape}")
|
| 69 |
+
print(f" Unique values: {np.unique(binary_adaptive)}")
|
| 70 |
+
|
| 71 |
+
# Visualize comparison
|
| 72 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
| 73 |
+
|
| 74 |
+
axes[0, 0].imshow(image)
|
| 75 |
+
axes[0, 0].set_title("1. Original (RGB)")
|
| 76 |
+
axes[0, 0].axis('off')
|
| 77 |
+
|
| 78 |
+
axes[0, 1].imshow(gray, cmap='gray')
|
| 79 |
+
axes[0, 1].set_title("2. Grayscale")
|
| 80 |
+
axes[0, 1].axis('off')
|
| 81 |
+
|
| 82 |
+
axes[1, 0].imshow(binary_simple, cmap='gray')
|
| 83 |
+
axes[1, 0].set_title("3. Simple Threshold")
|
| 84 |
+
axes[1, 0].axis('off')
|
| 85 |
+
|
| 86 |
+
axes[1, 1].imshow(binary_adaptive, cmap='gray')
|
| 87 |
+
axes[1, 1].set_title("4. Adaptive Threshold")
|
| 88 |
+
axes[1, 1].axis('off')
|
| 89 |
+
|
| 90 |
+
plt.tight_layout()
|
| 91 |
+
plt.show()
|
| 92 |
+
|
| 93 |
+
# Test 8: Error handling
|
| 94 |
+
print("\nTest 8: Testing error handling...")
|
| 95 |
+
try:
|
| 96 |
+
binarize(image, method='adaptive') # RGB image (3D) should fail
|
| 97 |
+
print("❌ Should have raised ValueError!")
|
| 98 |
+
except ValueError as e:
|
| 99 |
+
print(f"✅ Correctly raised error: {e}")
|
| 100 |
+
|
| 101 |
+
print("\n🎉 Binarization tests passed!")
|
| 102 |
+
|
| 103 |
+
# Test 9: Noise removal
|
| 104 |
+
print("\nTest 9: Noise removal...")
|
| 105 |
+
denoised = remove_noise(gray, kernel_size=3)
|
| 106 |
+
print(f"✅ Success! Denoised shape: {denoised.shape}")
|
| 107 |
+
|
| 108 |
+
# Test different kernel sizes
|
| 109 |
+
denoised_light = remove_noise(gray, kernel_size=3)
|
| 110 |
+
denoised_heavy = remove_noise(gray, kernel_size=7)
|
| 111 |
+
|
| 112 |
+
# Visualize comparison
|
| 113 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 114 |
+
|
| 115 |
+
axes[0].imshow(gray, cmap='gray')
|
| 116 |
+
axes[0].set_title("Original Grayscale")
|
| 117 |
+
axes[0].axis('off')
|
| 118 |
+
|
| 119 |
+
axes[1].imshow(denoised_light, cmap='gray')
|
| 120 |
+
axes[1].set_title("Denoised (kernel=3)")
|
| 121 |
+
axes[1].axis('off')
|
| 122 |
+
|
| 123 |
+
axes[2].imshow(denoised_heavy, cmap='gray')
|
| 124 |
+
axes[2].set_title("Denoised (kernel=7)")
|
| 125 |
+
axes[2].axis('off')
|
| 126 |
+
|
| 127 |
+
plt.tight_layout()
|
| 128 |
+
plt.show()
|
| 129 |
+
print(" Notice: kernel=7 is blurrier but removes more noise")
|
| 130 |
+
|
| 131 |
+
# Test 10: Error handling
|
| 132 |
+
print("\nTest 10: Noise removal error handling...")
|
| 133 |
+
try:
|
| 134 |
+
remove_noise(gray, kernel_size=4) # Even number
|
| 135 |
+
print("❌ Should have raised ValueError!")
|
| 136 |
+
except ValueError as e:
|
| 137 |
+
print(f"✅ Correctly raised error: {e}")
|
| 138 |
+
|
| 139 |
+
print("\n🎉 Noise removal tests passed!")
|
| 140 |
+
|
| 141 |
+
# Test 11: Full pipeline
|
| 142 |
+
print("\nTest 11: Full preprocessing pipeline...")
|
| 143 |
+
|
| 144 |
+
# Test with all steps
|
| 145 |
+
full_processed = preprocess_pipeline(image,
|
| 146 |
+
steps=['grayscale', 'denoise', 'binarize'],
|
| 147 |
+
denoise_kernel=3,
|
| 148 |
+
binarize_method='adaptive')
|
| 149 |
+
print(f"✅ Full pipeline success! Shape: {full_processed.shape}")
|
| 150 |
+
|
| 151 |
+
# Test with selective steps (your clean images)
|
| 152 |
+
clean_processed = preprocess_pipeline(image,
|
| 153 |
+
steps=['grayscale', 'binarize'],
|
| 154 |
+
binarize_method='adaptive')
|
| 155 |
+
print(f"✅ Clean pipeline success! Shape: {clean_processed.shape}")
|
| 156 |
+
|
| 157 |
+
# Visualize comparison
|
| 158 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 159 |
+
|
| 160 |
+
axes[0].imshow(image)
|
| 161 |
+
axes[0].set_title("Original")
|
| 162 |
+
axes[0].axis('off')
|
| 163 |
+
|
| 164 |
+
axes[1].imshow(full_processed, cmap='gray')
|
| 165 |
+
axes[1].set_title("Full Pipeline\n(grayscale → denoise → binarize)")
|
| 166 |
+
axes[1].axis('off')
|
| 167 |
+
|
| 168 |
+
axes[2].imshow(clean_processed, cmap='gray')
|
| 169 |
+
axes[2].set_title("Clean Pipeline\n(grayscale → binarize)")
|
| 170 |
+
axes[2].axis('off')
|
| 171 |
+
|
| 172 |
+
plt.tight_layout()
|
| 173 |
+
plt.show()
|
| 174 |
+
|
| 175 |
+
print("\n🎉 Pipeline tests passed!")
|
| 176 |
+
|
| 177 |
+
print("\n🎉 All tests passed!")
|
tests/utils.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
def save_image(image, path):
|
| 2 |
-
|
| 3 |
-
def visualize_boxes(image, boxes, text):
|
| 4 |
-
|
| 5 |
-
def validate_output(data):
|
| 6 |
-
|
| 7 |
def format_currency(amount):
|
|
|
|
| 1 |
+
def save_image(image, path):
|
| 2 |
+
|
| 3 |
+
def visualize_boxes(image, boxes, text):
|
| 4 |
+
|
| 5 |
+
def validate_output(data):
|
| 6 |
+
|
| 7 |
def format_currency(amount):
|
train_combined.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from seqeval.metrics import f1_score
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# --- IMPORTS ---
|
| 13 |
+
from load_sroie_dataset import load_sroie
|
| 14 |
+
from src.data_loader import load_unified_dataset
|
| 15 |
+
|
| 16 |
+
# --- CONFIGURATION ---
|
| 17 |
+
# Points to your local SROIE copy
|
| 18 |
+
SROIE_DATA_PATH = "data/sroie"
|
| 19 |
+
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 20 |
+
OUTPUT_DIR = "models/layoutlmv3-generalized"
|
| 21 |
+
|
| 22 |
+
# Standard Label Set
|
| 23 |
+
LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
|
| 24 |
+
'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL',
|
| 25 |
+
'B-INVOICE_NO', 'I-INVOICE_NO','B-BILL_TO', 'I-BILL_TO']
|
| 26 |
+
label2id = {label: idx for idx, label in enumerate(LABEL_LIST)}
|
| 27 |
+
id2label = {idx: label for idx, label in enumerate(LABEL_LIST)}
|
| 28 |
+
|
| 29 |
+
class UnifiedDataset(Dataset):
|
| 30 |
+
def __init__(self, data, processor, label2id):
|
| 31 |
+
self.data = data
|
| 32 |
+
self.processor = processor
|
| 33 |
+
self.label2id = label2id
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.data)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
example = self.data[idx]
|
| 40 |
+
|
| 41 |
+
# 1. Image Loading
|
| 42 |
+
try:
|
| 43 |
+
if 'image' in example and isinstance(example['image'], Image.Image):
|
| 44 |
+
image = example['image']
|
| 45 |
+
elif 'image_path' in example:
|
| 46 |
+
image = Image.open(example['image_path']).convert("RGB")
|
| 47 |
+
else:
|
| 48 |
+
image = Image.new('RGB', (224, 224), color='white')
|
| 49 |
+
except Exception:
|
| 50 |
+
image = Image.new('RGB', (224, 224), color='white')
|
| 51 |
+
|
| 52 |
+
# 2. Boxes are ALREADY normalized!
|
| 53 |
+
# Just need to ensure they are integers and valid
|
| 54 |
+
boxes = []
|
| 55 |
+
for box in example['bboxes']:
|
| 56 |
+
# Extra safety clamp, just in case
|
| 57 |
+
safe_box = [
|
| 58 |
+
max(0, min(int(box[0]), 1000)),
|
| 59 |
+
max(0, min(int(box[1]), 1000)),
|
| 60 |
+
max(0, min(int(box[2]), 1000)),
|
| 61 |
+
max(0, min(int(box[3]), 1000))
|
| 62 |
+
]
|
| 63 |
+
boxes.append(safe_box)
|
| 64 |
+
|
| 65 |
+
# 3. Label Encoding
|
| 66 |
+
word_labels = []
|
| 67 |
+
for label in example['ner_tags']:
|
| 68 |
+
word_labels.append(self.label2id.get(label, 0))
|
| 69 |
+
|
| 70 |
+
# 4. Processor Encoding
|
| 71 |
+
encoding = self.processor(
|
| 72 |
+
image,
|
| 73 |
+
text=example['words'],
|
| 74 |
+
boxes=boxes,
|
| 75 |
+
word_labels=word_labels,
|
| 76 |
+
truncation=True,
|
| 77 |
+
padding="max_length",
|
| 78 |
+
max_length=512,
|
| 79 |
+
return_tensors="pt"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return {k: v.squeeze(0) for k, v in encoding.items()}
|
| 83 |
+
|
| 84 |
+
def train():
|
| 85 |
+
print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
|
| 86 |
+
|
| 87 |
+
# Check SROIE path
|
| 88 |
+
if not os.path.exists(SROIE_DATA_PATH):
|
| 89 |
+
print(f"❌ Error: SROIE path not found at {SROIE_DATA_PATH}")
|
| 90 |
+
print("Please make sure you copied the 'sroie' folder into 'data/'.")
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
# 1. Load SROIE
|
| 94 |
+
print("📦 Loading SROIE dataset...")
|
| 95 |
+
sroie_data = load_sroie(SROIE_DATA_PATH)
|
| 96 |
+
print(f" - SROIE Train: {len(sroie_data['train'])}")
|
| 97 |
+
print(f" - SROIE Test: {len(sroie_data['test'])}")
|
| 98 |
+
|
| 99 |
+
# 2. Load New Dataset
|
| 100 |
+
print("📦 Loading General Invoice dataset...")
|
| 101 |
+
# Reduced sample size slightly to stay safe on RAM
|
| 102 |
+
new_data = load_unified_dataset(split='train', sample_size=600)
|
| 103 |
+
|
| 104 |
+
random.shuffle(new_data)
|
| 105 |
+
split_idx = int(len(new_data) * 0.9)
|
| 106 |
+
new_train = new_data[:split_idx]
|
| 107 |
+
new_test = new_data[split_idx:]
|
| 108 |
+
|
| 109 |
+
print(f" - General Train: {len(new_train)}")
|
| 110 |
+
print(f" - General Test: {len(new_test)}")
|
| 111 |
+
|
| 112 |
+
# 3. Merge
|
| 113 |
+
full_train_data = sroie_data['train'] + new_train
|
| 114 |
+
full_test_data = sroie_data['test'] + new_test
|
| 115 |
+
print(f"\n🔗 COMBINED DATASET SIZE: {len(full_train_data)} Training Images")
|
| 116 |
+
|
| 117 |
+
# 4. Setup Model
|
| 118 |
+
processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)
|
| 119 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
| 120 |
+
MODEL_CHECKPOINT, num_labels=len(LABEL_LIST),
|
| 121 |
+
id2label=id2label, label2id=label2id
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 125 |
+
model.to(device)
|
| 126 |
+
print(f" - Device: {device}")
|
| 127 |
+
|
| 128 |
+
# 5. Dataloaders
|
| 129 |
+
train_ds = UnifiedDataset(full_train_data, processor, label2id)
|
| 130 |
+
test_ds = UnifiedDataset(full_test_data, processor, label2id)
|
| 131 |
+
|
| 132 |
+
collator = DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt")
|
| 133 |
+
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collator)
|
| 134 |
+
test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator)
|
| 135 |
+
|
| 136 |
+
# 6. Optimize & Train
|
| 137 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
|
| 138 |
+
best_f1 = 0.0
|
| 139 |
+
NUM_EPOCHS = 5
|
| 140 |
+
|
| 141 |
+
print("\n🔥 Beginning Fine-Tuning...")
|
| 142 |
+
for epoch in range(NUM_EPOCHS):
|
| 143 |
+
model.train()
|
| 144 |
+
total_loss = 0
|
| 145 |
+
|
| 146 |
+
progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
|
| 147 |
+
for batch in progress:
|
| 148 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 149 |
+
|
| 150 |
+
optimizer.zero_grad()
|
| 151 |
+
outputs = model(**batch)
|
| 152 |
+
loss = outputs.loss
|
| 153 |
+
loss.backward()
|
| 154 |
+
optimizer.step()
|
| 155 |
+
|
| 156 |
+
total_loss += loss.item()
|
| 157 |
+
progress.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 158 |
+
|
| 159 |
+
# --- Evaluation ---
|
| 160 |
+
model.eval()
|
| 161 |
+
all_preds, all_labels = [], []
|
| 162 |
+
print(" Running Validation...")
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for batch in test_loader:
|
| 165 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 166 |
+
outputs = model(**batch)
|
| 167 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 168 |
+
labels = batch['labels']
|
| 169 |
+
|
| 170 |
+
for i in range(len(labels)):
|
| 171 |
+
true_labels = [id2label[l.item()] for l in labels[i] if l.item() != -100]
|
| 172 |
+
pred_labels = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
|
| 173 |
+
all_labels.append(true_labels)
|
| 174 |
+
all_preds.append(pred_labels)
|
| 175 |
+
|
| 176 |
+
f1 = f1_score(all_labels, all_preds)
|
| 177 |
+
print(f" 📊 Epoch {epoch+1} F1 Score: {f1:.4f}")
|
| 178 |
+
|
| 179 |
+
if f1 > best_f1:
|
| 180 |
+
best_f1 = f1
|
| 181 |
+
print(f" 💾 Saving Improved Model to {OUTPUT_DIR}")
|
| 182 |
+
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
| 183 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 184 |
+
processor.save_pretrained(OUTPUT_DIR)
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
train()
|
train_layoutlm.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
|
| 4 |
+
from load_sroie_dataset import load_sroie # Assumes your helper script is in the root
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from seqeval.metrics import f1_score, precision_score, recall_score
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# --- 1. Global Configuration & Label Mapping ---
|
| 11 |
+
print("Setting up configuration...")
|
| 12 |
+
label_list = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
|
| 13 |
+
'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL']
|
| 14 |
+
label2id = {label: idx for idx, label in enumerate(label_list)}
|
| 15 |
+
id2label = {idx: label for idx, label in enumerate(label_list)}
|
| 16 |
+
|
| 17 |
+
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 18 |
+
SROIE_DATA_PATH = "C:\\Users\\Soumyajit Ghosh\\Downloads\\sroie\\sroie" # Make sure this path is correct
|
| 19 |
+
|
| 20 |
+
# --- 2. PyTorch Dataset Class ---
|
| 21 |
+
class SROIEDataset(Dataset):
|
| 22 |
+
"""PyTorch Dataset for SROIE data."""
|
| 23 |
+
def __init__(self, data, processor, label2id):
|
| 24 |
+
self.data = data
|
| 25 |
+
self.processor = processor
|
| 26 |
+
self.label2id = label2id
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.data)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
example = self.data[idx]
|
| 33 |
+
|
| 34 |
+
# Load image and get its dimensions
|
| 35 |
+
image = Image.open(example['image_path']).convert("RGB")
|
| 36 |
+
width, height = image.size
|
| 37 |
+
|
| 38 |
+
# Normalize bounding boxes
|
| 39 |
+
boxes = []
|
| 40 |
+
for box in example['bboxes']:
|
| 41 |
+
x, y, w, h = box
|
| 42 |
+
x0, y0, x1, y1 = x, y, x + w, y + h
|
| 43 |
+
|
| 44 |
+
x0_norm = int((x0 / width) * 1000)
|
| 45 |
+
y0_norm = int((y0 / height) * 1000)
|
| 46 |
+
x1_norm = int((x1 / width) * 1000)
|
| 47 |
+
y1_norm = int((y1 / height) * 1000)
|
| 48 |
+
|
| 49 |
+
# Clip to ensure all values are within the 0-1000 range
|
| 50 |
+
x0_norm = max(0, min(x0_norm, 1000))
|
| 51 |
+
y0_norm = max(0, min(y0_norm, 1000))
|
| 52 |
+
x1_norm = max(0, min(x1_norm, 1000))
|
| 53 |
+
y1_norm = max(0, min(y1_norm, 1000))
|
| 54 |
+
|
| 55 |
+
boxes.append([x0_norm, y0_norm, x1_norm, y1_norm])
|
| 56 |
+
|
| 57 |
+
# Convert NER tags to IDs
|
| 58 |
+
word_labels = [self.label2id[label] for label in example['ner_tags']]
|
| 59 |
+
|
| 60 |
+
# Use processor to encode everything, with truncation
|
| 61 |
+
encoding = self.processor(
|
| 62 |
+
image,
|
| 63 |
+
text=example['words'],
|
| 64 |
+
boxes=boxes,
|
| 65 |
+
word_labels=word_labels,
|
| 66 |
+
truncation=True,
|
| 67 |
+
max_length=512,
|
| 68 |
+
return_tensors="pt"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Squeeze the batch dimension to get 1D tensors
|
| 72 |
+
item = {key: val.squeeze(0) for key, val in encoding.items()}
|
| 73 |
+
return item
|
| 74 |
+
|
| 75 |
+
# --- 3. Main Training Script ---
|
| 76 |
+
def train():
|
| 77 |
+
"""Main function to run the training process."""
|
| 78 |
+
# --- Load Data ---
|
| 79 |
+
print("Loading SROIE dataset...")
|
| 80 |
+
raw_dataset = load_sroie(SROIE_DATA_PATH)
|
| 81 |
+
|
| 82 |
+
# --- Load Processor ---
|
| 83 |
+
print("Creating processor...")
|
| 84 |
+
processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)
|
| 85 |
+
|
| 86 |
+
# --- Create PyTorch Datasets and DataLoaders ---
|
| 87 |
+
print("Creating PyTorch datasets and dataloaders...")
|
| 88 |
+
train_dataset = SROIEDataset(raw_dataset['train'], processor, label2id)
|
| 89 |
+
test_dataset = SROIEDataset(raw_dataset['test'], processor, label2id)
|
| 90 |
+
|
| 91 |
+
data_collator = DataCollatorForTokenClassification(
|
| 92 |
+
tokenizer=processor.tokenizer,
|
| 93 |
+
padding=True,
|
| 94 |
+
return_tensors="pt"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)
|
| 98 |
+
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)
|
| 99 |
+
|
| 100 |
+
# --- Load Model ---
|
| 101 |
+
print("Loading LayoutLMv3 model for fine-tuning...")
|
| 102 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
| 103 |
+
MODEL_CHECKPOINT,
|
| 104 |
+
num_labels=len(label_list),
|
| 105 |
+
id2label=id2label,
|
| 106 |
+
label2id=label2id
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 110 |
+
model.to(device)
|
| 111 |
+
print(f"Training on: {device}")
|
| 112 |
+
|
| 113 |
+
# --- Setup Optimizer ---
|
| 114 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
| 115 |
+
|
| 116 |
+
# --- Training Loop ---
|
| 117 |
+
best_f1 = 0
|
| 118 |
+
NUM_EPOCHS = 10
|
| 119 |
+
|
| 120 |
+
for epoch in range(NUM_EPOCHS):
|
| 121 |
+
print(f"\n{'='*60}\nEpoch {epoch + 1}/{NUM_EPOCHS}\n{'='*60}")
|
| 122 |
+
|
| 123 |
+
# --- Training Step ---
|
| 124 |
+
model.train()
|
| 125 |
+
total_train_loss = 0
|
| 126 |
+
train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")
|
| 127 |
+
for batch in train_progress_bar:
|
| 128 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 129 |
+
|
| 130 |
+
outputs = model(**batch)
|
| 131 |
+
loss = outputs.loss
|
| 132 |
+
|
| 133 |
+
loss.backward()
|
| 134 |
+
optimizer.step()
|
| 135 |
+
optimizer.zero_grad()
|
| 136 |
+
|
| 137 |
+
total_train_loss += loss.item()
|
| 138 |
+
train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
|
| 139 |
+
|
| 140 |
+
avg_train_loss = total_train_loss / len(train_dataloader)
|
| 141 |
+
|
| 142 |
+
# --- Validation Step ---
|
| 143 |
+
model.eval()
|
| 144 |
+
all_predictions = []
|
| 145 |
+
all_labels = []
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
for batch in tqdm(test_dataloader, desc="Validation"):
|
| 148 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 149 |
+
outputs = model(**batch)
|
| 150 |
+
|
| 151 |
+
predictions = outputs.logits.argmax(dim=-1)
|
| 152 |
+
labels = batch['labels']
|
| 153 |
+
|
| 154 |
+
for i in range(labels.shape[0]):
|
| 155 |
+
true_labels_i = [id2label[l.item()] for l in labels[i] if l.item() != -100]
|
| 156 |
+
pred_labels_i = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
|
| 157 |
+
all_labels.append(true_labels_i)
|
| 158 |
+
all_predictions.append(pred_labels_i)
|
| 159 |
+
|
| 160 |
+
# --- Calculate Metrics ---
|
| 161 |
+
f1 = f1_score(all_labels, all_predictions)
|
| 162 |
+
precision = precision_score(all_labels, all_predictions)
|
| 163 |
+
recall = recall_score(all_labels, all_predictions)
|
| 164 |
+
|
| 165 |
+
print(f"\n📊 Epoch {epoch + 1} Results:")
|
| 166 |
+
print(f" Train Loss: {avg_train_loss:.4f}")
|
| 167 |
+
print(f" F1 Score: {f1:.4f}")
|
| 168 |
+
print(f" Precision: {precision:.4f}")
|
| 169 |
+
print(f" Recall: {recall:.4f}")
|
| 170 |
+
|
| 171 |
+
# --- Save Best Model ---
|
| 172 |
+
if f1 > best_f1:
|
| 173 |
+
best_f1 = f1
|
| 174 |
+
print(f" 🌟 New best F1! Saving model...")
|
| 175 |
+
save_path = Path("./models/layoutlmv3-sroie-best")
|
| 176 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
model.save_pretrained(save_path)
|
| 178 |
+
processor.save_pretrained(save_path)
|
| 179 |
+
|
| 180 |
+
print(f"\n🎉 TRAINING COMPLETE! Best F1 Score: {best_f1:.4f}")
|
| 181 |
+
print(f"Model saved to: ./models/layoutlmv3-sroie-best")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == '__main__':
|
| 185 |
+
train()
|