GSoumyajit2005 commited on
Commit
d79b7f7
·
1 Parent(s): 22fe020

feat: Add Phase 3 generalization scripts and clean up legacy files

Browse files
.gitignore CHANGED
@@ -1,70 +1,73 @@
1
- # Python
2
- __pycache__/
3
- *.pyc
4
- *.pyo
5
- *.pyd
6
-
7
- # Environment
8
- env/
9
- venv/
10
- .env
11
- config.yaml
12
- credentials.json
13
-
14
- # IDE / Editor
15
- .vscode/
16
- .idea/
17
- *.swp
18
- *.swo
19
-
20
- # Notebooks / caches / logs
21
- .ipynb_checkpoints/
22
- .pytest_cache/
23
- *.log
24
- logs/
25
- .cache/
26
-
27
- # OS
28
- .DS_Store
29
- Thumbs.db
30
- ehthumbs.db
31
- *.code-workspace
32
- Desktop.ini
33
-
34
- # Streamlit temp folder
35
- temp/
36
- .streamlit/
37
-
38
- # Jupyter Notebook
39
- .ipynb_checkpoints
40
-
41
- # JSON outputs
42
- outputs/
43
-
44
- # Logs
45
- logs/
46
- *.log
47
-
48
- # --- Data Folders ---
49
- # Ignore all files inside the raw and processed data folders
50
- data/raw/*
51
- data/processed/*
52
-
53
- # But DO NOT ignore the .gitkeep files inside them
54
- !data/raw/.gitkeep
55
- !data/processed/.gitkeep
56
-
57
- !requirements.txt
58
- !README.md
59
-
60
- datasets/
61
- checkpoints/
62
- lightning_logs/
63
- wandb/
64
- mlruns/
65
-
66
-
67
- # Ignore all files in the models directory
68
- models/*
69
- !models/.gitkeep
70
- !models/README.md
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+
7
+ # Environment
8
+ env/
9
+ venv/
10
+ .env
11
+ config.yaml
12
+ credentials.json
13
+
14
+ # IDE / Editor
15
+ .vscode/
16
+ .idea/
17
+ *.swp
18
+ *.swo
19
+
20
+ # Notebooks / caches / logs
21
+ .ipynb_checkpoints/
22
+ .pytest_cache/
23
+ *.log
24
+ logs/
25
+ .cache/
26
+
27
+ # OS
28
+ .DS_Store
29
+ Thumbs.db
30
+ ehthumbs.db
31
+ *.code-workspace
32
+ Desktop.ini
33
+
34
+ # Streamlit temp folder
35
+ temp/
36
+ .streamlit/
37
+
38
+ # Jupyter Notebook
39
+ .ipynb_checkpoints
40
+
41
+ # JSON outputs
42
+ outputs/
43
+
44
+ # Logs
45
+ logs/
46
+ *.log
47
+
48
+ # --- Data Folders ---
49
+ # Ignore all files inside the raw and processed data folders
50
+ data/raw/*
51
+ data/processed/*
52
+
53
+ # But DO NOT ignore the .gitkeep files inside them
54
+ !data/raw/.gitkeep
55
+ !data/processed/.gitkeep
56
+
57
+ !requirements.txt
58
+ !README.md
59
+
60
+ datasets/
61
+ checkpoints/
62
+ lightning_logs/
63
+ wandb/
64
+ mlruns/
65
+
66
+
67
+ # Ignore all files in the models directory
68
+ models/*
69
+ !models/.gitkeep
70
+ !models/README.md
71
+
72
+ # Ignore sroie files in the data directory
73
+ data/sroie/
LICENSE CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 Soumyajit Ghosh
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Soumyajit Ghosh
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,323 +1,323 @@
1
- # 📄 Smart Invoice Processor
2
-
3
- End-to-end invoice/receipt processing with OCR + Rule-based extraction and a fine‑tuned LayoutLMv3 model. Upload an image or run via CLI to get clean, structured JSON (vendor, date, totals, address, etc.).
4
-
5
- ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
6
- ![Streamlit](https://img.shields.io/badge/Streamlit-1.51+-red.svg)
7
- ![Tesseract](https://img.shields.io/badge/Tesseract-5.0+-green.svg)
8
- ![Transformers](https://img.shields.io/badge/Transformers-4.x-purple.svg)
9
- ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-orange.svg)
10
-
11
- ---
12
-
13
- ## 🎯 Features
14
-
15
- - ✅ OCR using Tesseract (configurable, fast, multi-platform)
16
- - ✅ Rule-based extraction (regex baselines)
17
- - ✅ ML-based extraction (LayoutLMv3 fine‑tuned on SROIE) for robust field detection
18
- - ✅ Clean JSON output (date, total, vendor, address, receipt number*)
19
- - ✅ Confidence and simple validation (e.g., total found among amounts)
20
- - ✅ Streamlit web UI with method toggle (ML vs Regex)
21
- - ✅ CLI for single/batch processing with saving to JSON
22
- - ✅ Tests for preprocessing/OCR/pipeline
23
-
24
- > Note: SROIE does not include invoice/receipt number labels; the ML model won’t output it unless you add labeled data. The rule-based extractor can still provide it when formats allow.
25
-
26
- ---
27
-
28
- ## 📊 Demo
29
-
30
- ### Web Interface
31
- ![Homepage](docs/screenshots/homepage.png)
32
- *Clean upload → extract flow with method selector (ML vs Regex).*
33
-
34
- ### Successful Extraction (ML-based)
35
- ![Success Result](docs/screenshots/success_result.png)
36
- *Fields extracted with LayoutLMv3.*
37
-
38
- ### Format Detection (simulated)
39
- ![Format Detection](docs/screenshots/format_detection.png)
40
- *UI shows simple format hints and confidence.*
41
-
42
- ### Example JSON (Rule-based)
43
- ```json
44
- {
45
- "receipt_number": "PEGIV-1030765",
46
- "date": "15/01/2019",
47
- "bill_to": {
48
- "name": "THE PEAK QUARRY WORKS",
49
- "email": null
50
- },
51
- "items": [],
52
- "total_amount": 193.0,
53
- "extraction_confidence": 100,
54
- "validation_passed": true,
55
- "vendor": "OJC MARKETING SDN BHD",
56
- "address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR"
57
- }
58
- ```
59
- ### Example JSON (ML-based)
60
- ```json
61
- {
62
- "receipt_number": null,
63
- "date": "15/01/2019",
64
- "bill_to": null,
65
- "items": [],
66
- "total_amount": 193.0,
67
- "vendor": "OJC MARKETING SDN BHD",
68
- "address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR",
69
- "raw_text": "…",
70
- "raw_ocr_words": ["…"],
71
- "raw_predictions": {
72
- "DATE": {"text": "15/01/2019", "bbox": [[…]]},
73
- "TOTAL": {"text": "193.00", "bbox": [[…]]},
74
- "COMPANY": {"text": "OJC MARKETING SDN BHD", "bbox": [[…]]},
75
- "ADDRESS": {"text": "…", "bbox": [[…]]}
76
- }
77
- }
78
- ```
79
-
80
- ## 🚀 Quick Start
81
-
82
- ### Prerequisites
83
- - Python 3.10+
84
- - Tesseract OCR
85
- - (Optional) CUDA-capable GPU for training/inference speed
86
-
87
- ### Installation
88
-
89
- 1. Clone the repository
90
- ```bash
91
- git clone https://github.com/GSoumyajit2005/invoice-processor-ml
92
- cd invoice-processor-ml
93
- ```
94
-
95
- 2. Install dependencies
96
- ```bash
97
- pip install -r requirements.txt
98
- ```
99
-
100
- 3. Install Tesseract OCR
101
- - **Windows**: Download from [UB Mannheim](https://github.com/UB-Mannheim/tesseract/wiki)
102
- - **Mac**: `brew install tesseract`
103
- - **Linux**: `sudo apt install tesseract-ocr`
104
-
105
- 4. (Optional, Windows) Set Tesseract path in src/ocr.py if needed:
106
- ```bash
107
- pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
108
- ```
109
-
110
- 5. Run the web app
111
- ```bash
112
- streamlit run app.py
113
- ```
114
-
115
- ## 💻 Usage
116
-
117
- ### Web Interface (Recommended)
118
-
119
- The easiest way to use the processor is via the web interface.
120
-
121
- ```bash
122
- streamlit run app.py
123
- ```
124
- - Upload an invoice image (PNG/JPG).
125
- - Choose extraction method in sidebar:
126
- - ML-Based (LayoutLMv3)
127
- - Rule-Based (Regex)
128
- - View JSON, download results.
129
-
130
- ### Command-Line Interface (CLI)
131
-
132
- You can also process invoices directly from the command line.
133
-
134
- #### 1. Processing a Single Invoice
135
-
136
- This command processes the provided sample invoice and prints the results to the console.
137
-
138
- ```bash
139
- python src/pipeline.py data/samples/sample_invoice.jpg --save --method ml
140
- # or
141
- python src/pipeline.py data/samples/sample_invoice.jpg --save --method rules
142
- ```
143
-
144
- #### 2. Batch Processing a Folder
145
-
146
- The CLI can process an entire folder of images at once.
147
-
148
- First, place your own invoice images (e.g., `my_invoice1.jpg`, `my_invoice2.png`) into the `data/raw/` folder.
149
-
150
- Then, run the following command. It will process all images in `data/raw/`. Saved files are written to `outputs/{stem}_{method}.json`.
151
-
152
- ```bash
153
- python src/pipeline.py data/raw --save --method ml
154
- ```
155
-
156
- ### Python API
157
-
158
- You can integrate the pipeline directly into your own Python scripts.
159
-
160
- ```python
161
- from src.pipeline import process_invoice
162
- import json
163
-
164
- result = process_invoice('data/samples/sample_invoice.jpg', method='ml')
165
- print(json.dumps(result, indent=2))
166
- ```
167
-
168
- ## 🏗️ Architecture
169
-
170
- ```
171
- ┌────────────────┐
172
- │ Upload Image │
173
- └───────┬────────┘
174
-
175
-
176
- ┌────────────────────┐
177
- │ Preprocessing │ (OpenCV grayscale/denoise)
178
- └────────┬───────────┘
179
-
180
-
181
- ┌───────────────┐
182
- │ OCR │ (Tesseract)
183
- └───────┬───────┘
184
-
185
- ┌──────────────┴──────────────┐
186
- │ │
187
- ▼ ▼
188
- ┌──────────────────┐ ┌────────────────────────┐
189
- │ Rule-based IE │ │ ML-based IE (NER) │
190
- │ (regex, heur.) │ │ LayoutLMv3 token-class │
191
- └────────┬─────────┘ └───────────┬────────────┘
192
- │ │
193
- └──────────────┬──────────────────┘
194
-
195
- ┌──────────────────┐
196
- │ Post-process │
197
- │ validate, scores │
198
- └────────┬─────────┘
199
-
200
- ┌──────────────────┐
201
- │ JSON Output │
202
- └──────────────────┘
203
- ```
204
-
205
- ## 📁 Project Structure
206
-
207
- ```
208
- invoice-processor-ml/
209
-
210
- ├── data/
211
- │ ├── raw/ # Input invoice images for processing
212
- │ └── processed/ # (Reserved for future use)
213
-
214
-
215
- ├── data/samples/
216
- │ └── sample_invoice.jpg # Public sample for quick testing
217
-
218
- ├── docs/
219
- │ └── screenshots/ # UI Screenshots for the README demo
220
-
221
-
222
- ├── models/
223
- │ └── layoutlmv3-sroie-best/ # Fine-tuned model (created after training)
224
-
225
- ├── outputs/ # Default folder for saved JSON results
226
-
227
- ├── src/
228
- │ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
229
- │ ├── ocr.py # Tesseract OCR integration
230
- │ ├── extraction.py # Regex-based information extraction logic
231
- │ ├── ml_extraction.py # ML-based extraction (LayoutLMv3)
232
- │ └── pipeline.py # Main orchestrator for the pipeline and CLI
233
-
234
-
235
- ├── tests/ # <-- ADD THIS FOLDER
236
- │ ├── test_preprocessing.py # Tests for the preprocessing module
237
- │ ├── test_ocr.py # Tests for the OCR module
238
- │ └── test_pipeline.py # End-to-end pipeline tests
239
-
240
- ├── app.py # Streamlit web interface
241
- ├── requirements.txt # Python dependencies
242
- └── README.md # You are Here!
243
- ```
244
-
245
- ## 🧠 Model & Training
246
-
247
- - **Model**: `microsoft/layoutlmv3-base` (125M params)
248
- - **Task**: Token Classification (NER) with 9 labels: `O, B/I-COMPANY, B/I-ADDRESS, B/I-DATE, B/I-TOTAL`
249
- - **Dataset**: SROIE (ICDAR 2019, English retail receipts)
250
- - **Training**: RTX 3050 6GB, PyTorch 2.x, Transformers 4.x
251
- - **Result**: Best F1 ≈ 0.922 on validation (epoch 5 saved)
252
-
253
- - Training scripts(local):
254
- - `train_layoutlm.py` (data prep, training loop with validation + model save)
255
- - Model saved to: `models/layoutlmv3-sroie-best/`
256
-
257
- ## 📈 Performance
258
-
259
- - **OCR accuracy (clear images)**: High with Tesseract
260
- - **Rule-based extraction**: Strong on simple retail receipts
261
- - **ML-based extraction (SROIE-style)**:
262
- - COMPANY / ADDRESS / DATE / TOTAL: High F1 on simple receipts
263
- - Complex business invoices: Partial extraction unless further fine-tuned
264
-
265
- ## ⚠️ Known Limitations
266
-
267
- 1. **Layout Sensitivity**: The ML model was fine‑tuned only on SROIE (retail receipts). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
268
- 2. **Invoice Number (ML)**: SROIE lacks invoice number labels; the ML model won’t output it unless you add labeled data. The rule-based method can still recover it on many formats.
269
- 3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
270
- 4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
271
-
272
- ## 🔮 Future Enhancements
273
-
274
- - [ ] Add and fine‑tune on mychen76/invoices-and-receipts_ocr_v1 (English) for broader invoice formats
275
- - [ ] (Optional) Add FATURA (table-focused) for line-item extraction
276
- - [ ] Sliding-window chunking for >512 token documents (to avoid truncation)
277
- - [ ] Table detection (Camelot/Tabula/DeepDeSRT) for line items
278
- - [ ] PDF support (pdf2image) for multipage invoices
279
- - [ ] FastAPI backend + Docker
280
- - [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
281
- - [ ] Confidence calibration and better validation rules
282
-
283
- ## 🛠️ Tech Stack
284
-
285
- | Component | Technology |
286
- |-----------|------------|
287
- | OCR | Tesseract 5.0+ |
288
- | Image Processing | OpenCV, Pillow |
289
- | ML/NLP | PyTorch 2.x, Transformers |
290
- | Model | LayoutLMv3 (token class.) |
291
- | Web Interface | Streamlit |
292
- | Data Format | JSON |
293
-
294
- ## 📚 What I Learned
295
-
296
- - OCR challenges (confusable characters, confidence-based filtering)
297
- - Layout-aware NER with LayoutLMv3 (text + bbox + pixels)
298
- - Data normalization (bbox to 0–1000 scale)
299
- - End-to-end pipelines (UI + CLI + JSON output)
300
- - When regex is enough vs when ML is needed
301
- - Evaluation (seqeval F1 for NER)
302
-
303
- ## 🤝 Contributing
304
-
305
- Contributions welcome! Areas needing improvement:
306
- - New patterns for regex extractor
307
- - Better preprocessing for OCR
308
- - New datasets and training configs
309
- - Tests and CI
310
-
311
- ## 📝 License
312
-
313
- MIT License - See LICENSE file for details
314
-
315
- ## 👨‍💻 Author
316
-
317
- **Soumyajit Ghosh** - 3rd Year BTech Student
318
- - Exploring AI/ML and practical applications
319
- - [LinkedIn](https://www.linkedin.com/in/soumyajit-ghosh-49a5b02b2?utm_source=share&utm_campaign) | [GitHub](https://github.com/GSoumyajit2005) | [Portfolio](#)(Coming Soon)
320
-
321
- ---
322
-
323
  **Note**: "This is a learning project demonstrating an end-to-end ML pipeline. Not recommended for production use without further validation, retraining on diverse datasets, and security hardening."
 
1
+ # 📄 Smart Invoice Processor
2
+
3
+ End-to-end invoice/receipt processing with OCR + Rule-based extraction and a fine‑tuned LayoutLMv3 model. Upload an image or run via CLI to get clean, structured JSON (vendor, date, totals, address, etc.).
4
+
5
+ ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
6
+ ![Streamlit](https://img.shields.io/badge/Streamlit-1.51+-red.svg)
7
+ ![Tesseract](https://img.shields.io/badge/Tesseract-5.0+-green.svg)
8
+ ![Transformers](https://img.shields.io/badge/Transformers-4.x-purple.svg)
9
+ ![PyTorch](https://img.shields.io/badge/PyTorch-2.x-orange.svg)
10
+
11
+ ---
12
+
13
+ ## 🎯 Features
14
+
15
+ - ✅ OCR using Tesseract (configurable, fast, multi-platform)
16
+ - ✅ Rule-based extraction (regex baselines)
17
+ - ✅ ML-based extraction (LayoutLMv3 fine‑tuned on SROIE) for robust field detection
18
+ - ✅ Clean JSON output (date, total, vendor, address, receipt number*)
19
+ - ✅ Confidence and simple validation (e.g., total found among amounts)
20
+ - ✅ Streamlit web UI with method toggle (ML vs Regex)
21
+ - ✅ CLI for single/batch processing with saving to JSON
22
+ - ✅ Tests for preprocessing/OCR/pipeline
23
+
24
+ > Note: SROIE does not include invoice/receipt number labels; the ML model won’t output it unless you add labeled data. The rule-based extractor can still provide it when formats allow.
25
+ u
26
+ ---
27
+
28
+ ## 📊 Demo
29
+
30
+ ### Web Interface
31
+ ![Homepage](docs/screenshots/homepage.png)
32
+ *Clean upload → extract flow with method selector (ML vs Regex).*
33
+
34
+ ### Successful Extraction (ML-based)
35
+ ![Success Result](docs/screenshots/success_result.png)
36
+ *Fields extracted with LayoutLMv3.*
37
+
38
+ ### Format Detection (simulated)
39
+ ![Format Detection](docs/screenshots/format_detection.png)
40
+ *UI shows simple format hints and confidence.*
41
+
42
+ ### Example JSON (Rule-based)
43
+ ```json
44
+ {
45
+ "receipt_number": "PEGIV-1030765",
46
+ "date": "15/01/2019",
47
+ "bill_to": {
48
+ "name": "THE PEAK QUARRY WORKS",
49
+ "email": null
50
+ },
51
+ "items": [],
52
+ "total_amount": 193.0,
53
+ "extraction_confidence": 100,
54
+ "validation_passed": true,
55
+ "vendor": "OJC MARKETING SDN BHD",
56
+ "address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR"
57
+ }
58
+ ```
59
+ ### Example JSON (ML-based)
60
+ ```json
61
+ {
62
+ "receipt_number": null,
63
+ "date": "15/01/2019",
64
+ "bill_to": null,
65
+ "items": [],
66
+ "total_amount": 193.0,
67
+ "vendor": "OJC MARKETING SDN BHD",
68
+ "address": "NO JALAN BAYU 4, BANDAR SERI ALAM, 81750 MASAI, JOHOR",
69
+ "raw_text": "…",
70
+ "raw_ocr_words": ["…"],
71
+ "raw_predictions": {
72
+ "DATE": {"text": "15/01/2019", "bbox": [[…]]},
73
+ "TOTAL": {"text": "193.00", "bbox": [[…]]},
74
+ "COMPANY": {"text": "OJC MARKETING SDN BHD", "bbox": [[…]]},
75
+ "ADDRESS": {"text": "…", "bbox": [[…]]}
76
+ }
77
+ }
78
+ ```
79
+
80
+ ## 🚀 Quick Start
81
+
82
+ ### Prerequisites
83
+ - Python 3.10+
84
+ - Tesseract OCR
85
+ - (Optional) CUDA-capable GPU for training/inference speed
86
+
87
+ ### Installation
88
+
89
+ 1. Clone the repository
90
+ ```bash
91
+ git clone https://github.com/GSoumyajit2005/invoice-processor-ml
92
+ cd invoice-processor-ml
93
+ ```
94
+
95
+ 2. Install dependencies
96
+ ```bash
97
+ pip install -r requirements.txt
98
+ ```
99
+
100
+ 3. Install Tesseract OCR
101
+ - **Windows**: Download from [UB Mannheim](https://github.com/UB-Mannheim/tesseract/wiki)
102
+ - **Mac**: `brew install tesseract`
103
+ - **Linux**: `sudo apt install tesseract-ocr`
104
+
105
+ 4. (Optional, Windows) Set Tesseract path in src/ocr.py if needed:
106
+ ```bash
107
+ pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
108
+ ```
109
+
110
+ 5. Run the web app
111
+ ```bash
112
+ streamlit run app.py
113
+ ```
114
+
115
+ ## 💻 Usage
116
+
117
+ ### Web Interface (Recommended)
118
+
119
+ The easiest way to use the processor is via the web interface.
120
+
121
+ ```bash
122
+ streamlit run app.py
123
+ ```
124
+ - Upload an invoice image (PNG/JPG).
125
+ - Choose extraction method in sidebar:
126
+ - ML-Based (LayoutLMv3)
127
+ - Rule-Based (Regex)
128
+ - View JSON, download results.
129
+
130
+ ### Command-Line Interface (CLI)
131
+
132
+ You can also process invoices directly from the command line.
133
+
134
+ #### 1. Processing a Single Invoice
135
+
136
+ This command processes the provided sample invoice and prints the results to the console.
137
+
138
+ ```bash
139
+ python src/pipeline.py data/samples/sample_invoice.jpg --save --method ml
140
+ # or
141
+ python src/pipeline.py data/samples/sample_invoice.jpg --save --method rules
142
+ ```
143
+
144
+ #### 2. Batch Processing a Folder
145
+
146
+ The CLI can process an entire folder of images at once.
147
+
148
+ First, place your own invoice images (e.g., `my_invoice1.jpg`, `my_invoice2.png`) into the `data/raw/` folder.
149
+
150
+ Then, run the following command. It will process all images in `data/raw/`. Saved files are written to `outputs/{stem}_{method}.json`.
151
+
152
+ ```bash
153
+ python src/pipeline.py data/raw --save --method ml
154
+ ```
155
+
156
+ ### Python API
157
+
158
+ You can integrate the pipeline directly into your own Python scripts.
159
+
160
+ ```python
161
+ from src.pipeline import process_invoice
162
+ import json
163
+
164
+ result = process_invoice('data/samples/sample_invoice.jpg', method='ml')
165
+ print(json.dumps(result, indent=2))
166
+ ```
167
+
168
+ ## 🏗️ Architecture
169
+
170
+ ```
171
+ ┌────────────────┐
172
+ │ Upload Image │
173
+ └───────┬────────┘
174
+
175
+
176
+ ┌────────────────────┐
177
+ │ Preprocessing │ (OpenCV grayscale/denoise)
178
+ └────────┬───────────┘
179
+
180
+
181
+ ┌───────────────┐
182
+ │ OCR │ (Tesseract)
183
+ └───────┬───────┘
184
+
185
+ ┌──────────────┴──────────────┐
186
+ │ │
187
+ ▼ ▼
188
+ ┌──────────────────┐ ┌────────────────────────┐
189
+ │ Rule-based IE │ │ ML-based IE (NER) │
190
+ │ (regex, heur.) │ │ LayoutLMv3 token-class │
191
+ └────────┬─────────┘ └───────────┬────────────┘
192
+ │ │
193
+ └──────────────┬──────────────────┘
194
+
195
+ ┌──────────────────┐
196
+ │ Post-process │
197
+ │ validate, scores │
198
+ └────────┬─────────┘
199
+
200
+ ┌──────────────────┐
201
+ │ JSON Output │
202
+ └──────────────────┘
203
+ ```
204
+
205
+ ## 📁 Project Structure
206
+
207
+ ```
208
+ invoice-processor-ml/
209
+
210
+ ├── data/
211
+ │ ├── raw/ # Input invoice images for processing
212
+ │ └── processed/ # (Reserved for future use)
213
+
214
+
215
+ ├── data/samples/
216
+ │ └── sample_invoice.jpg # Public sample for quick testing
217
+
218
+ ├── docs/
219
+ │ └── screenshots/ # UI Screenshots for the README demo
220
+
221
+
222
+ ├── models/
223
+ │ └── layoutlmv3-sroie-best/ # Fine-tuned model (created after training)
224
+
225
+ ├── outputs/ # Default folder for saved JSON results
226
+
227
+ ├── src/
228
+ │ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
229
+ │ ├── ocr.py # Tesseract OCR integration
230
+ │ ├── extraction.py # Regex-based information extraction logic
231
+ │ ├── ml_extraction.py # ML-based extraction (LayoutLMv3)
232
+ │ └── pipeline.py # Main orchestrator for the pipeline and CLI
233
+
234
+
235
+ ├── tests/ # <-- ADD THIS FOLDER
236
+ │ ├── test_preprocessing.py # Tests for the preprocessing module
237
+ │ ├── test_ocr.py # Tests for the OCR module
238
+ │ └── test_pipeline.py # End-to-end pipeline tests
239
+
240
+ ├── app.py # Streamlit web interface
241
+ ├── requirements.txt # Python dependencies
242
+ └── README.md # You are Here!
243
+ ```
244
+
245
+ ## 🧠 Model & Training
246
+
247
+ - **Model**: `microsoft/layoutlmv3-base` (125M params)
248
+ - **Task**: Token Classification (NER) with 9 labels: `O, B/I-COMPANY, B/I-ADDRESS, B/I-DATE, B/I-TOTAL`
249
+ - **Dataset**: SROIE (ICDAR 2019, English retail receipts)
250
+ - **Training**: RTX 3050 6GB, PyTorch 2.x, Transformers 4.x
251
+ - **Result**: Best F1 ≈ 0.922 on validation (epoch 5 saved)
252
+
253
+ - Training scripts (local):
254
+ - `train_layoutlm.py` (data prep, training loop with validation + model save)
255
+ - Model saved to: `models/layoutlmv3-sroie-best/`
256
+
257
+ ## 📈 Performance
258
+
259
+ - **OCR accuracy (clear images)**: High with Tesseract
260
+ - **Rule-based extraction**: Strong on simple retail receipts
261
+ - **ML-based extraction (SROIE-style)**:
262
+ - COMPANY / ADDRESS / DATE / TOTAL: High F1 on simple receipts
263
+ - Complex business invoices: Partial extraction unless further fine-tuned
264
+
265
+ ## ⚠️ Known Limitations
266
+
267
+ 1. **Layout Sensitivity**: The ML model was fine‑tuned only on SROIE (retail receipts). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
268
+ 2. **Invoice Number (ML)**: SROIE lacks invoice number labels; the ML model won’t output it unless you add labeled data. The rule-based method can still recover it on many formats.
269
+ 3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
270
+ 4. **OCR Variability**: Tesseract outputs can vary; preprocessing and thresholds can impact ML results.
271
+
272
+ ## 🔮 Future Enhancements
273
+
274
+ - [ ] Add and fine‑tune on mychen76/invoices-and-receipts_ocr_v1 (English) for broader invoice formats
275
+ - [ ] (Optional) Add FATURA (table-focused) for line-item extraction
276
+ - [ ] Sliding-window chunking for >512 token documents (to avoid truncation)
277
+ - [ ] Table detection (Camelot/Tabula/DeepDeSRT) for line items
278
+ - [ ] PDF support (pdf2image) for multipage invoices
279
+ - [ ] FastAPI backend + Docker
280
+ - [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
281
+ - [ ] Confidence calibration and better validation rules
282
+
283
+ ## 🛠️ Tech Stack
284
+
285
+ | Component | Technology |
286
+ |-----------|------------|
287
+ | OCR | Tesseract 5.0+ |
288
+ | Image Processing | OpenCV, Pillow |
289
+ | ML/NLP | PyTorch 2.x, Transformers |
290
+ | Model | LayoutLMv3 (token class.) |
291
+ | Web Interface | Streamlit |
292
+ | Data Format | JSON |
293
+
294
+ ## 📚 What I Learned
295
+
296
+ - OCR challenges (confusable characters, confidence-based filtering)
297
+ - Layout-aware NER with LayoutLMv3 (text + bbox + pixels)
298
+ - Data normalization (bbox to 0–1000 scale)
299
+ - End-to-end pipelines (UI + CLI + JSON output)
300
+ - When regex is enough vs when ML is needed
301
+ - Evaluation (seqeval F1 for NER)
302
+
303
+ ## 🤝 Contributing
304
+
305
+ Contributions welcome! Areas needing improvement:
306
+ - New patterns for regex extractor
307
+ - Better preprocessing for OCR
308
+ - New datasets and training configs
309
+ - Tests and CI
310
+
311
+ ## 📝 License
312
+
313
+ MIT License - See LICENSE file for details
314
+
315
+ ## 👨‍💻 Author
316
+
317
+ **Soumyajit Ghosh** - 3rd Year BTech Student
318
+ - Exploring AI/ML and practical applications
319
+ - [LinkedIn](https://www.linkedin.com/in/soumyajit-ghosh-49a5b02b2?utm_source=share&utm_campaign) | [GitHub](https://github.com/GSoumyajit2005) | [Portfolio](#)(Coming Soon)
320
+
321
+ ---
322
+
323
  **Note**: "This is a learning project demonstrating an end-to-end ML pipeline. Not recommended for production use without further validation, retraining on diverse datasets, and security hardening."
app.py CHANGED
@@ -1,313 +1,313 @@
1
- import streamlit as st
2
- import os
3
- import json
4
- from datetime import datetime
5
- from PIL import Image
6
- import numpy as np
7
- import pandas as pd
8
- from pathlib import Path
9
-
10
- # Import our actual, working pipeline function
11
- import sys
12
- sys.path.append('src')
13
- from pipeline import process_invoice
14
-
15
- # --- Mock Functions to support the UI without errors ---
16
- # These functions simulate the ones from your example README.
17
- # They allow the UI to render without needing to build a complex format detector today.
18
-
19
- def detect_invoice_format(ocr_text: str):
20
- """
21
- A mock function to simulate format detection.
22
- In a real system, this would analyze the text layout.
23
- """
24
- # Simple heuristic: if it contains "SDN BHD", it's our known format.
25
- if "SDN BHD" in ocr_text:
26
- return {
27
- 'name': 'Template A (Retail)',
28
- 'confidence': 95.0,
29
- 'supported': True,
30
- 'indicators': ["Found 'SDN BHD' suffix", "Date format DD/MM/YYYY detected"]
31
- }
32
- else:
33
- return {
34
- 'name': 'Unknown Format',
35
- 'confidence': 20.0,
36
- 'supported': False,
37
- 'indicators': ["No known company suffixes found"]
38
- }
39
-
40
- def get_format_recommendations(format_info):
41
- """Mock recommendations based on the detected format."""
42
- if format_info['supported']:
43
- return ["• Extraction should be highly accurate."]
44
- else:
45
- return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
46
-
47
- # --- Streamlit App ---
48
-
49
- # Page configuration
50
- st.set_page_config(
51
- page_title="Invoice Processor",
52
- page_icon="📄",
53
- layout="wide",
54
- initial_sidebar_state="expanded"
55
- )
56
-
57
- # Custom CSS for styling
58
- st.markdown("""
59
- <style>
60
- .main-header {
61
- font-size: 3rem;
62
- color: #1f77b4;
63
- text-align: center;
64
- margin-bottom: 2rem;
65
- }
66
- .success-box {
67
- padding: 1rem;
68
- border-radius: 0.5rem;
69
- background-color: #d4edda;
70
- border: 1px solid #c3e6cb;
71
- margin: 1rem 0;
72
- }
73
- .warning-box {
74
- padding: 1rem;
75
- border-radius: 0.5rem;
76
- background-color: #fff3cd;
77
- border: 1px solid #ffeaa7;
78
- margin: 1rem 0;
79
- }
80
- .error-box {
81
- padding: 1rem;
82
- border-radius: 0.5rem;
83
- background-color: #f8d7da;
84
- border: 1px solid #f5c6cb;
85
- margin: 1rem 0;
86
- }
87
- </style>
88
- """, unsafe_allow_html=True)
89
-
90
- # Title
91
- st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
92
- st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
93
-
94
- # Sidebar
95
- with st.sidebar:
96
- st.header("ℹ️ About")
97
- st.info("""
98
- This app uses the pipeline you built to automatically extract:
99
- - Receipt/Invoice number
100
- - Date
101
- - Customer information
102
- - Line items
103
- - Total amount
104
-
105
- **Technology Stack:**
106
- - Tesseract OCR
107
- - OpenCV
108
- - Python Regex
109
- - Streamlit
110
- """)
111
-
112
- st.header("📊 Stats")
113
- if 'processed_count' not in st.session_state:
114
- st.session_state.processed_count = 0
115
- st.metric("Invoices Processed Today", st.session_state.processed_count)
116
-
117
- st.header("⚙️ Configuration")
118
- extraction_method = st.selectbox(
119
- "Choose Extraction Method:",
120
- ('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
121
- help="ML-Based is more robust but may miss fields not in its training data. Rule-Based is faster but more fragile."
122
- )
123
-
124
- # Main content
125
- tab1, tab2, tab3 = st.tabs(["📤 Upload & Process", "📚 Sample Invoices", "ℹ️ How It Works"])
126
-
127
- with tab1:
128
- st.header("Upload an Invoice")
129
-
130
- uploaded_file = st.file_uploader(
131
- "Choose an invoice image (JPG, PNG)",
132
- type=['jpg', 'jpeg', 'png'],
133
- help="Upload a clear image of an invoice or receipt"
134
- )
135
-
136
- if uploaded_file is not None:
137
- col1, col2 = st.columns([1, 1])
138
-
139
- with col1:
140
- st.subheader("📸 Original Image")
141
- image = Image.open(uploaded_file)
142
- st.image(image, use_container_width=True)
143
- st.caption(f"Filename: {uploaded_file.name}")
144
-
145
- with col2:
146
- st.subheader("🔄 Processing Status")
147
-
148
- if st.button("🚀 Extract Data", type="primary"):
149
- with st.spinner("Executing your custom pipeline..."):
150
- try:
151
- # Save the uploaded file to a temporary path to be used by our pipeline
152
- temp_dir = "temp"
153
- os.makedirs(temp_dir, exist_ok=True)
154
- temp_path = os.path.join(temp_dir, uploaded_file.name)
155
- with open(temp_path, "wb") as f:
156
- f.write(uploaded_file.getbuffer())
157
-
158
- # Step 1: Call YOUR full pipeline function
159
- st.write("✅ Calling `process_invoice`...")
160
- # Map the user-friendly name from the dropdown to the actual method parameter
161
- method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
162
- st.write(f"⚙️ Using **{method.upper()}** extraction method...")
163
-
164
- # Call the pipeline with the selected method
165
- extracted_data = process_invoice(temp_path, method=method)
166
-
167
- # Step 2: Simulate format detection using the extracted data
168
- st.write("✅ Simulating format detection...")
169
- format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
170
-
171
- # Store results in session state to display them
172
- st.session_state.extracted_data = extracted_data
173
- st.session_state.format_info = format_info
174
- st.session_state.processed_count += 1
175
-
176
- st.success("✅ Pipeline executed successfully!")
177
-
178
- except Exception as e:
179
- st.error(f"❌ An error occurred in the pipeline: {str(e)}")
180
-
181
- # Display results if they exist in the session state
182
- if 'extracted_data' in st.session_state:
183
- st.markdown("---")
184
- st.header("📊 Extraction Results")
185
-
186
- # --- Format Detection Section ---
187
- format_info = st.session_state.format_info
188
- st.subheader("📋 Detected Format (Simulated)")
189
- col1_fmt, col2_fmt = st.columns([2, 3])
190
- with col1_fmt:
191
- st.metric("Format Type", format_info['name'])
192
- st.metric("Detection Confidence", f"{format_info['confidence']:.0f}%")
193
- if format_info['supported']: st.success("✅ Fully Supported")
194
- else: st.warning("⚠️ Limited Support")
195
- with col2_fmt:
196
- st.write("**Detected Indicators:**")
197
- for indicator in format_info['indicators']: st.write(f"• {indicator}")
198
- st.write("**Recommendations:**")
199
- for rec in get_format_recommendations(format_info): st.write(rec)
200
- st.markdown("---")
201
-
202
- # --- Main Results Section ---
203
- data = st.session_state.extracted_data
204
-
205
- # Confidence display
206
- confidence = data.get('extraction_confidence', 0)
207
- if confidence >= 80:
208
- st.markdown(f'<div class="success-box">✅ <strong>High Confidence: {confidence}%</strong> - Most key fields were found.</div>', unsafe_allow_html=True)
209
- elif confidence >= 50:
210
- st.markdown(f'<div class="warning-box">⚠️ <strong>Medium Confidence: {confidence}%</strong> - Some fields may be missing.</div>', unsafe_allow_html=True)
211
- else:
212
- st.markdown(f'<div class="error-box">❌ <strong>Low Confidence: {confidence}%</strong> - Format likely unsupported.</div>', unsafe_allow_html=True)
213
-
214
- # Validation display
215
- if data.get('validation_passed', False):
216
- st.success("✔️ Validation Passed: Total amount appears consistent with other extracted amounts.")
217
- else:
218
- st.warning("⚠️ Validation Failed: Total amount could not be verified against other numbers.")
219
-
220
- # Key metrics display
221
- # Key metrics display
222
- st.metric("🏢 Vendor", data.get('vendor') or "N/A") # <-- ADD THIS
223
-
224
- res_col1, res_col2, res_col3 = st.columns(3)
225
- res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
226
- res_col2.metric("📅 Date", data.get('date') or "N/A")
227
- res_col3.metric("💵 Total Amount", f"${data.get('total_amount'):.2f}" if data.get('total_amount') is not None else "N/A")
228
-
229
- # Use an expander for longer text fields like address
230
- with st.expander("Show More Details"):
231
- st.markdown(f"**👤 Bill To:** {data.get('bill_to', {}).get('name') if data.get('bill_to') else 'N/A'}")
232
- st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
233
-
234
- # Line items table
235
- if data.get('items'):
236
- st.subheader("🛒 Line Items")
237
- # Ensure data is in the right format for DataFrame
238
- items_df_data = [{
239
- "Description": item.get("description", "N/A"),
240
- "Qty": item.get("quantity", "N/A"),
241
- "Unit Price": f"${item.get('unit_price', 0.0):.2f}",
242
- "Total": f"${item.get('total', 0.0):.2f}"
243
- } for item in data['items']]
244
- df = pd.DataFrame(items_df_data)
245
- st.dataframe(df, use_container_width=True)
246
- else:
247
- st.info("ℹ️ No line items were extracted.")
248
-
249
- # JSON output and download
250
- with st.expander("📄 View Full JSON Output"):
251
- st.json(data)
252
-
253
- json_str = json.dumps(data, indent=2)
254
- st.download_button(
255
- label="💾 Download JSON",
256
- data=json_str,
257
- file_name=f"invoice_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
258
- mime="application/json"
259
- )
260
-
261
- with st.expander("📝 View Raw OCR Text"):
262
- raw_text = data.get('raw_text', '')
263
- if raw_text:
264
- st.text(raw_text)
265
- else:
266
- st.info("No OCR text available.")
267
-
268
- with tab2:
269
- st.header("📚 Sample Invoices")
270
- st.write("Try the sample invoice below to see how the system performs:")
271
-
272
- sample_dir = "data/samples" # ✅ Points to the correct folder
273
- if os.path.exists(sample_dir):
274
- sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
275
-
276
- if sample_files:
277
- # Display the first sample found
278
- img_path = os.path.join(sample_dir, sample_files[0])
279
- st.image(Image.open(img_path), caption=sample_files[0], use_container_width=True)
280
- st.info("You can download this image and upload it in the 'Upload & Process' tab to test the pipeline.")
281
- else:
282
- st.warning("No sample invoices found in `data/samples/`.")
283
- else:
284
- st.error("The `data/samples` directory was not found.")
285
-
286
- with tab3:
287
- st.header("ℹ️ How It Works (Your Custom Pipeline)")
288
- st.markdown("""
289
- This app follows the exact pipeline you built:
290
- ```
291
- 1. 📸 Image Upload
292
-
293
- 2. 🔄 Preprocessing (OpenCV)
294
- Grayscale conversion and noise removal.
295
-
296
- 3. 🔍 OCR (Tesseract)
297
- Optimized with PSM 6 for receipt layouts.
298
-
299
- 4. 🎯 Rule-Based Extraction (Regex)
300
- Your custom patterns find specific fields.
301
-
302
- 5. ✅ Confidence & Validation
303
- Heuristics to check the quality of the extraction.
304
-
305
- 6. 📊 Output JSON
306
- Presents all extracted data in a structured format.
307
- ```
308
- """)
309
- st.info("This rule-based system is a great foundation. The next step is to replace the extraction logic with an ML model like LayoutLM to handle more diverse formats!")
310
-
311
- # Footer
312
- st.markdown("---")
313
  st.markdown("<div style='text-align: center; color: #666;'>Built with your custom Python pipeline | UI by Streamlit</div>", unsafe_allow_html=True)
 
1
+ import streamlit as st
2
+ import os
3
+ import json
4
+ from datetime import datetime
5
+ from PIL import Image
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pathlib import Path
9
+
10
+ # Import our actual, working pipeline function
11
+ import sys
12
+ sys.path.append('src')
13
+ from pipeline import process_invoice
14
+
15
+ # --- Mock Functions to support the UI without errors ---
16
+ # These functions simulate the ones from your example README.
17
+ # They allow the UI to render without needing to build a complex format detector today.
18
+
19
+ def detect_invoice_format(ocr_text: str):
20
+ """
21
+ A mock function to simulate format detection.
22
+ In a real system, this would analyze the text layout.
23
+ """
24
+ # Simple heuristic: if it contains "SDN BHD", it's our known format.
25
+ if "SDN BHD" in ocr_text:
26
+ return {
27
+ 'name': 'Template A (Retail)',
28
+ 'confidence': 95.0,
29
+ 'supported': True,
30
+ 'indicators': ["Found 'SDN BHD' suffix", "Date format DD/MM/YYYY detected"]
31
+ }
32
+ else:
33
+ return {
34
+ 'name': 'Unknown Format',
35
+ 'confidence': 20.0,
36
+ 'supported': False,
37
+ 'indicators': ["No known company suffixes found"]
38
+ }
39
+
40
+ def get_format_recommendations(format_info):
41
+ """Mock recommendations based on the detected format."""
42
+ if format_info['supported']:
43
+ return ["• Extraction should be highly accurate."]
44
+ else:
45
+ return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
46
+
47
+ # --- Streamlit App ---
48
+
49
+ # Page configuration
50
+ st.set_page_config(
51
+ page_title="Invoice Processor",
52
+ page_icon="📄",
53
+ layout="wide",
54
+ initial_sidebar_state="expanded"
55
+ )
56
+
57
+ # Custom CSS for styling
58
+ st.markdown("""
59
+ <style>
60
+ .main-header {
61
+ font-size: 3rem;
62
+ color: #1f77b4;
63
+ text-align: center;
64
+ margin-bottom: 2rem;
65
+ }
66
+ .success-box {
67
+ padding: 1rem;
68
+ border-radius: 0.5rem;
69
+ background-color: #d4edda;
70
+ border: 1px solid #c3e6cb;
71
+ margin: 1rem 0;
72
+ }
73
+ .warning-box {
74
+ padding: 1rem;
75
+ border-radius: 0.5rem;
76
+ background-color: #fff3cd;
77
+ border: 1px solid #ffeaa7;
78
+ margin: 1rem 0;
79
+ }
80
+ .error-box {
81
+ padding: 1rem;
82
+ border-radius: 0.5rem;
83
+ background-color: #f8d7da;
84
+ border: 1px solid #f5c6cb;
85
+ margin: 1rem 0;
86
+ }
87
+ </style>
88
+ """, unsafe_allow_html=True)
89
+
90
+ # Title
91
+ st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
92
+ st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
93
+
94
+ # Sidebar
95
+ with st.sidebar:
96
+ st.header("ℹ️ About")
97
+ st.info("""
98
+ This app uses the pipeline you built to automatically extract:
99
+ - Receipt/Invoice number
100
+ - Date
101
+ - Customer information
102
+ - Line items
103
+ - Total amount
104
+
105
+ **Technology Stack:**
106
+ - Tesseract OCR
107
+ - OpenCV
108
+ - Python Regex
109
+ - Streamlit
110
+ """)
111
+
112
+ st.header("📊 Stats")
113
+ if 'processed_count' not in st.session_state:
114
+ st.session_state.processed_count = 0
115
+ st.metric("Invoices Processed Today", st.session_state.processed_count)
116
+
117
+ st.header("⚙️ Configuration")
118
+ extraction_method = st.selectbox(
119
+ "Choose Extraction Method:",
120
+ ('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
121
+ help="ML-Based is more robust but may miss fields not in its training data. Rule-Based is faster but more fragile."
122
+ )
123
+
124
+ # Main content
125
+ tab1, tab2, tab3 = st.tabs(["📤 Upload & Process", "📚 Sample Invoices", "ℹ️ How It Works"])
126
+
127
+ with tab1:
128
+ st.header("Upload an Invoice")
129
+
130
+ uploaded_file = st.file_uploader(
131
+ "Choose an invoice image (JPG, PNG)",
132
+ type=['jpg', 'jpeg', 'png'],
133
+ help="Upload a clear image of an invoice or receipt"
134
+ )
135
+
136
+ if uploaded_file is not None:
137
+ col1, col2 = st.columns([1, 1])
138
+
139
+ with col1:
140
+ st.subheader("📸 Original Image")
141
+ image = Image.open(uploaded_file)
142
+ st.image(image, use_container_width=True)
143
+ st.caption(f"Filename: {uploaded_file.name}")
144
+
145
+ with col2:
146
+ st.subheader("🔄 Processing Status")
147
+
148
+ if st.button("🚀 Extract Data", type="primary"):
149
+ with st.spinner("Executing your custom pipeline..."):
150
+ try:
151
+ # Save the uploaded file to a temporary path to be used by our pipeline
152
+ temp_dir = "temp"
153
+ os.makedirs(temp_dir, exist_ok=True)
154
+ temp_path = os.path.join(temp_dir, uploaded_file.name)
155
+ with open(temp_path, "wb") as f:
156
+ f.write(uploaded_file.getbuffer())
157
+
158
+ # Step 1: Call YOUR full pipeline function
159
+ st.write("✅ Calling `process_invoice`...")
160
+ # Map the user-friendly name from the dropdown to the actual method parameter
161
+ method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
162
+ st.write(f"⚙️ Using **{method.upper()}** extraction method...")
163
+
164
+ # Call the pipeline with the selected method
165
+ extracted_data = process_invoice(temp_path, method=method)
166
+
167
+ # Step 2: Simulate format detection using the extracted data
168
+ st.write("✅ Simulating format detection...")
169
+ format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
170
+
171
+ # Store results in session state to display them
172
+ st.session_state.extracted_data = extracted_data
173
+ st.session_state.format_info = format_info
174
+ st.session_state.processed_count += 1
175
+
176
+ st.success("✅ Pipeline executed successfully!")
177
+
178
+ except Exception as e:
179
+ st.error(f"❌ An error occurred in the pipeline: {str(e)}")
180
+
181
+ # Display results if they exist in the session state
182
+ if 'extracted_data' in st.session_state:
183
+ st.markdown("---")
184
+ st.header("📊 Extraction Results")
185
+
186
+ # --- Format Detection Section ---
187
+ format_info = st.session_state.format_info
188
+ st.subheader("📋 Detected Format (Simulated)")
189
+ col1_fmt, col2_fmt = st.columns([2, 3])
190
+ with col1_fmt:
191
+ st.metric("Format Type", format_info['name'])
192
+ st.metric("Detection Confidence", f"{format_info['confidence']:.0f}%")
193
+ if format_info['supported']: st.success("✅ Fully Supported")
194
+ else: st.warning("⚠️ Limited Support")
195
+ with col2_fmt:
196
+ st.write("**Detected Indicators:**")
197
+ for indicator in format_info['indicators']: st.write(f"• {indicator}")
198
+ st.write("**Recommendations:**")
199
+ for rec in get_format_recommendations(format_info): st.write(rec)
200
+ st.markdown("---")
201
+
202
+ # --- Main Results Section ---
203
+ data = st.session_state.extracted_data
204
+
205
+ # Confidence display
206
+ confidence = data.get('extraction_confidence', 0)
207
+ if confidence >= 80:
208
+ st.markdown(f'<div class="success-box">✅ <strong>High Confidence: {confidence}%</strong> - Most key fields were found.</div>', unsafe_allow_html=True)
209
+ elif confidence >= 50:
210
+ st.markdown(f'<div class="warning-box">⚠️ <strong>Medium Confidence: {confidence}%</strong> - Some fields may be missing.</div>', unsafe_allow_html=True)
211
+ else:
212
+ st.markdown(f'<div class="error-box">❌ <strong>Low Confidence: {confidence}%</strong> - Format likely unsupported.</div>', unsafe_allow_html=True)
213
+
214
+ # Validation display
215
+ if data.get('validation_passed', False):
216
+ st.success("✔️ Validation Passed: Total amount appears consistent with other extracted amounts.")
217
+ else:
218
+ st.warning("⚠️ Validation Failed: Total amount could not be verified against other numbers.")
219
+
220
+ # Key metrics display
221
+ # Key metrics display
222
+ st.metric("🏢 Vendor", data.get('vendor') or "N/A") # <-- ADD THIS
223
+
224
+ res_col1, res_col2, res_col3 = st.columns(3)
225
+ res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
226
+ res_col2.metric("📅 Date", data.get('date') or "N/A")
227
+ res_col3.metric("💵 Total Amount", f"${data.get('total_amount'):.2f}" if data.get('total_amount') is not None else "N/A")
228
+
229
+ # Use an expander for longer text fields like address
230
+ with st.expander("Show More Details"):
231
+ st.markdown(f"**👤 Bill To:** {data.get('bill_to', {}).get('name') if data.get('bill_to') else 'N/A'}")
232
+ st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
233
+
234
+ # Line items table
235
+ if data.get('items'):
236
+ st.subheader("🛒 Line Items")
237
+ # Ensure data is in the right format for DataFrame
238
+ items_df_data = [{
239
+ "Description": item.get("description", "N/A"),
240
+ "Qty": item.get("quantity", "N/A"),
241
+ "Unit Price": f"${item.get('unit_price', 0.0):.2f}",
242
+ "Total": f"${item.get('total', 0.0):.2f}"
243
+ } for item in data['items']]
244
+ df = pd.DataFrame(items_df_data)
245
+ st.dataframe(df, use_container_width=True)
246
+ else:
247
+ st.info("ℹ️ No line items were extracted.")
248
+
249
+ # JSON output and download
250
+ with st.expander("📄 View Full JSON Output"):
251
+ st.json(data)
252
+
253
+ json_str = json.dumps(data, indent=2)
254
+ st.download_button(
255
+ label="💾 Download JSON",
256
+ data=json_str,
257
+ file_name=f"invoice_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
258
+ mime="application/json"
259
+ )
260
+
261
+ with st.expander("📝 View Raw OCR Text"):
262
+ raw_text = data.get('raw_text', '')
263
+ if raw_text:
264
+ st.text(raw_text)
265
+ else:
266
+ st.info("No OCR text available.")
267
+
268
+ with tab2:
269
+ st.header("📚 Sample Invoices")
270
+ st.write("Try the sample invoice below to see how the system performs:")
271
+
272
+ sample_dir = "data/samples" # ✅ Points to the correct folder
273
+ if os.path.exists(sample_dir):
274
+ sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
275
+
276
+ if sample_files:
277
+ # Display the first sample found
278
+ img_path = os.path.join(sample_dir, sample_files[0])
279
+ st.image(Image.open(img_path), caption=sample_files[0], use_container_width=True)
280
+ st.info("You can download this image and upload it in the 'Upload & Process' tab to test the pipeline.")
281
+ else:
282
+ st.warning("No sample invoices found in `data/samples/`.")
283
+ else:
284
+ st.error("The `data/samples` directory was not found.")
285
+
286
+ with tab3:
287
+ st.header("ℹ️ How It Works (Your Custom Pipeline)")
288
+ st.markdown("""
289
+ This app follows the exact pipeline you built:
290
+ ```
291
+ 1. 📸 Image Upload
292
+
293
+ 2. 🔄 Preprocessing (OpenCV)
294
+ Grayscale conversion and noise removal.
295
+
296
+ 3. 🔍 OCR (Tesseract)
297
+ Optimized with PSM 6 for receipt layouts.
298
+
299
+ 4. 🎯 Rule-Based Extraction (Regex)
300
+ Your custom patterns find specific fields.
301
+
302
+ 5. ✅ Confidence & Validation
303
+ Heuristics to check the quality of the extraction.
304
+
305
+ 6. 📊 Output JSON
306
+ Presents all extracted data in a structured format.
307
+ ```
308
+ """)
309
+ st.info("This rule-based system is a great foundation. The next step is to replace the extraction logic with an ML model like LayoutLM to handle more diverse formats!")
310
+
311
+ # Footer
312
+ st.markdown("---")
313
  st.markdown("<div style='text-align: center; color: #666;'>Built with your custom Python pipeline | UI by Streamlit</div>", unsafe_allow_html=True)
eval_new_dataset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.data_loader import load_unified_dataset
3
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
4
+ from torch.utils.data import DataLoader
5
+ from seqeval.metrics import classification_report
6
+ from tqdm import tqdm
7
+ from train_combined import UnifiedDataset, label2id, id2label, LABEL_LIST
8
+
9
+ # Load Model
10
+ model_path = "./models/layoutlmv3-generalized"
11
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
12
+ processor = LayoutLMv3Processor.from_pretrained(model_path, apply_ocr=False)
13
+ device = torch.device("cuda")
14
+ model.to(device)
15
+
16
+ # Load ONLY the new dataset (validation split)
17
+ # We want to see how well it learned THIS specific dataset
18
+ print("Loading new dataset validation split...")
19
+ val_data = load_unified_dataset(split="valid", sample_size=None)
20
+ dataset = UnifiedDataset(val_data, processor, label2id)
21
+ loader = DataLoader(dataset, batch_size=4, collate_fn=DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt"))
22
+
23
+ print("Running evaluation...")
24
+ model.eval()
25
+ preds, labs = [], []
26
+
27
+ for batch in tqdm(loader):
28
+ batch = {k: v.to(device) for k, v in batch.items()}
29
+ with torch.no_grad():
30
+ outputs = model(**batch)
31
+
32
+ predictions = outputs.logits.argmax(dim=-1)
33
+ labels = batch['labels']
34
+
35
+ for i in range(len(labels)):
36
+ p = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
37
+ l = [id2label[l.item()] for l in labels[i] if l.item() != -100]
38
+ preds.append(p)
39
+ labs.append(l)
40
+
41
+ print("\nClassification Report:")
42
+ print(classification_report(labs, preds))
explore_new_dataset.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import json
3
+ import ast # <--- Added for robust parsing
4
+
5
+ # --- 1. Load the dataset ---
6
+ print("📥 Loading 'mychen76/invoices-and-receipts_ocr_v1' from Hugging Face...")
7
+ try:
8
+ dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split='train')
9
+ print("✅ Dataset loaded successfully!")
10
+ except Exception as e:
11
+ print(f"❌ Failed to load dataset. Error: {e}")
12
+ exit()
13
+
14
+ # --- 2. Print Dataset Information ---
15
+ print("\n" + "="*60)
16
+ print("📊 DATASET INFORMATION & FEATURES")
17
+ print("="*60)
18
+ print(f"Number of examples: {len(dataset)}")
19
+ print(f"\nFeatures (Columns): {dataset.features}")
20
+
21
+
22
+ # --- 3. Explore a Single Example ---
23
+ print("\n" + "="*60)
24
+ print("📄 EXPLORING THE FIRST SAMPLE")
25
+ print("="*60)
26
+ if len(dataset) > 0:
27
+ sample = dataset[0]
28
+
29
+ # Parse the main wrapper JSONs
30
+ try:
31
+ raw_data = json.loads(sample['raw_data'])
32
+ parsed_data = json.loads(sample['parsed_data'])
33
+ except json.JSONDecodeError as e:
34
+ print(f"❌ Error decoding main JSON wrappers: {e}")
35
+ exit()
36
+
37
+ print(f"\nImage object: {sample['image']}")
38
+
39
+ # --- ROBUST PARSING LOGIC ---
40
+ def safe_parse(content):
41
+ """Try JSON, fallback to AST (for single quotes)"""
42
+ if isinstance(content, list):
43
+ return content # Already a list
44
+ if isinstance(content, str):
45
+ try:
46
+ return json.loads(content)
47
+ except json.JSONDecodeError:
48
+ try:
49
+ return ast.literal_eval(content)
50
+ except:
51
+ return None
52
+ return None
53
+
54
+ ocr_words = safe_parse(raw_data.get('ocr_words'))
55
+ ocr_boxes = safe_parse(raw_data.get('ocr_boxes'))
56
+
57
+ if ocr_words and ocr_boxes:
58
+ print(f"\nFound {len(ocr_words)} OCR words.")
59
+ print("Sample Word & Box Format:")
60
+ # Print first 3 to check coordinate format (4 numbers or 8 numbers?)
61
+ for i in range(min(3, len(ocr_words))):
62
+ print(f" Word: '{ocr_words[i]}' | Box: {ocr_boxes[i]}")
63
+ else:
64
+ print("❌ OCR fields missing or could not be parsed.")
65
+
66
+
67
+ else:
68
+ print("Dataset is empty.")
69
+
70
+
71
+ # --- 4. Discover All Unique NER Tags ---
72
+ print("\n" + "="*60)
73
+ print("📋 ALL UNIQUE ENTITY LABELS IN THIS DATASET")
74
+ print("="*60)
75
+ if len(dataset) > 0:
76
+ all_entity_labels = set()
77
+
78
+ print("Scanning dataset for labels...")
79
+ for i, example in enumerate(dataset):
80
+ try:
81
+ # Parse parsed_data
82
+ parsed_example = json.loads(example['parsed_data'])
83
+
84
+ # The 'json' field inside might be a string or a dict
85
+ fields_data = parsed_example.get('json', {})
86
+
87
+ if isinstance(fields_data, str):
88
+ try:
89
+ fields = json.loads(fields_data)
90
+ except:
91
+ fields = ast.literal_eval(fields_data)
92
+ else:
93
+ fields = fields_data
94
+
95
+ if fields:
96
+ all_entity_labels.update(fields.keys())
97
+
98
+ except Exception:
99
+ continue # Skip corrupted examples silently
100
+
101
+ if all_entity_labels:
102
+ print(f"\nFound {len(all_entity_labels)} unique entity labels:")
103
+ print(sorted(list(all_entity_labels)))
104
+ else:
105
+ print("Could not find any entity labels.")
106
+ else:
107
+ print("Cannot analyze tags of an empty dataset.")
108
+
109
+
110
+ # Add this to explore_new_dataset.py
111
+ sample = dataset[0]
112
+ sample['image'].save("data/samples/test_invoice_no.jpg")
113
+ print("Saved sample image to data/samples/test_invoice_no.jpg")
load_sroie_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from PIL import Image
4
+
5
+ def load_sroie(path):
6
+ print(f"🔄 Loading SROIE from local path: {path}")
7
+ path = Path(path)
8
+ dataset = {'train': [], 'test': []}
9
+
10
+ for split in ["train", "test"]:
11
+ split_path = path / split
12
+
13
+ if (split_path / "images").exists(): img_dir = split_path / "images"
14
+ elif (split_path / "img").exists(): img_dir = split_path / "img"
15
+ else: continue
16
+
17
+ if (split_path / "tagged").exists(): ann_dir = split_path / "tagged"
18
+ elif (split_path / "box").exists(): ann_dir = split_path / "box"
19
+ else: continue
20
+
21
+ examples = []
22
+ for img_file in sorted(img_dir.iterdir()):
23
+ if img_file.suffix.lower() not in [".jpg", ".png"]: continue
24
+
25
+ name = img_file.stem
26
+ json_path = ann_dir / f"{name}.json"
27
+ if not json_path.exists(): continue
28
+
29
+ with open(json_path, encoding="utf8") as f:
30
+ data = json.load(f)
31
+
32
+ if "words" in data and "bbox" in data and "labels" in data:
33
+ # --- NORMALIZATION HAPPENS HERE (YOUR FIX) ---
34
+ try:
35
+ with Image.open(img_file) as img:
36
+ width, height = img.size
37
+
38
+ norm_boxes = []
39
+ for box in data["bbox"]:
40
+ # SROIE is raw [x0, y0, x1, y1]
41
+ x0, y0, x1, y1 = box
42
+
43
+ # Normalize and Clamp
44
+ norm_box = [
45
+ int(max(0, min(1000 * (x0 / width), 1000))),
46
+ int(max(0, min(1000 * (y0 / height), 1000))),
47
+ int(max(0, min(1000 * (x1 / width), 1000))),
48
+ int(max(0, min(1000 * (y1 / height), 1000)))
49
+ ]
50
+ norm_boxes.append(norm_box)
51
+
52
+ examples.append({
53
+ "image_path": str(img_file),
54
+ "words": data["words"],
55
+ "bboxes": norm_boxes, # Storing normalized boxes
56
+ "ner_tags": data["labels"]
57
+ })
58
+ except Exception as e:
59
+ print(f"Skipping {name}: {e}")
60
+ continue
61
+
62
+ dataset[split] = examples
63
+ print(f" Mapped {len(examples)} paths for {split}")
64
+
65
+ return dataset
notebooks/test_setup.py DELETED
@@ -1,11 +0,0 @@
1
- # This is just a verification script - you can copy this
2
- import pytesseract
3
- from PIL import Image
4
- import cv2
5
- import numpy as np
6
-
7
- # If Windows, you might need to set this path:
8
- # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
9
-
10
- print("✅ All imports successful!")
11
- print(f"Tesseract version: {pytesseract.get_tesseract_version()}")
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/test_visual.ipynb DELETED
File without changes
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/data_loader.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/data_loader.py
2
+
3
+ import json
4
+ import ast
5
+ import numpy as np
6
+ from datasets import load_dataset
7
+ from difflib import SequenceMatcher
8
+
9
+ # --- CONFIGURATION ---
10
+ LABEL_MAPPING = {
11
+ # Vendor/Company
12
+ "seller": "COMPANY",
13
+ "store_name": "COMPANY",
14
+
15
+ # Address
16
+ "store_addr": "ADDRESS",
17
+
18
+ # Date
19
+ "date": "DATE",
20
+ "invoice_date": "DATE",
21
+
22
+ # Total
23
+ "total": "TOTAL",
24
+ "total_gross_worth": "TOTAL",
25
+
26
+ # Receipt Number / Invoice No
27
+ "invoice_no": "INVOICE_NO",
28
+
29
+ # Bill To / Client
30
+ "client": "BILL_TO"
31
+ }
32
+
33
+ def safe_parse(content):
34
+ """Robustly parses input that might be a list, a JSON string, or a Python string literal."""
35
+ if isinstance(content, list):
36
+ return content
37
+ if isinstance(content, str):
38
+ try:
39
+ return json.loads(content)
40
+ except json.JSONDecodeError:
41
+ pass
42
+ try:
43
+ return ast.literal_eval(content)
44
+ except (ValueError, SyntaxError):
45
+ pass
46
+ return []
47
+
48
+ def normalize_box(box, width, height):
49
+ """Converts 8-point polygons to 4-point normalized [0-1000] bbox."""
50
+ try:
51
+ # Handle nested format variations
52
+ if isinstance(box, list) and len(box) == 2 and isinstance(box[0], list):
53
+ polygon = box[0]
54
+ elif isinstance(box, list) and len(box) == 4 and isinstance(box[0], list):
55
+ polygon = box
56
+ else:
57
+ return None
58
+
59
+ xs = [point[0] for point in polygon]
60
+ ys = [point[1] for point in polygon]
61
+
62
+ return [
63
+ int(max(0, min(1000 * (min(xs) / width), 1000))),
64
+ int(max(0, min(1000 * (min(ys) / height), 1000))),
65
+ int(max(0, min(1000 * (max(xs) / width), 1000))),
66
+ int(max(0, min(1000 * (max(ys) / height), 1000)))
67
+ ]
68
+ except Exception:
69
+ return None
70
+
71
+ def tokenize_and_spread_boxes(words, boxes):
72
+ """
73
+ Splits phrases into individual words and duplicates the bounding box.
74
+ Input: ['Invoice #123'], [BOX_A]
75
+ Output: ['Invoice', '#123'], [BOX_A, BOX_A]
76
+ """
77
+ tokenized_words = []
78
+ tokenized_boxes = []
79
+
80
+ for word, box in zip(words, boxes):
81
+ # Split by whitespace
82
+ sub_words = str(word).split()
83
+ for sw in sub_words:
84
+ tokenized_words.append(sw)
85
+ tokenized_boxes.append(box)
86
+
87
+ return tokenized_words, tokenized_boxes
88
+
89
+ def align_labels(ocr_words, label_map):
90
+ """Matches OCR words to Ground Truth values using Sub-sequence Matching."""
91
+ tags = ["O"] * len(ocr_words)
92
+
93
+ for target_text, label_class in label_map.items():
94
+ if not target_text: continue
95
+
96
+ target_tokens = str(target_text).split()
97
+ if not target_tokens: continue
98
+
99
+ n_target = len(target_tokens)
100
+
101
+ # Sliding window search
102
+ for i in range(len(ocr_words) - n_target + 1):
103
+ window = ocr_words[i : i + n_target]
104
+
105
+ # Check match
106
+ match = True
107
+ for j in range(n_target):
108
+ # Clean punctuation for comparison
109
+ w_clean = window[j].strip(".,-:")
110
+ t_clean = target_tokens[j].strip(".,-:")
111
+ if w_clean not in t_clean and t_clean not in w_clean:
112
+ match = False
113
+ break
114
+
115
+ if match:
116
+ tags[i] = f"B-{label_class}"
117
+ for k in range(1, n_target):
118
+ tags[i + k] = f"I-{label_class}"
119
+
120
+ return tags
121
+
122
+ def load_unified_dataset(split="train", sample_size=None):
123
+ print(f"🔄 Loading dataset 'mychen76/invoices-and-receipts_ocr_v1' ({split})...")
124
+ dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split=split)
125
+
126
+ if sample_size:
127
+ dataset = dataset.select(range(sample_size))
128
+
129
+ processed_data = []
130
+
131
+ print("⚙️ Processing, Tokenizing, and Aligning...")
132
+ for example in dataset:
133
+ try:
134
+ image = example['image']
135
+ if image.mode != "RGB":
136
+ image = image.convert("RGB")
137
+ width, height = image.size
138
+
139
+ # 1. Parse Raw OCR
140
+ raw_words = safe_parse(json.loads(example['raw_data']).get('ocr_words'))
141
+ raw_boxes = safe_parse(json.loads(example['raw_data']).get('ocr_boxes'))
142
+
143
+ if not raw_words or not raw_boxes or len(raw_words) != len(raw_boxes):
144
+ continue
145
+
146
+ # 2. Normalize Boxes first
147
+ norm_boxes = []
148
+ valid_words = []
149
+ for i, box in enumerate(raw_boxes):
150
+ nb = normalize_box(box, width, height)
151
+ if nb:
152
+ norm_boxes.append(nb)
153
+ valid_words.append(raw_words[i])
154
+
155
+ # 3. TOKENIZE (The Fix)
156
+ final_words, final_boxes = tokenize_and_spread_boxes(valid_words, norm_boxes)
157
+
158
+ # 4. Map Labels
159
+ parsed_json = json.loads(example['parsed_data'])
160
+ fields = safe_parse(parsed_json.get('json', {}))
161
+ label_value_map = {}
162
+ if isinstance(fields, dict):
163
+ for k, v in fields.items():
164
+ if k in LABEL_MAPPING and v:
165
+ label_value_map[v] = LABEL_MAPPING[k]
166
+
167
+ # 5. Align Labels
168
+ final_tags = align_labels(final_words, label_value_map)
169
+
170
+ # Only keep if we found at least one entity (cleaner training data)
171
+ unique_tags = set(final_tags)
172
+ if len(unique_tags) > 1:
173
+ processed_data.append({
174
+ "image": image,
175
+ "words": final_words,
176
+ "bboxes": final_boxes,
177
+ "ner_tags": final_tags
178
+ })
179
+
180
+ except Exception:
181
+ continue
182
+
183
+ print(f"✅ Successfully processed {len(processed_data)} examples.")
184
+ return processed_data
185
+
186
+ if __name__ == "__main__":
187
+ # Test run
188
+ data = load_unified_dataset(sample_size=20)
189
+ if len(data) > 0:
190
+ print(f"\nSample 0 Words: {data[0]['words'][:10]}...")
191
+ print(f"Sample 0 Tags: {data[0]['ner_tags'][:10]}...")
192
+
193
+ all_tags = [t for item in data for t in item['ner_tags']]
194
+ unique_tags = set(all_tags)
195
+ print(f"\nUnique Tags Found in Sample: {unique_tags}")
196
+ else:
197
+ print("No valid examples found in sample.")
src/extraction.py CHANGED
@@ -1,273 +1,123 @@
1
- import re
2
- from typing import List, Dict, Optional, Any
3
-
4
-
5
- def extract_dates(text: str) -> List[str]:
6
- if not text:
7
- return []
8
-
9
- dates = []
10
-
11
- pattern1 = r'\d{2}[/-]\d{2}[/-]\d{4}'
12
- pattern2 = r'\d{2}[/-]\d{2}[/-]\d{2}(?!\d)'
13
- pattern3 = r'\d{4}[/-]\d{2}[/-]\d{2}'
14
-
15
- dates.extend(re.findall(pattern1, text))
16
- dates.extend(re.findall(pattern2, text))
17
- dates.extend(re.findall(pattern3, text))
18
-
19
- dates = list(dict.fromkeys(dates))
20
- return dates
21
-
22
-
23
- def extract_amounts(text: str) -> List[float]:
24
- if not text:
25
- return []
26
- # Matches: 123.45, 1,234.56, $123.45, 123.45 RM
27
- pattern = r'(?:RM|Rs\.?|\$|€)?\s*\d{1,3}(?:,\d{3})*[.,]\d{2}'
28
- amounts_strings = (re.findall(pattern, text))
29
-
30
- amounts = []
31
- for amt_str in amounts_strings:
32
- amt_cleaned = re.sub(r'[^\d.,]', '', amt_str)
33
- amt_cleaned = amt_cleaned.replace(',', '.')
34
- try:
35
- amounts.append(float(amt_cleaned))
36
- except ValueError:
37
- continue
38
- return amounts
39
-
40
-
41
- def extract_total(text: str) -> Optional[float]:
42
- if not text:
43
- return None
44
-
45
- pattern = r'(?:TOTAL|GRAND\s*TOTAL|AMOUNT\s*DUE|BALANCE)\s*:?\s*(\d+[.,]\d{2})'
46
- match = re.search(pattern, text, re.IGNORECASE)
47
-
48
- if match:
49
- amount_str = match.group(1).replace(',', '.')
50
- return float(amount_str)
51
-
52
- return None
53
-
54
-
55
- def extract_vendor(text: str) -> Optional[str]:
56
- if not text:
57
- return None
58
-
59
- lines = text.strip().split('\n')
60
-
61
- company_suffixes = ['SDN BHD', 'INC', 'LTD', 'LLC', 'PLC', 'CORP', 'PTY', 'PVT']
62
-
63
- for line in lines:
64
- line = line.strip()
65
-
66
- # Skip empty or very short line
67
- if len(line) < 3:
68
- continue
69
-
70
- # Skip lines with only symbols
71
- if all(c in '*-=_#' for c in line.replace(' ', '')):
72
- continue
73
-
74
- for suffix in company_suffixes:
75
- if suffix in line.upper():
76
- return line
77
-
78
- # If we've gone through 10 lines and found nothing,
79
- # return the first substantial line
80
- # (Vendor is usually in first few lines)
81
-
82
- # Fallback: return first non-trivial line
83
- for line in lines[:10]:
84
- line = line.strip()
85
- if len(line) >= 3 and not all(c in '*-=_#' for c in line.replace(' ', '')):
86
- return line
87
- return None
88
-
89
-
90
- def extract_invoice_number(text: str) -> Optional[str]:
91
- if not text:
92
- return None
93
-
94
- # Look for invoice number patterns (alphanumeric with hyphens, 5+ chars)
95
- # Typically near invoice-related text
96
- lines = text.split('\n')
97
-
98
- for line in lines[:15]: # Check first 15 lines (invoice # is usually at top)
99
- # If line mentions anything invoice-related
100
- if any(keyword in line.lower() for keyword in ['nvoice', 'receipt', 'bill', 'no']):
101
- # Find alphanumeric patterns
102
- patterns = re.findall(r'[A-Z]{2,}[A-Z0-9\-]{3,}', line, re.IGNORECASE)
103
- for pattern in patterns:
104
- # Must be 5+ chars and contain both letters and numbers
105
- if (len(pattern) >= 5 and
106
- any(c.isdigit() for c in pattern) and
107
- any(c.isalpha() for c in pattern)):
108
- return pattern.upper()
109
-
110
- return None
111
-
112
-
113
- def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
114
- if not text:
115
- return None
116
-
117
- bill_to = None
118
-
119
- # Normalize lines and remove empty lines
120
- lines = [line.strip() for line in text.splitlines() if line.strip()]
121
-
122
- # Possible headings
123
- headings = ['bill to', 'billed to', 'billing name', 'customer']
124
-
125
- bill_to_text = None
126
- for i, line in enumerate(lines):
127
- lower_line = line.lower()
128
- if any(h in lower_line for h in headings):
129
- # Capture text after colon or hyphen if present
130
- split_line = re.split(r'[:\-]', line, maxsplit=1)
131
- if len(split_line) > 1:
132
- bill_to_text = split_line[1].strip()
133
- else:
134
- # If name is on next line
135
- if i + 1 < len(lines):
136
- bill_to_text = lines[i + 1].strip()
137
- break
138
-
139
- if not bill_to_text:
140
- return None
141
-
142
- # Extract email if present
143
- email_match = re.search(r'[\w\.-]+@[\w\.-]+\.\w+', bill_to_text)
144
- email = email_match.group(0) if email_match else None
145
-
146
- # Remove email from name
147
- if email:
148
- bill_to_text = bill_to_text.replace(email, '').strip()
149
-
150
- if len(bill_to_text) > 2: # Basic validation
151
- bill_to = {"name": bill_to_text, "email": email}
152
-
153
- return bill_to
154
-
155
-
156
- def extract_line_items(text: str) -> List[Dict[str, Any]]:
157
- """
158
- Extract line items from receipt text more robustly.
159
- Handles:
160
- - Multi-line descriptions
161
- - Prices with or without currency symbols
162
- - Quantities in different formats
163
- - Missing decimals
164
-
165
- Args:
166
- text: Raw OCR text
167
-
168
- Returns:
169
- List of dictionaries with description, quantity, unit_price, total
170
- """
171
- items = []
172
- lines = text.split('\n')
173
-
174
- # Keywords to detect start/end of item section
175
- start_keywords = ['description', 'item', 'qty', 'price', 'amount']
176
- end_keywords = ['total', 'subtotal', 'tax', 'gst']
177
-
178
- # Detect section
179
- start_index = -1
180
- end_index = len(lines)
181
- for i, line in enumerate(lines):
182
- lower = line.lower()
183
- if start_index == -1 and any(k in lower for k in start_keywords):
184
- start_index = i + 1
185
- if start_index != -1 and any(k in lower for k in end_keywords):
186
- end_index = i
187
- break
188
-
189
- if start_index == -1:
190
- return []
191
-
192
- item_lines = lines[start_index:end_index]
193
-
194
- current_description = ""
195
- for line in item_lines:
196
- # Remove currency symbols, commas, etc.
197
- clean_line = re.sub(r'[^\d\.\s]', '', line)
198
-
199
- # Find all numbers (floats or integers)
200
- amounts_on_line = re.findall(r'\d+(?:\.\d+)?', clean_line)
201
-
202
- # Attempt to detect quantity at the start: "2 ", "3 x", etc.
203
- qty_match = re.match(r'^\s*(\d+)\s*(?:x)?', line)
204
- quantity = int(qty_match.group(1)) if qty_match else 1
205
-
206
- # Extract description by removing numbers and common symbols
207
- desc_part = re.sub(r'[\d\.\s]+', '', line).strip()
208
- if len(desc_part) > 0:
209
- if current_description:
210
- current_description += " " + desc_part
211
- else:
212
- current_description = desc_part
213
-
214
- # If there are numbers and a description, create item
215
- if amounts_on_line and current_description:
216
- try:
217
- # Heuristic: last number is total, second last is unit price
218
- item_total = float(amounts_on_line[-1])
219
- unit_price = float(amounts_on_line[-2]) if len(amounts_on_line) > 1 else item_total
220
-
221
- items.append({
222
- "description": current_description.strip(),
223
- "quantity": quantity,
224
- "unit_price": unit_price,
225
- "total": item_total
226
- })
227
- current_description = "" # reset for next item
228
- except ValueError:
229
- current_description = ""
230
- continue
231
-
232
- return items
233
-
234
-
235
- def structure_output(text: str) -> Dict[str, Any]:
236
- """
237
- Extract all information and return in the desired advanced format.
238
- """
239
-
240
- # Old fields
241
- date = extract_dates(text)[0] if extract_dates(text) else None
242
- total = extract_total(text)
243
-
244
- # New fields
245
- bill_to = extract_bill_to(text)
246
- items = extract_line_items(text)
247
- invoice_num = extract_invoice_number(text) # Renamed for clarity
248
-
249
- data = {
250
- "receipt_number": invoice_num,
251
- "date": date,
252
- "bill_to": bill_to,
253
- "items": items,
254
- "total_amount": total,
255
- "raw_text": text
256
- }
257
-
258
- # --- Confidence and Validation ---
259
- fields_to_check = ['receipt_number', 'date', 'bill_to', 'total_amount']
260
- extracted_fields = sum(1 for field in fields_to_check if data.get(field) is not None)
261
- if items: # Count items as an extracted field
262
- extracted_fields += 1
263
-
264
- data['extraction_confidence'] = int((extracted_fields / (len(fields_to_check) + 1)) * 100)
265
-
266
- # A more advanced validation
267
- items_total = sum(item.get('total', 0) for item in items)
268
- data['validation_passed'] = False
269
- if total is not None and abs(total - items_total) < 0.01: # Check if total matches sum of items
270
- data['validation_passed'] = True
271
-
272
- return data
273
-
 
1
+ # src/extraction.py
2
+
3
+ import re
4
+ from typing import List, Dict, Optional, Any
5
+
6
+ def extract_dates(text: str) -> List[str]:
7
+ if not text: return []
8
+ dates = []
9
+ # DD/MM/YYYY or DD-MM-YYYY
10
+ pattern1 = r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b'
11
+ # YYYY-MM-DD
12
+ pattern2 = r'\b\d{4}[/-]\d{1,2}[/-]\d{1,2}\b'
13
+
14
+ dates.extend(re.findall(pattern1, text))
15
+ dates.extend(re.findall(pattern2, text))
16
+ return list(dict.fromkeys(dates))
17
+
18
+ def extract_amounts(text: str) -> List[float]:
19
+ if not text: return []
20
+ # Matches: 1,234.56 or 1234.56
21
+ pattern = r'\b\d{1,3}(?:,\d{3})*\.\d{2}\b'
22
+ amounts_strings = re.findall(pattern, text)
23
+
24
+ amounts = []
25
+ for amt_str in amounts_strings:
26
+ amt_cleaned = amt_str.replace(',', '')
27
+ try:
28
+ amounts.append(float(amt_cleaned))
29
+ except ValueError:
30
+ continue
31
+ return amounts
32
+
33
+ def extract_total(text: str) -> Optional[float]:
34
+ """
35
+ Robust total extraction looking for keywords + largest number context.
36
+ """
37
+ if not text: return None
38
+
39
+ # 1. Try specific "Total" keywords first
40
+ # Looks for "Total: 123.45" or "Total Amount $123.45"
41
+ pattern = r'(?:TOTAL|AMOUNT DUE|GRAND TOTAL|BALANCE|PAYABLE)[\w\s]*[:$]?\s*([\d,]+\.\d{2})'
42
+ matches = re.findall(pattern, text, re.IGNORECASE)
43
+
44
+ if matches:
45
+ # Return the last match (often the grand total at bottom)
46
+ try:
47
+ return float(matches[-1].replace(',', ''))
48
+ except ValueError:
49
+ pass
50
+
51
+ # 2. Fallback: Find the largest monetary value in the bottom half of text
52
+ # (Risky, but better than None)
53
+ amounts = extract_amounts(text)
54
+ if amounts:
55
+ return max(amounts)
56
+
57
+ return None
58
+
59
+ def extract_vendor(text: str) -> Optional[str]:
60
+ if not text: return None
61
+ lines = text.strip().split('\n')
62
+ company_suffixes = ['SDN BHD', 'INC', 'LTD', 'LLC', 'PLC', 'CORP', 'PTY', 'PVT', 'LIMITED']
63
+
64
+ for line in lines[:10]: # Check top 10 lines
65
+ line_upper = line.upper()
66
+ if any(suffix in line_upper for suffix in company_suffixes):
67
+ return line.strip()
68
+
69
+ # Fallback: Return first non-empty line that isn't a date
70
+ for line in lines[:5]:
71
+ if len(line.strip()) > 3 and not re.search(r'\d{2}/\d{2}', line):
72
+ return line.strip()
73
+ return None
74
+
75
+ def extract_invoice_number(text: str) -> Optional[str]:
76
+ """
77
+ Improved regex that handles alphanumeric AND numeric IDs.
78
+ """
79
+ if not text: return None
80
+
81
+ # Strategy 1: Look for "Invoice No: XXXXX" pattern
82
+ # Matches: "Invoice No: 12345", "Inv #: AB-123", "Bill No. 999"
83
+ keyword_pattern = r'(?:INVOICE|BILL|RECEIPT)\s*(?:NO|NUMBER|#|NUM)?[\s\.:-]*([A-Z0-9\-/]{3,})'
84
+ match = re.search(keyword_pattern, text, re.IGNORECASE)
85
+ if match:
86
+ return match.group(1)
87
+
88
+ # Strategy 2: Look for standalone labeled patterns (Existing Logic)
89
+ # Only if Strategy 1 fails
90
+ lines = text.split('\n')
91
+ for line in lines[:20]:
92
+ if any(k in line.lower() for k in ['invoice', 'no', '#']):
93
+ # Allow pure digits now if they are long enough (e.g. 40378170)
94
+ # Match 4+ digits OR alphanumeric
95
+ token_match = re.search(r'\b([A-Z0-9-]{4,})\b', line)
96
+ if token_match:
97
+ return token_match.group(1)
98
+
99
+ return None
100
+
101
+ def extract_bill_to(text: str) -> Optional[Dict[str, str]]:
102
+ if not text: return None
103
+
104
+ # Look for "Bill To" block
105
+ match = re.search(r'(?:BILL|BILLED)\s*TO[:\s]+([^\n]+)', text, re.IGNORECASE)
106
+ if match:
107
+ name = match.group(1).strip()
108
+ return {"name": name, "email": None}
109
+ return None
110
+
111
+ def extract_line_items(text: str) -> List[Dict[str, Any]]:
112
+ # (Keeping your existing logic simple for now)
113
+ return []
114
+
115
+ def structure_output(text: str) -> Dict[str, Any]:
116
+ """Legacy wrapper for rule-based-only pipeline"""
117
+ return {
118
+ "receipt_number": extract_invoice_number(text),
119
+ "date": extract_dates(text)[0] if extract_dates(text) else None,
120
+ "total_amount": extract_total(text),
121
+ "vendor": extract_vendor(text),
122
+ "raw_text": text
123
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ml_extraction.py CHANGED
@@ -1,176 +1,144 @@
1
- # src/ml_extraction.py
2
-
3
- import torch
4
- from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
5
- from PIL import Image
6
- import pytesseract
7
- from typing import List, Dict, Any
8
- import re
9
-
10
- # --- CONFIGURATION ---
11
- # The local path where we expect to find/save the model
12
- LOCAL_MODEL_PATH = "./models/layoutlmv3-sroie-best"
13
- # The Hugging Face Hub ID for the model to download if not found locally
14
- HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-sroie-invoice-extraction"
15
-
16
- # --- Function to load the model ---
17
- def load_model_and_processor(model_path, hub_id):
18
- """
19
- Tries to load the model from a local path. If it fails,
20
- it downloads it from the Hugging Face Hub.
21
- """
22
- try:
23
- # Try loading from local path first
24
- print(f"Attempting to load model from local path: {model_path}...")
25
- processor = LayoutLMv3Processor.from_pretrained(model_path)
26
- model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
27
- print("✅ Model loaded successfully from local path.")
28
- except OSError:
29
- # If it fails, download from the Hub
30
- print(f"Model not found locally. Downloading from Hugging Face Hub: {hub_id}...")
31
- from huggingface_hub import snapshot_download
32
- # Download the model files and save them to the local path
33
- snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
34
- # Now load from the local path again
35
- processor = LayoutLMv3Processor.from_pretrained(model_path)
36
- model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
37
- print("✅ Model downloaded and loaded successfully from the Hub.")
38
-
39
- return model, processor
40
-
41
- # --- Load the model and processor only ONCE when the module is imported ---
42
- MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
43
-
44
- if MODEL and PROCESSOR:
45
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- MODEL.to(DEVICE)
47
- MODEL.eval()
48
- print(f"ML Model is ready on device: {DEVICE}")
49
- else:
50
- DEVICE = None
51
- print("❌ Could not load ML model.")
52
-
53
- # --- Helper Function to group entities ---
54
- def _process_predictions(words: List[str], unnormalized_boxes: List[List[int]], encoding, predictions: List[int], id2label: Dict[int, str]) -> Dict[str, Any]:
55
- word_ids = encoding.word_ids(batch_index=0)
56
-
57
- word_level_preds = {}
58
- for idx, word_id in enumerate(word_ids):
59
- if word_id is not None:
60
- label_id = predictions[idx]
61
- if label_id != -100:
62
- if word_id not in word_level_preds:
63
- word_level_preds[word_id] = id2label[label_id]
64
-
65
- entities = {}
66
- for word_idx, label in word_level_preds.items():
67
- if label == 'O': continue
68
-
69
- entity_type = label[2:]
70
- word = words[word_idx]
71
-
72
- if label.startswith('B-'):
73
- entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
74
- elif label.startswith('I-') and entity_type in entities:
75
- if word_idx > 0 and word_level_preds.get(word_idx - 1) in (f'B-{entity_type}', f'I-{entity_type}'):
76
- entities[entity_type]['text'] += " " + word
77
- entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
78
- else:
79
- entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
80
-
81
- # Clean up the final text field
82
- for entity in entities.values():
83
- entity['text'] = entity['text'].strip()
84
-
85
- return entities
86
-
87
- # --- Main Function to be called from the pipeline ---
88
- def extract_ml_based(image_path: str) -> Dict[str, Any]:
89
- """
90
- Performs end-to-end ML-based extraction on a single image.
91
-
92
- Args:
93
- image_path: The path to the invoice image.
94
-
95
- Returns:
96
- A dictionary containing the extracted entities.
97
- """
98
- if not MODEL or not PROCESSOR:
99
- raise RuntimeError("ML model is not loaded. Cannot perform extraction.")
100
-
101
- # 1. Load Image
102
- image = Image.open(image_path).convert("RGB")
103
- width, height = image.size
104
-
105
- # 2. Perform OCR
106
- ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
107
- n_boxes = len(ocr_data['level'])
108
- words = []
109
- unnormalized_boxes = []
110
- for i in range(n_boxes):
111
- if int(ocr_data['conf'][i]) > 60 and ocr_data['text'][i].strip() != '':
112
- word = ocr_data['text'][i]
113
- (x, y, w, h) = (ocr_data['left'][i], ocr_data['top'][i], ocr_data['width'][i], ocr_data['height'][i])
114
- words.append(word)
115
- unnormalized_boxes.append([x, y, x + w, y + h])
116
-
117
- # 3. Normalize Boxes and Prepare Inputs
118
- normalized_boxes = []
119
- for box in unnormalized_boxes:
120
- normalized_boxes.append([
121
- int(1000 * (box[0] / width)),
122
- int(1000 * (box[1] / height)),
123
- int(1000 * (box[2] / width)),
124
- int(1000 * (box[3] / height)),
125
- ])
126
-
127
- # 4. Process with LayoutLMv3 Processor
128
- encoding = PROCESSOR(
129
- image,
130
- text=words,
131
- boxes=normalized_boxes,
132
- truncation=True,
133
- max_length=512,
134
- return_tensors="pt"
135
- ).to(DEVICE)
136
-
137
- # 5. Run Inference
138
- with torch.no_grad():
139
- outputs = MODEL(**encoding)
140
-
141
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
142
-
143
- # 6. Post-process to get final entities
144
- extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
145
-
146
- # 7. Format the output to be consistent with your rule-based output
147
- # Format the output to be consistent with the desired UI structure
148
- # Format the output to be a superset of all possible fields
149
- final_output = {
150
- # --- Standard UI Fields ---
151
- "receipt_number": None, # SROIE doesn't train for this. Your regex model will provide it.
152
- "date": extracted_entities.get("DATE", {}).get("text"),
153
- "bill_to": None, # SROIE doesn't train for this. Your regex model will provide it.
154
- "items": [], # SROIE doesn't train for line items.
155
- "total_amount": None,
156
-
157
- # --- Additional Fields from ML Model ---
158
- "vendor": extracted_entities.get("COMPANY", {}).get("text"), # The ML model finds 'COMPANY'
159
- "address": extracted_entities.get("ADDRESS", {}).get("text"),
160
-
161
- # --- Debugging Info ---
162
- "raw_text": " ".join(words),
163
- "raw_ocr_words": words,
164
- "raw_predictions": extracted_entities
165
- }
166
-
167
- # Safely extract and convert total
168
- total_text = extracted_entities.get("TOTAL", {}).get("text")
169
- if total_text:
170
- try:
171
- cleaned_total = re.sub(r'[^\d.]', '', total_text)
172
- final_output["total_amount"] = float(cleaned_total)
173
- except (ValueError, TypeError):
174
- final_output["total_amount"] = None
175
-
176
  return final_output
 
1
+ # src/ml_extraction.py
2
+
3
+ import torch
4
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
5
+ from PIL import Image
6
+ import pytesseract
7
+ from typing import List, Dict, Any
8
+ import re
9
+ import numpy as np
10
+ from extraction import extract_invoice_number, extract_total
11
+
12
+ # --- CONFIGURATION ---
13
+ LOCAL_MODEL_PATH = "./models/layoutlmv3-generalized"
14
+ HUB_MODEL_ID = "GSoumyajit2005/layoutlmv3-sroie-invoice-extraction"
15
+
16
+ # --- Load Model ---
17
+ def load_model_and_processor(model_path, hub_id):
18
+ try:
19
+ print(f"Attempting to load model from local path: {model_path}...")
20
+ processor = LayoutLMv3Processor.from_pretrained(model_path)
21
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
22
+ print("✅ Model loaded successfully from local path.")
23
+ except OSError:
24
+ print(f"Model not found locally. Downloading from Hub: {hub_id}...")
25
+ from huggingface_hub import snapshot_download
26
+ snapshot_download(repo_id=hub_id, local_dir=model_path, local_dir_use_symlinks=False)
27
+ processor = LayoutLMv3Processor.from_pretrained(model_path)
28
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
29
+ print("✅ Model downloaded and loaded successfully.")
30
+ return model, processor
31
+
32
+ MODEL, PROCESSOR = load_model_and_processor(LOCAL_MODEL_PATH, HUB_MODEL_ID)
33
+
34
+ if MODEL and PROCESSOR:
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ MODEL.to(DEVICE)
37
+ MODEL.eval()
38
+ print(f"ML Model is ready on device: {DEVICE}")
39
+ else:
40
+ DEVICE = None
41
+ print("❌ Could not load ML model.")
42
+
43
+ def _process_predictions(words, unnormalized_boxes, encoding, predictions, id2label):
44
+ word_ids = encoding.word_ids(batch_index=0)
45
+ word_level_preds = {}
46
+ for idx, word_id in enumerate(word_ids):
47
+ if word_id is not None:
48
+ label_id = predictions[idx]
49
+ if label_id != -100:
50
+ if word_id not in word_level_preds:
51
+ word_level_preds[word_id] = id2label[label_id]
52
+
53
+ entities = {}
54
+ for word_idx, label in word_level_preds.items():
55
+ if label == 'O': continue
56
+ entity_type = label[2:]
57
+ word = words[word_idx]
58
+
59
+ if label.startswith('B-'):
60
+ entities[entity_type] = {"text": word, "bbox": [unnormalized_boxes[word_idx]]}
61
+ elif label.startswith('I-') and entity_type in entities:
62
+ entities[entity_type]['text'] += " " + word
63
+ entities[entity_type]['bbox'].append(unnormalized_boxes[word_idx])
64
+
65
+ for entity in entities.values():
66
+ entity['text'] = entity['text'].strip()
67
+
68
+ return entities
69
+
70
+ def extract_ml_based(image_path: str) -> Dict[str, Any]:
71
+ if not MODEL or not PROCESSOR:
72
+ raise RuntimeError("ML model is not loaded.")
73
+
74
+ # 1. Load Image
75
+ image = Image.open(image_path).convert("RGB")
76
+ width, height = image.size
77
+ ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
78
+
79
+ words = []
80
+ unnormalized_boxes = []
81
+ for i in range(len(ocr_data['level'])):
82
+ if int(ocr_data['conf'][i]) > 30 and ocr_data['text'][i].strip() != '':
83
+ words.append(ocr_data['text'][i])
84
+ unnormalized_boxes.append([
85
+ ocr_data['left'][i], ocr_data['top'][i],
86
+ ocr_data['width'][i], ocr_data['height'][i]
87
+ ])
88
+
89
+ raw_text = " ".join(words)
90
+
91
+ # 2. Normalize Boxes (WITH SAFETY CLAMP)
92
+ normalized_boxes = []
93
+ for box in unnormalized_boxes:
94
+ x, y, w, h = box
95
+ x0, y0, x1, y1 = x, y, x + w, y + h
96
+
97
+ # ⚠️ The Fix: Ensure values never exceed 1000 or drop below 0
98
+ normalized_boxes.append([
99
+ max(0, min(1000, int(1000 * (x0 / width)))),
100
+ max(0, min(1000, int(1000 * (y0 / height)))),
101
+ max(0, min(1000, int(1000 * (x1 / width)))),
102
+ max(0, min(1000, int(1000 * (y1 / height)))),
103
+ ])
104
+
105
+ # 3. Inference
106
+ encoding = PROCESSOR(
107
+ image, text=words, boxes=normalized_boxes,
108
+ truncation=True, max_length=512, return_tensors="pt"
109
+ ).to(DEVICE)
110
+
111
+ with torch.no_grad():
112
+ outputs = MODEL(**encoding)
113
+
114
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
115
+ extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
116
+
117
+ # 4. Construct Output
118
+ final_output = {
119
+ "vendor": extracted_entities.get("COMPANY", {}).get("text"),
120
+ "date": extracted_entities.get("DATE", {}).get("text"),
121
+ "address": extracted_entities.get("ADDRESS", {}).get("text"),
122
+ "receipt_number": extracted_entities.get("INVOICE_NO", {}).get("text"),
123
+ "bill_to": extracted_entities.get("BILL_TO", {}).get("text"),
124
+ "total_amount": None,
125
+ "items": [],
126
+ "raw_text": raw_text
127
+ }
128
+
129
+ # Fallbacks
130
+ ml_total = extracted_entities.get("TOTAL", {}).get("text")
131
+ if ml_total:
132
+ try:
133
+ cleaned = re.sub(r'[^\d.,]', '', ml_total).replace(',', '.')
134
+ final_output["total_amount"] = float(cleaned)
135
+ except (ValueError, TypeError):
136
+ pass
137
+
138
+ if final_output["total_amount"] is None:
139
+ final_output["total_amount"] = extract_total(raw_text)
140
+
141
+ if not final_output["receipt_number"]:
142
+ final_output["receipt_number"] = extract_invoice_number(raw_text)
143
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return final_output
src/ocr.py CHANGED
@@ -1,15 +1,15 @@
1
- import pytesseract
2
- import numpy as np
3
- from typing import Optional
4
-
5
- pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
6
-
7
- def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
8
- if image is None:
9
- raise ValueError("Input image is None")
10
- text = pytesseract.image_to_string(image, lang=lang, config=config)
11
- return text.strip()
12
-
13
- def extract_text_with_boxes(image):
14
- pass
15
-
 
1
+ import pytesseract
2
+ import numpy as np
3
+ from typing import Optional
4
+
5
+ #pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
6
+
7
+ def extract_text(image: np.ndarray, lang: str='eng', config: str='--psm 11') -> str:
8
+ if image is None:
9
+ raise ValueError("Input image is None")
10
+ text = pytesseract.image_to_string(image, lang=lang, config=config)
11
+ return text.strip()
12
+
13
+ def extract_text_with_boxes(image):
14
+ pass
15
+
src/pipeline.py CHANGED
@@ -1,151 +1,151 @@
1
- """
2
- Main invoice processing pipeline
3
- Orchestrates preprocessing, OCR, and extraction
4
- """
5
-
6
- from typing import Dict, Any, Optional
7
- from pathlib import Path
8
- import json
9
-
10
- # Make sure all your modules are imported
11
- from preprocessing import load_image, convert_to_grayscale, remove_noise
12
- from ocr import extract_text
13
- from extraction import structure_output
14
- from ml_extraction import extract_ml_based
15
-
16
- def process_invoice(image_path: str,
17
- method: str = 'ml', # <-- New parameter: 'ml' or 'rules'
18
- save_results: bool = False,
19
- output_dir: str = 'outputs') -> Dict[str, Any]:
20
- """
21
- Process an invoice image using either rule-based or ML-based extraction.
22
-
23
- Args:
24
- image_path: Path to the invoice image.
25
- method: The extraction method to use ('ml' or 'rules'). Default is 'ml'.
26
- save_results: Whether to save JSON results to a file.
27
- output_dir: Directory to save results.
28
-
29
- Returns:
30
- A dictionary with the extracted invoice data.
31
- """
32
- if not Path(image_path).exists():
33
- raise FileNotFoundError(f"Image not found at path: {image_path}")
34
-
35
- print(f"Processing with '{method}' method...")
36
-
37
- if method == 'ml':
38
- # --- ML-Based Extraction ---
39
- try:
40
- # The ml_extraction function handles everything internally
41
- structured_data = extract_ml_based(image_path)
42
- except Exception as e:
43
- raise ValueError(f"Error during ML-based extraction: {e}")
44
-
45
- elif method == 'rules':
46
- # --- Rule-Based Extraction (Your original logic) ---
47
- try:
48
- image = load_image(image_path)
49
- gray_image = convert_to_grayscale(image)
50
- preprocessed_image = remove_noise(gray_image, kernel_size=3)
51
- text = extract_text(preprocessed_image, config='--psm 6')
52
- structured_data = structure_output(text) # Calls your old extraction.py
53
- except Exception as e:
54
- raise ValueError(f"Error during rule-based extraction: {e}")
55
-
56
- else:
57
- raise ValueError(f"Unknown extraction method: '{method}'. Choose 'ml' or 'rules'.")
58
-
59
- # --- Saving Logic (remains the same) ---
60
- if save_results:
61
- output_path = Path(output_dir)
62
- output_path.mkdir(parents=True, exist_ok=True)
63
- json_path = output_path / (Path(image_path).stem + f"_{method}.json") # Add method to filename
64
- try:
65
- with open(json_path, 'w', encoding='utf-8') as f:
66
- json.dump(structured_data, f, indent=2, ensure_ascii=False)
67
- except Exception as e:
68
- raise IOError(f"Error saving results to {json_path}: {e}")
69
-
70
- return structured_data
71
-
72
-
73
- def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
74
- """Process multiple invoices in a folder""" # Corrected indentation
75
- results = []
76
-
77
- supported_extensions = ['*.jpg', '*.png', '*.jpeg']
78
-
79
- for ext in supported_extensions:
80
- for img_file in Path(image_folder).glob(ext):
81
- print(f"🔄 Processing: {img_file}")
82
- try:
83
- result = process_invoice(str(img_file), save_results=True, output_dir=output_dir)
84
- results.append(result)
85
- except Exception as e:
86
- print(f"❌ Error processing {img_file}: {e}")
87
-
88
- print(f"\n🎉 Batch processing complete! {len(results)} invoices processed.")
89
- return results
90
-
91
-
92
- def main():
93
- """Command-line interface for invoice processing"""
94
- import argparse
95
-
96
- parser = argparse.ArgumentParser(
97
- description='Process invoice images or folders and extract structured data.',
98
- formatter_class=argparse.RawDescriptionHelpFormatter,
99
- epilog="""
100
- Examples:
101
- # Process a single invoice
102
- python src/pipeline.py data/raw/receipt1.jpg
103
-
104
- # Process and save a single invoice
105
- python src/pipeline.py data/raw/receipt1.jpg --save
106
-
107
- # Process an entire folder of invoices
108
- python src/pipeline.py data/raw --save --output results/
109
- """
110
- )
111
-
112
- # Corrected: Single 'path' argument
113
- parser.add_argument('path', help='Path to an invoice image or a folder of images')
114
- parser.add_argument('--save', action='store_true', help='Save results to JSON files')
115
- parser.add_argument('--output', default='outputs', help='Output directory for JSON files')
116
- parser.add_argument('--method', default='ml', choices=['ml', 'rules'], help="Extraction method: 'ml' or 'rules'")
117
-
118
- args = parser.parse_args()
119
-
120
- try:
121
- # Check if path is a directory or a file
122
- if Path(args.path).is_dir():
123
- process_batch(args.path, output_dir=args.output)
124
- elif Path(args.path).is_file():
125
- # Corrected: Use args.path
126
- print(f"🔄 Processing: {args.path}")
127
- result = process_invoice(args.path, method=args.method, save_results=args.save, output_dir=args.output)
128
-
129
- print("\n📊 Extracted Data:")
130
- print("=" * 60)
131
- print(f"Vendor: {result.get('vendor', 'N/A')}")
132
- print(f"Invoice Number: {result.get('invoice_number', 'N/A')}")
133
- print(f"Date: {result.get('date', 'N/A')}")
134
- print(f"Total: ${result.get('total', 0.0)}")
135
- print("=" * 60)
136
-
137
- if args.save:
138
- print(f"\n💾 JSON saved to: {args.output}/{Path(args.path).stem}.json")
139
- else:
140
- raise FileNotFoundError(f"Path does not exist: {args.path}")
141
-
142
- except Exception as e:
143
- print(f"❌ An error occurred: {e}")
144
- return 1
145
-
146
- return 0
147
-
148
-
149
- if __name__ == '__main__':
150
- import sys
151
  sys.exit(main())
 
1
+ """
2
+ Main invoice processing pipeline
3
+ Orchestrates preprocessing, OCR, and extraction
4
+ """
5
+
6
+ from typing import Dict, Any, Optional
7
+ from pathlib import Path
8
+ import json
9
+
10
+ # Make sure all your modules are imported
11
+ from preprocessing import load_image, convert_to_grayscale, remove_noise
12
+ from ocr import extract_text
13
+ from extraction import structure_output
14
+ from ml_extraction import extract_ml_based
15
+
16
+ def process_invoice(image_path: str,
17
+ method: str = 'ml', # <-- New parameter: 'ml' or 'rules'
18
+ save_results: bool = False,
19
+ output_dir: str = 'outputs') -> Dict[str, Any]:
20
+ """
21
+ Process an invoice image using either rule-based or ML-based extraction.
22
+
23
+ Args:
24
+ image_path: Path to the invoice image.
25
+ method: The extraction method to use ('ml' or 'rules'). Default is 'ml'.
26
+ save_results: Whether to save JSON results to a file.
27
+ output_dir: Directory to save results.
28
+
29
+ Returns:
30
+ A dictionary with the extracted invoice data.
31
+ """
32
+ if not Path(image_path).exists():
33
+ raise FileNotFoundError(f"Image not found at path: {image_path}")
34
+
35
+ print(f"Processing with '{method}' method...")
36
+
37
+ if method == 'ml':
38
+ # --- ML-Based Extraction ---
39
+ try:
40
+ # The ml_extraction function handles everything internally
41
+ structured_data = extract_ml_based(image_path)
42
+ except Exception as e:
43
+ raise ValueError(f"Error during ML-based extraction: {e}")
44
+
45
+ elif method == 'rules':
46
+ # --- Rule-Based Extraction (Your original logic) ---
47
+ try:
48
+ image = load_image(image_path)
49
+ gray_image = convert_to_grayscale(image)
50
+ preprocessed_image = remove_noise(gray_image, kernel_size=3)
51
+ text = extract_text(preprocessed_image, config='--psm 6')
52
+ structured_data = structure_output(text) # Calls your old extraction.py
53
+ except Exception as e:
54
+ raise ValueError(f"Error during rule-based extraction: {e}")
55
+
56
+ else:
57
+ raise ValueError(f"Unknown extraction method: '{method}'. Choose 'ml' or 'rules'.")
58
+
59
+ # --- Saving Logic (remains the same) ---
60
+ if save_results:
61
+ output_path = Path(output_dir)
62
+ output_path.mkdir(parents=True, exist_ok=True)
63
+ json_path = output_path / (Path(image_path).stem + f"_{method}.json") # Add method to filename
64
+ try:
65
+ with open(json_path, 'w', encoding='utf-8') as f:
66
+ json.dump(structured_data, f, indent=2, ensure_ascii=False)
67
+ except Exception as e:
68
+ raise IOError(f"Error saving results to {json_path}: {e}")
69
+
70
+ return structured_data
71
+
72
+
73
+ def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
74
+ """Process multiple invoices in a folder""" # Corrected indentation
75
+ results = []
76
+
77
+ supported_extensions = ['*.jpg', '*.png', '*.jpeg']
78
+
79
+ for ext in supported_extensions:
80
+ for img_file in Path(image_folder).glob(ext):
81
+ print(f"🔄 Processing: {img_file}")
82
+ try:
83
+ result = process_invoice(str(img_file), save_results=True, output_dir=output_dir)
84
+ results.append(result)
85
+ except Exception as e:
86
+ print(f"❌ Error processing {img_file}: {e}")
87
+
88
+ print(f"\n🎉 Batch processing complete! {len(results)} invoices processed.")
89
+ return results
90
+
91
+
92
+ def main():
93
+ """Command-line interface for invoice processing"""
94
+ import argparse
95
+
96
+ parser = argparse.ArgumentParser(
97
+ description='Process invoice images or folders and extract structured data.',
98
+ formatter_class=argparse.RawDescriptionHelpFormatter,
99
+ epilog="""
100
+ Examples:
101
+ # Process a single invoice
102
+ python src/pipeline.py data/raw/receipt1.jpg
103
+
104
+ # Process and save a single invoice
105
+ python src/pipeline.py data/raw/receipt1.jpg --save
106
+
107
+ # Process an entire folder of invoices
108
+ python src/pipeline.py data/raw --save --output results/
109
+ """
110
+ )
111
+
112
+ # Corrected: Single 'path' argument
113
+ parser.add_argument('path', help='Path to an invoice image or a folder of images')
114
+ parser.add_argument('--save', action='store_true', help='Save results to JSON files')
115
+ parser.add_argument('--output', default='outputs', help='Output directory for JSON files')
116
+ parser.add_argument('--method', default='ml', choices=['ml', 'rules'], help="Extraction method: 'ml' or 'rules'")
117
+
118
+ args = parser.parse_args()
119
+
120
+ try:
121
+ # Check if path is a directory or a file
122
+ if Path(args.path).is_dir():
123
+ process_batch(args.path, output_dir=args.output)
124
+ elif Path(args.path).is_file():
125
+ # Corrected: Use args.path
126
+ print(f"🔄 Processing: {args.path}")
127
+ result = process_invoice(args.path, method=args.method, save_results=args.save, output_dir=args.output)
128
+
129
+ print("\n📊 Extracted Data:")
130
+ print("=" * 60)
131
+ print(f"Vendor: {result.get('vendor', 'N/A')}")
132
+ print(f"Invoice Number: {result.get('invoice_number', 'N/A')}")
133
+ print(f"Date: {result.get('date', 'N/A')}")
134
+ print(f"Total: ${result.get('total_amount', 0.0)}")
135
+ print("=" * 60)
136
+
137
+ if args.save:
138
+ print(f"\n💾 JSON saved to: {args.output}/{Path(args.path).stem}.json")
139
+ else:
140
+ raise FileNotFoundError(f"Path does not exist: {args.path}")
141
+
142
+ except Exception as e:
143
+ print(f"❌ An error occurred: {e}")
144
+ return 1
145
+
146
+ return 0
147
+
148
+
149
+ if __name__ == '__main__':
150
+ import sys
151
  sys.exit(main())
src/preprocessing.py CHANGED
@@ -1,78 +1,78 @@
1
- import cv2
2
- import numpy as np
3
- from pathlib import Path
4
-
5
-
6
-
7
- def load_image(image_path: str) -> np.ndarray:
8
- if not Path(image_path).exists():
9
- raise FileNotFoundError(f"Image not found : {image_path}")
10
- image = cv2.imread(image_path)
11
- if image is None:
12
- raise ValueError(f"Could not load image: {image_path}")
13
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
14
- return image
15
-
16
-
17
- def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
18
- if image is None:
19
- raise ValueError(f"Image is None, cannot convert to grayscale")
20
- if len(image.shape) ==2:
21
- return image
22
- return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
23
-
24
-
25
- def remove_noise(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
26
- if image is None:
27
- raise ValueError(f"Image is None, cannot remove noise")
28
- if kernel_size <= 0:
29
- raise ValueError("Kernel size must be positive")
30
- if kernel_size % 2 == 0:
31
- raise ValueError("Kernel size must be odd")
32
- denoised_image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
33
- return denoised_image
34
-
35
-
36
- def binarize(image: np.ndarray, method: str = 'adaptive', block_size: int=11, C: int=2) -> np.ndarray:
37
- if image is None:
38
- raise ValueError(f"Image is None, cannot binarize")
39
- if image.ndim != 2:
40
- raise ValueError("Input image must be grayscale for binarization")
41
- if method == 'simple':
42
- _, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
43
- elif method == 'adaptive':
44
- binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, block_size, C)
45
- else:
46
- raise ValueError(f"Unknown binarization method: {method}")
47
- return binary_image
48
-
49
-
50
- def deskew(image):
51
- pass
52
-
53
-
54
- def preprocess_pipeline(image: np.ndarray,
55
- steps: list = ['grayscale', 'denoise', 'binarize'],
56
- denoise_kernel: int = 3,
57
- binarize_method: str = 'adaptive',
58
- binarize_block_size: int = 11,
59
- binarize_C: int = 2) -> np.ndarray:
60
- if image is None:
61
- raise ValueError("Input image is None")
62
-
63
- processed = image
64
-
65
- for step in steps:
66
- if step == 'grayscale':
67
- processed = convert_to_grayscale(processed)
68
- elif step == 'denoise':
69
- processed = remove_noise(processed, kernel_size=denoise_kernel)
70
- elif step == 'binarize':
71
- processed = binarize(processed,
72
- method=binarize_method,
73
- block_size=binarize_block_size,
74
- C=binarize_C)
75
- else:
76
- raise ValueError(f"Unknown preprocessing step: {step}")
77
-
78
- return processed
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+
6
+
7
+ def load_image(image_path: str) -> np.ndarray:
8
+ if not Path(image_path).exists():
9
+ raise FileNotFoundError(f"Image not found : {image_path}")
10
+ image = cv2.imread(image_path)
11
+ if image is None:
12
+ raise ValueError(f"Could not load image: {image_path}")
13
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
14
+ return image
15
+
16
+
17
+ def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
18
+ if image is None:
19
+ raise ValueError(f"Image is None, cannot convert to grayscale")
20
+ if len(image.shape) ==2:
21
+ return image
22
+ return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
23
+
24
+
25
+ def remove_noise(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
26
+ if image is None:
27
+ raise ValueError(f"Image is None, cannot remove noise")
28
+ if kernel_size <= 0:
29
+ raise ValueError("Kernel size must be positive")
30
+ if kernel_size % 2 == 0:
31
+ raise ValueError("Kernel size must be odd")
32
+ denoised_image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
33
+ return denoised_image
34
+
35
+
36
+ def binarize(image: np.ndarray, method: str = 'adaptive', block_size: int=11, C: int=2) -> np.ndarray:
37
+ if image is None:
38
+ raise ValueError(f"Image is None, cannot binarize")
39
+ if image.ndim != 2:
40
+ raise ValueError("Input image must be grayscale for binarization")
41
+ if method == 'simple':
42
+ _, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
43
+ elif method == 'adaptive':
44
+ binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY, block_size, C)
45
+ else:
46
+ raise ValueError(f"Unknown binarization method: {method}")
47
+ return binary_image
48
+
49
+
50
+ def deskew(image):
51
+ pass
52
+
53
+
54
+ def preprocess_pipeline(image: np.ndarray,
55
+ steps: list = ['grayscale', 'denoise', 'binarize'],
56
+ denoise_kernel: int = 3,
57
+ binarize_method: str = 'adaptive',
58
+ binarize_block_size: int = 11,
59
+ binarize_C: int = 2) -> np.ndarray:
60
+ if image is None:
61
+ raise ValueError("Input image is None")
62
+
63
+ processed = image
64
+
65
+ for step in steps:
66
+ if step == 'grayscale':
67
+ processed = convert_to_grayscale(processed)
68
+ elif step == 'denoise':
69
+ processed = remove_noise(processed, kernel_size=denoise_kernel)
70
+ elif step == 'binarize':
71
+ processed = binarize(processed,
72
+ method=binarize_method,
73
+ block_size=binarize_block_size,
74
+ C=binarize_C)
75
+ else:
76
+ raise ValueError(f"Unknown preprocessing step: {step}")
77
+
78
+ return processed
tests/test_extraction.py CHANGED
@@ -1,41 +1,41 @@
1
- import sys
2
- sys.path.append('src')
3
-
4
- from extraction import extract_dates, extract_amounts, extract_total, extract_vendor, extract_invoice_number
5
-
6
- receipt_text = """
7
- tan chay yee
8
-
9
- *** COPY ***
10
-
11
- OJC MARKETING SDN BHD.
12
-
13
- ROC NO: 538358-H
14
-
15
- TAX INVOICE
16
-
17
- Invoice No: PEGIV-1030765
18
- Date: 15/01/2019 11:05:16 AM
19
-
20
- TOTAL: 193.00
21
- """
22
-
23
- print("🧪 Testing Extraction Functions")
24
- print("=" * 60)
25
-
26
- dates = extract_dates(receipt_text)
27
- print(f"\n📅 Date: {dates}")
28
-
29
- amounts = extract_amounts(receipt_text)
30
- print(f"\n💰 Amounts: {amounts}")
31
-
32
- total = extract_total(receipt_text)
33
- print(f"\n💵 Total: {total}")
34
-
35
- vendor = extract_vendor(receipt_text)
36
- print(f"\n🏢 Vendor: {vendor}")
37
-
38
- invoice_num = extract_invoice_number(receipt_text)
39
- print(f"\n📄 Invoice Number: {invoice_num}")
40
-
41
  print("\n✅ All extraction tests complete!")
 
1
+ import sys
2
+ sys.path.append('src')
3
+
4
+ from extraction import extract_dates, extract_amounts, extract_total, extract_vendor, extract_invoice_number
5
+
6
+ receipt_text = """
7
+ tan chay yee
8
+
9
+ *** COPY ***
10
+
11
+ OJC MARKETING SDN BHD.
12
+
13
+ ROC NO: 538358-H
14
+
15
+ TAX INVOICE
16
+
17
+ Invoice No: PEGIV-1030765
18
+ Date: 15/01/2019 11:05:16 AM
19
+
20
+ TOTAL: 193.00
21
+ """
22
+
23
+ print("🧪 Testing Extraction Functions")
24
+ print("=" * 60)
25
+
26
+ dates = extract_dates(receipt_text)
27
+ print(f"\n📅 Date: {dates}")
28
+
29
+ amounts = extract_amounts(receipt_text)
30
+ print(f"\n💰 Amounts: {amounts}")
31
+
32
+ total = extract_total(receipt_text)
33
+ print(f"\n💵 Total: {total}")
34
+
35
+ vendor = extract_vendor(receipt_text)
36
+ print(f"\n🏢 Vendor: {vendor}")
37
+
38
+ invoice_num = extract_invoice_number(receipt_text)
39
+ print(f"\n📄 Invoice Number: {invoice_num}")
40
+
41
  print("\n✅ All extraction tests complete!")
tests/test_full_pipeline.py CHANGED
@@ -1,42 +1,42 @@
1
- import sys
2
- sys.path.append('src')
3
-
4
- from preprocessing import load_image, convert_to_grayscale, remove_noise
5
- from ocr import extract_text
6
- from extraction import structure_output
7
- import json
8
-
9
- print("=" * 60)
10
- print("🎯 FULL INVOICE PROCESSING PIPELINE TEST")
11
- print("=" * 60)
12
-
13
- # Step 1: Load and preprocess image
14
- print("\n1️⃣ Loading and preprocessing image...")
15
- image = load_image('data/raw/receipt3.jpg')
16
- gray = convert_to_grayscale(image)
17
- denoised = remove_noise(gray, kernel_size=3)
18
- print("✅ Image preprocessed")
19
-
20
- # Step 2: Extract text with OCR
21
- print("\n2️⃣ Extracting text with OCR...")
22
- text = extract_text(denoised, config='--psm 6')
23
- print(f"✅ Extracted {len(text)} characters")
24
-
25
- # Step 3: Extract structured information
26
- print("\n3️⃣ Extracting structured information...")
27
- result = structure_output(text)
28
- print("✅ Information extracted")
29
-
30
- # Step 4: Display results
31
- print("\n" + "=" * 60)
32
- print("📊 EXTRACTED INVOICE DATA (JSON)")
33
- print("=" * 60)
34
- print(json.dumps(result, indent=2, ensure_ascii=False))
35
- print("=" * 60)
36
-
37
- print("\n🎉 PIPELINE COMPLETE!")
38
- print("\n📋 Summary:")
39
- print(f" Vendor: {result['vendor']}")
40
- print(f" Invoice #: {result['invoice_number']}")
41
- print(f" Date: {result['date']}")
42
  print(f" Total: ${result['total']}")
 
1
+ import sys
2
+ sys.path.append('src')
3
+
4
+ from preprocessing import load_image, convert_to_grayscale, remove_noise
5
+ from ocr import extract_text
6
+ from extraction import structure_output
7
+ import json
8
+
9
+ print("=" * 60)
10
+ print("🎯 FULL INVOICE PROCESSING PIPELINE TEST")
11
+ print("=" * 60)
12
+
13
+ # Step 1: Load and preprocess image
14
+ print("\n1️⃣ Loading and preprocessing image...")
15
+ image = load_image('data/raw/receipt3.jpg')
16
+ gray = convert_to_grayscale(image)
17
+ denoised = remove_noise(gray, kernel_size=3)
18
+ print("✅ Image preprocessed")
19
+
20
+ # Step 2: Extract text with OCR
21
+ print("\n2️⃣ Extracting text with OCR...")
22
+ text = extract_text(denoised, config='--psm 6')
23
+ print(f"✅ Extracted {len(text)} characters")
24
+
25
+ # Step 3: Extract structured information
26
+ print("\n3️⃣ Extracting structured information...")
27
+ result = structure_output(text)
28
+ print("✅ Information extracted")
29
+
30
+ # Step 4: Display results
31
+ print("\n" + "=" * 60)
32
+ print("📊 EXTRACTED INVOICE DATA (JSON)")
33
+ print("=" * 60)
34
+ print(json.dumps(result, indent=2, ensure_ascii=False))
35
+ print("=" * 60)
36
+
37
+ print("\n🎉 PIPELINE COMPLETE!")
38
+ print("\n📋 Summary:")
39
+ print(f" Vendor: {result['vendor']}")
40
+ print(f" Invoice #: {result['invoice_number']}")
41
+ print(f" Date: {result['date']}")
42
  print(f" Total: ${result['total']}")
tests/test_ocr.py CHANGED
@@ -1,101 +1,101 @@
1
- import sys
2
- sys.path.append('src')
3
-
4
- from preprocessing import load_image, convert_to_grayscale, remove_noise
5
- from ocr import extract_text
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
-
9
- print("=" * 60)
10
- print("🎯 OPTIMIZING GRAYSCALE OCR")
11
- print("=" * 60)
12
-
13
- # Load and convert to grayscale
14
- image = load_image('data/raw/receipt3.jpg')
15
- gray = convert_to_grayscale(image)
16
-
17
- # Test 1: Different PSM modes
18
- print("\n📊 Testing different Tesseract PSM modes...\n")
19
-
20
- psm_configs = [
21
- ('', 'Default'),
22
- ('--psm 3', 'Automatic page segmentation'),
23
- ('--psm 4', 'Single column of text'),
24
- ('--psm 6', 'Uniform block of text'),
25
- ('--psm 11', 'Sparse text, find as much as possible'),
26
- ('--psm 12', 'Sparse text with OSD (Orientation and Script Detection)'),
27
- ]
28
-
29
- results = {}
30
- for config, desc in psm_configs:
31
- text = extract_text(gray, config=config)
32
- results[desc] = text
33
- print(f"{desc:50s} → {len(text):4d} chars")
34
-
35
- # Find best result
36
- best_desc = max(results, key=lambda k: len(results[k]))
37
- best_text = results[best_desc]
38
-
39
- print(f"\n✅ WINNER: {best_desc} ({len(best_text)} chars)")
40
-
41
- # Test 2: With slight denoising
42
- print("\n📊 Testing with light denoising...\n")
43
-
44
- denoised = remove_noise(gray, kernel_size=3)
45
- text_denoised = extract_text(denoised, config='--psm 6')
46
- print(f"Grayscale + Denoise (psm 6): {len(text_denoised)} chars")
47
-
48
-
49
- # Display best result
50
- print("\n" + "=" * 60)
51
- print("📄 BEST EXTRACTED TEXT:")
52
- print("=" * 60)
53
- print(best_text)
54
- print("=" * 60)
55
-
56
- # Visualize
57
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
58
-
59
- axes[0].imshow(image)
60
- axes[0].set_title("Original")
61
- axes[0].axis('off')
62
-
63
- axes[1].imshow(gray, cmap='gray')
64
- axes[1].set_title(f"Grayscale\n({len(best_text)} chars - {best_desc})")
65
- axes[1].axis('off')
66
-
67
- axes[2].imshow(denoised, cmap='gray')
68
- axes[2].set_title(f"Denoised\n({len(text_denoised)} chars)")
69
- axes[2].axis('off')
70
-
71
- plt.tight_layout()
72
- plt.show()
73
-
74
- print(f"\n💡 Recommended pipeline: Grayscale + {best_desc}")
75
-
76
- # Test the combination we missed!
77
- print("\n📊 Testing BEST combination...\n")
78
-
79
- denoised = remove_noise(gray, kernel_size=3)
80
-
81
- # Test PSM 11 on denoised
82
- text_denoised_psm11 = extract_text(denoised, config='--psm 11')
83
- text_denoised_psm6 = extract_text(denoised, config='--psm 6')
84
-
85
- print(f"Denoised + PSM 6: {len(text_denoised_psm6)} chars")
86
- print(f"Denoised + PSM 11: {len(text_denoised_psm11)} chars")
87
-
88
- if len(text_denoised_psm11) > len(text_denoised_psm6):
89
- print(f"\n✅ PSM 11 wins! ({len(text_denoised_psm11)} chars)")
90
- best_config = '--psm 11'
91
- best_text_final = text_denoised_psm11
92
- else:
93
- print(f"\n✅ PSM 6 wins! ({len(text_denoised_psm6)} chars)")
94
- best_config = '--psm 6'
95
- best_text_final = text_denoised_psm6
96
-
97
- print(f"\n🏆 FINAL WINNER: Denoised + {best_config}")
98
- print("\nFull text:")
99
- print("=" * 60)
100
- print(best_text_final)
101
  print("=" * 60)
 
1
+ import sys
2
+ sys.path.append('src')
3
+
4
+ from preprocessing import load_image, convert_to_grayscale, remove_noise
5
+ from ocr import extract_text
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ print("=" * 60)
10
+ print("🎯 OPTIMIZING GRAYSCALE OCR")
11
+ print("=" * 60)
12
+
13
+ # Load and convert to grayscale
14
+ image = load_image('data/raw/receipt3.jpg')
15
+ gray = convert_to_grayscale(image)
16
+
17
+ # Test 1: Different PSM modes
18
+ print("\n📊 Testing different Tesseract PSM modes...\n")
19
+
20
+ psm_configs = [
21
+ ('', 'Default'),
22
+ ('--psm 3', 'Automatic page segmentation'),
23
+ ('--psm 4', 'Single column of text'),
24
+ ('--psm 6', 'Uniform block of text'),
25
+ ('--psm 11', 'Sparse text, find as much as possible'),
26
+ ('--psm 12', 'Sparse text with OSD (Orientation and Script Detection)'),
27
+ ]
28
+
29
+ results = {}
30
+ for config, desc in psm_configs:
31
+ text = extract_text(gray, config=config)
32
+ results[desc] = text
33
+ print(f"{desc:50s} → {len(text):4d} chars")
34
+
35
+ # Find best result
36
+ best_desc = max(results, key=lambda k: len(results[k]))
37
+ best_text = results[best_desc]
38
+
39
+ print(f"\n✅ WINNER: {best_desc} ({len(best_text)} chars)")
40
+
41
+ # Test 2: With slight denoising
42
+ print("\n📊 Testing with light denoising...\n")
43
+
44
+ denoised = remove_noise(gray, kernel_size=3)
45
+ text_denoised = extract_text(denoised, config='--psm 6')
46
+ print(f"Grayscale + Denoise (psm 6): {len(text_denoised)} chars")
47
+
48
+
49
+ # Display best result
50
+ print("\n" + "=" * 60)
51
+ print("📄 BEST EXTRACTED TEXT:")
52
+ print("=" * 60)
53
+ print(best_text)
54
+ print("=" * 60)
55
+
56
+ # Visualize
57
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
58
+
59
+ axes[0].imshow(image)
60
+ axes[0].set_title("Original")
61
+ axes[0].axis('off')
62
+
63
+ axes[1].imshow(gray, cmap='gray')
64
+ axes[1].set_title(f"Grayscale\n({len(best_text)} chars - {best_desc})")
65
+ axes[1].axis('off')
66
+
67
+ axes[2].imshow(denoised, cmap='gray')
68
+ axes[2].set_title(f"Denoised\n({len(text_denoised)} chars)")
69
+ axes[2].axis('off')
70
+
71
+ plt.tight_layout()
72
+ plt.show()
73
+
74
+ print(f"\n💡 Recommended pipeline: Grayscale + {best_desc}")
75
+
76
+ # Test the combination we missed!
77
+ print("\n📊 Testing BEST combination...\n")
78
+
79
+ denoised = remove_noise(gray, kernel_size=3)
80
+
81
+ # Test PSM 11 on denoised
82
+ text_denoised_psm11 = extract_text(denoised, config='--psm 11')
83
+ text_denoised_psm6 = extract_text(denoised, config='--psm 6')
84
+
85
+ print(f"Denoised + PSM 6: {len(text_denoised_psm6)} chars")
86
+ print(f"Denoised + PSM 11: {len(text_denoised_psm11)} chars")
87
+
88
+ if len(text_denoised_psm11) > len(text_denoised_psm6):
89
+ print(f"\n✅ PSM 11 wins! ({len(text_denoised_psm11)} chars)")
90
+ best_config = '--psm 11'
91
+ best_text_final = text_denoised_psm11
92
+ else:
93
+ print(f"\n✅ PSM 6 wins! ({len(text_denoised_psm6)} chars)")
94
+ best_config = '--psm 6'
95
+ best_text_final = text_denoised_psm6
96
+
97
+ print(f"\n🏆 FINAL WINNER: Denoised + {best_config}")
98
+ print("\nFull text:")
99
+ print("=" * 60)
100
+ print(best_text_final)
101
  print("=" * 60)
tests/test_pipeline.py CHANGED
@@ -1,96 +1,96 @@
1
- import sys
2
- import json
3
- from pathlib import Path
4
-
5
- # Add the 'src' directory to the Python path
6
- sys.path.append('src')
7
-
8
- from pipeline import process_invoice
9
-
10
- def test_full_pipeline():
11
- """
12
- Tests the full invoice processing pipeline on a sample receipt
13
- and prints the advanced JSON structure.
14
- """
15
- print("=" * 60)
16
- print("🎯 ADVANCED INVOICE PROCESSING PIPELINE TEST")
17
- print("=" * 60)
18
-
19
- # --- Configuration ---
20
- image_path = 'data/raw/receipt1.jpg'
21
- save_output = True
22
- output_dir = 'outputs'
23
-
24
- # Check if the image exists
25
- if not Path(image_path).exists():
26
- print(f"❌ ERROR: Test image not found at '{image_path}'")
27
- return
28
-
29
- # --- Processing ---
30
- print(f"\n🔄 Processing invoice: {image_path}...")
31
- try:
32
- # Call the main processing function
33
- result = process_invoice(image_path, save_results=save_output, output_dir=output_dir)
34
- print("✅ Invoice processed successfully!")
35
- except Exception as e:
36
- print(f"❌ An error occurred during processing: {e}")
37
- # Print traceback for detailed debugging
38
- import traceback
39
- traceback.print_exc()
40
- return
41
-
42
- # --- Display Results ---
43
- print("\n" + "=" * 60)
44
- print("📊 EXTRACTED INVOICE DATA (Advanced JSON)")
45
- print("=" * 60)
46
-
47
- # Pretty-print the JSON to the console
48
- print(json.dumps(result, indent=2, ensure_ascii=False))
49
-
50
- print("\n" + "=" * 60)
51
- print("📋 SUMMARY OF KEY EXTRACTED FIELDS")
52
- print("=" * 60)
53
-
54
- # --- Print a clean summary ---
55
- print(f"📄 Receipt Number: {result.get('receipt_number', 'N/A')}")
56
- print(f"📅 Date: {result.get('date', 'N/A')}")
57
-
58
- # Print Bill To info safely
59
- bill_to = result.get('bill_to')
60
- if bill_to and isinstance(bill_to, dict):
61
- print(f"👤 Bill To: {bill_to.get('name', 'N/A')}")
62
- else:
63
- print("👤 Bill To: N/A")
64
-
65
- # Print line items
66
- print("\n🛒 Line Items:")
67
- items = result.get('items', [])
68
- if items:
69
- for i, item in enumerate(items, 1):
70
- desc = item.get('description', 'No Description')
71
- qty = item.get('quantity', 1)
72
- total = item.get('total', 0.0)
73
- print(f" - Item {i}: {desc[:40]:<40} | Qty: {qty} | Total: {total:.2f}")
74
- else:
75
- print(" - No line items extracted.")
76
-
77
- # Print total and validation status
78
- print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0):.2f}")
79
-
80
- confidence = result.get('extraction_confidence', 0)
81
- print(f"📈 Confidence: {confidence}%")
82
-
83
- validation = "✅ Passed" if result.get('validation_passed', False) else "❌ Failed"
84
- print(f"✔️ Validation: {validation}")
85
-
86
- print("\n" + "=" * 60)
87
-
88
- if save_output:
89
- json_path = Path(output_dir) / (Path(image_path).stem + '.json')
90
- print(f"\n💾 Full JSON output saved to: {json_path}")
91
-
92
- print("\n🎉 PIPELINE TEST COMPLETE!")
93
-
94
-
95
- if __name__ == '__main__':
96
  test_full_pipeline()
 
1
+ import sys
2
+ import json
3
+ from pathlib import Path
4
+
5
+ # Add the 'src' directory to the Python path
6
+ sys.path.append('src')
7
+
8
+ from pipeline import process_invoice
9
+
10
+ def test_full_pipeline():
11
+ """
12
+ Tests the full invoice processing pipeline on a sample receipt
13
+ and prints the advanced JSON structure.
14
+ """
15
+ print("=" * 60)
16
+ print("🎯 ADVANCED INVOICE PROCESSING PIPELINE TEST")
17
+ print("=" * 60)
18
+
19
+ # --- Configuration ---
20
+ image_path = 'data/raw/receipt1.jpg'
21
+ save_output = True
22
+ output_dir = 'outputs'
23
+
24
+ # Check if the image exists
25
+ if not Path(image_path).exists():
26
+ print(f"❌ ERROR: Test image not found at '{image_path}'")
27
+ return
28
+
29
+ # --- Processing ---
30
+ print(f"\n🔄 Processing invoice: {image_path}...")
31
+ try:
32
+ # Call the main processing function
33
+ result = process_invoice(image_path, save_results=save_output, output_dir=output_dir)
34
+ print("✅ Invoice processed successfully!")
35
+ except Exception as e:
36
+ print(f"❌ An error occurred during processing: {e}")
37
+ # Print traceback for detailed debugging
38
+ import traceback
39
+ traceback.print_exc()
40
+ return
41
+
42
+ # --- Display Results ---
43
+ print("\n" + "=" * 60)
44
+ print("📊 EXTRACTED INVOICE DATA (Advanced JSON)")
45
+ print("=" * 60)
46
+
47
+ # Pretty-print the JSON to the console
48
+ print(json.dumps(result, indent=2, ensure_ascii=False))
49
+
50
+ print("\n" + "=" * 60)
51
+ print("📋 SUMMARY OF KEY EXTRACTED FIELDS")
52
+ print("=" * 60)
53
+
54
+ # --- Print a clean summary ---
55
+ print(f"📄 Receipt Number: {result.get('receipt_number', 'N/A')}")
56
+ print(f"📅 Date: {result.get('date', 'N/A')}")
57
+
58
+ # Print Bill To info safely
59
+ bill_to = result.get('bill_to')
60
+ if bill_to and isinstance(bill_to, dict):
61
+ print(f"👤 Bill To: {bill_to.get('name', 'N/A')}")
62
+ else:
63
+ print("👤 Bill To: N/A")
64
+
65
+ # Print line items
66
+ print("\n🛒 Line Items:")
67
+ items = result.get('items', [])
68
+ if items:
69
+ for i, item in enumerate(items, 1):
70
+ desc = item.get('description', 'No Description')
71
+ qty = item.get('quantity', 1)
72
+ total = item.get('total', 0.0)
73
+ print(f" - Item {i}: {desc[:40]:<40} | Qty: {qty} | Total: {total:.2f}")
74
+ else:
75
+ print(" - No line items extracted.")
76
+
77
+ # Print total and validation status
78
+ print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0):.2f}")
79
+
80
+ confidence = result.get('extraction_confidence', 0)
81
+ print(f"📈 Confidence: {confidence}%")
82
+
83
+ validation = "✅ Passed" if result.get('validation_passed', False) else "❌ Failed"
84
+ print(f"✔️ Validation: {validation}")
85
+
86
+ print("\n" + "=" * 60)
87
+
88
+ if save_output:
89
+ json_path = Path(output_dir) / (Path(image_path).stem + '.json')
90
+ print(f"\n💾 Full JSON output saved to: {json_path}")
91
+
92
+ print("\n🎉 PIPELINE TEST COMPLETE!")
93
+
94
+
95
+ if __name__ == '__main__':
96
  test_full_pipeline()
tests/test_preprocessing.py CHANGED
@@ -1,177 +1,177 @@
1
- import sys
2
- sys.path.append('src') # So Python can find our modules
3
-
4
- from preprocessing import load_image, convert_to_grayscale, remove_noise, binarize, preprocess_pipeline
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
-
8
- # Test 1: Load a valid image
9
- print("Test 1: Loading receipt1.jpg...")
10
- image = load_image('data/raw/receipt1.jpg')
11
- print(f"✅ Success! Image shape: {image.shape}")
12
- print(f" Data type: {image.dtype}")
13
- print(f" Value range: {image.min()} to {image.max()}")
14
-
15
- # Test 2: Visualize it
16
- print("\nTest 2: Displaying image...")
17
- plt.imshow(image)
18
- plt.title("Loaded Receipt")
19
- plt.axis('off')
20
- plt.show()
21
- print("✅ If you see the receipt image, it worked!")
22
-
23
- # Test 3: Try loading non-existent file
24
- print("\nTest 3: Testing error handling...")
25
- try:
26
- load_image('data/raw/fake_image.jpg')
27
- print("❌ Should have raised FileNotFoundError!")
28
- except FileNotFoundError as e:
29
- print(f"✅ Correctly raised error: {e}")
30
-
31
- # Test 4: Grayscale conversion
32
- print("\nTest 4: Converting to grayscale...")
33
- gray = convert_to_grayscale(image)
34
- print(f"✅ Success! Grayscale shape: {gray.shape}")
35
- print(f" Original had 3 channels, now has: {len(gray.shape)} dimensions")
36
-
37
- # Visualize side-by-side
38
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
39
- ax1.imshow(image)
40
- ax1.set_title("Original (RGB)")
41
- ax1.axis('off')
42
-
43
- ax2.imshow(gray, cmap='gray') # cmap='gray' tells matplotlib to display in grayscale
44
- ax2.set_title("Grayscale")
45
- ax2.axis('off')
46
-
47
- plt.tight_layout()
48
- plt.show()
49
-
50
- # Test 5: Already grayscale (should return as-is)
51
- print("\nTest 5: Converting already-grayscale image...")
52
- gray_again = convert_to_grayscale(gray)
53
- print(f"✅ Returned without error: {gray_again.shape}")
54
- assert gray_again is gray, "Should return same object if already grayscale"
55
- print("✅ Correctly returned the same image!")
56
-
57
- print("\n🎉 Grayscale tests passed!")
58
-
59
- # Test 6: Binarization - Simple method
60
- print("\nTest 6: Simple binarization...")
61
- binary_simple = binarize(gray, method='simple')
62
- print(f"✅ Success! Binary shape: {binary_simple.shape}")
63
- print(f" Unique values: {np.unique(binary_simple)}") # Should be [0, 255]
64
-
65
- # Test 7: Binarization - Adaptive method
66
- print("\nTest 7: Adaptive binarization...")
67
- binary_adaptive = binarize(gray, method='adaptive', block_size=11, C=2)
68
- print(f"✅ Success! Binary shape: {binary_adaptive.shape}")
69
- print(f" Unique values: {np.unique(binary_adaptive)}")
70
-
71
- # Visualize comparison
72
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
73
-
74
- axes[0, 0].imshow(image)
75
- axes[0, 0].set_title("1. Original (RGB)")
76
- axes[0, 0].axis('off')
77
-
78
- axes[0, 1].imshow(gray, cmap='gray')
79
- axes[0, 1].set_title("2. Grayscale")
80
- axes[0, 1].axis('off')
81
-
82
- axes[1, 0].imshow(binary_simple, cmap='gray')
83
- axes[1, 0].set_title("3. Simple Threshold")
84
- axes[1, 0].axis('off')
85
-
86
- axes[1, 1].imshow(binary_adaptive, cmap='gray')
87
- axes[1, 1].set_title("4. Adaptive Threshold")
88
- axes[1, 1].axis('off')
89
-
90
- plt.tight_layout()
91
- plt.show()
92
-
93
- # Test 8: Error handling
94
- print("\nTest 8: Testing error handling...")
95
- try:
96
- binarize(image, method='adaptive') # RGB image (3D) should fail
97
- print("❌ Should have raised ValueError!")
98
- except ValueError as e:
99
- print(f"✅ Correctly raised error: {e}")
100
-
101
- print("\n🎉 Binarization tests passed!")
102
-
103
- # Test 9: Noise removal
104
- print("\nTest 9: Noise removal...")
105
- denoised = remove_noise(gray, kernel_size=3)
106
- print(f"✅ Success! Denoised shape: {denoised.shape}")
107
-
108
- # Test different kernel sizes
109
- denoised_light = remove_noise(gray, kernel_size=3)
110
- denoised_heavy = remove_noise(gray, kernel_size=7)
111
-
112
- # Visualize comparison
113
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
114
-
115
- axes[0].imshow(gray, cmap='gray')
116
- axes[0].set_title("Original Grayscale")
117
- axes[0].axis('off')
118
-
119
- axes[1].imshow(denoised_light, cmap='gray')
120
- axes[1].set_title("Denoised (kernel=3)")
121
- axes[1].axis('off')
122
-
123
- axes[2].imshow(denoised_heavy, cmap='gray')
124
- axes[2].set_title("Denoised (kernel=7)")
125
- axes[2].axis('off')
126
-
127
- plt.tight_layout()
128
- plt.show()
129
- print(" Notice: kernel=7 is blurrier but removes more noise")
130
-
131
- # Test 10: Error handling
132
- print("\nTest 10: Noise removal error handling...")
133
- try:
134
- remove_noise(gray, kernel_size=4) # Even number
135
- print("❌ Should have raised ValueError!")
136
- except ValueError as e:
137
- print(f"✅ Correctly raised error: {e}")
138
-
139
- print("\n🎉 Noise removal tests passed!")
140
-
141
- # Test 11: Full pipeline
142
- print("\nTest 11: Full preprocessing pipeline...")
143
-
144
- # Test with all steps
145
- full_processed = preprocess_pipeline(image,
146
- steps=['grayscale', 'denoise', 'binarize'],
147
- denoise_kernel=3,
148
- binarize_method='adaptive')
149
- print(f"✅ Full pipeline success! Shape: {full_processed.shape}")
150
-
151
- # Test with selective steps (your clean images)
152
- clean_processed = preprocess_pipeline(image,
153
- steps=['grayscale', 'binarize'],
154
- binarize_method='adaptive')
155
- print(f"✅ Clean pipeline success! Shape: {clean_processed.shape}")
156
-
157
- # Visualize comparison
158
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
159
-
160
- axes[0].imshow(image)
161
- axes[0].set_title("Original")
162
- axes[0].axis('off')
163
-
164
- axes[1].imshow(full_processed, cmap='gray')
165
- axes[1].set_title("Full Pipeline\n(grayscale → denoise → binarize)")
166
- axes[1].axis('off')
167
-
168
- axes[2].imshow(clean_processed, cmap='gray')
169
- axes[2].set_title("Clean Pipeline\n(grayscale → binarize)")
170
- axes[2].axis('off')
171
-
172
- plt.tight_layout()
173
- plt.show()
174
-
175
- print("\n🎉 Pipeline tests passed!")
176
-
177
- print("\n🎉 All tests passed!")
 
1
+ import sys
2
+ sys.path.append('src') # So Python can find our modules
3
+
4
+ from preprocessing import load_image, convert_to_grayscale, remove_noise, binarize, preprocess_pipeline
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Test 1: Load a valid image
9
+ print("Test 1: Loading receipt1.jpg...")
10
+ image = load_image('data/raw/receipt1.jpg')
11
+ print(f"✅ Success! Image shape: {image.shape}")
12
+ print(f" Data type: {image.dtype}")
13
+ print(f" Value range: {image.min()} to {image.max()}")
14
+
15
+ # Test 2: Visualize it
16
+ print("\nTest 2: Displaying image...")
17
+ plt.imshow(image)
18
+ plt.title("Loaded Receipt")
19
+ plt.axis('off')
20
+ plt.show()
21
+ print("✅ If you see the receipt image, it worked!")
22
+
23
+ # Test 3: Try loading non-existent file
24
+ print("\nTest 3: Testing error handling...")
25
+ try:
26
+ load_image('data/raw/fake_image.jpg')
27
+ print("❌ Should have raised FileNotFoundError!")
28
+ except FileNotFoundError as e:
29
+ print(f"✅ Correctly raised error: {e}")
30
+
31
+ # Test 4: Grayscale conversion
32
+ print("\nTest 4: Converting to grayscale...")
33
+ gray = convert_to_grayscale(image)
34
+ print(f"✅ Success! Grayscale shape: {gray.shape}")
35
+ print(f" Original had 3 channels, now has: {len(gray.shape)} dimensions")
36
+
37
+ # Visualize side-by-side
38
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
39
+ ax1.imshow(image)
40
+ ax1.set_title("Original (RGB)")
41
+ ax1.axis('off')
42
+
43
+ ax2.imshow(gray, cmap='gray') # cmap='gray' tells matplotlib to display in grayscale
44
+ ax2.set_title("Grayscale")
45
+ ax2.axis('off')
46
+
47
+ plt.tight_layout()
48
+ plt.show()
49
+
50
+ # Test 5: Already grayscale (should return as-is)
51
+ print("\nTest 5: Converting already-grayscale image...")
52
+ gray_again = convert_to_grayscale(gray)
53
+ print(f"✅ Returned without error: {gray_again.shape}")
54
+ assert gray_again is gray, "Should return same object if already grayscale"
55
+ print("✅ Correctly returned the same image!")
56
+
57
+ print("\n🎉 Grayscale tests passed!")
58
+
59
+ # Test 6: Binarization - Simple method
60
+ print("\nTest 6: Simple binarization...")
61
+ binary_simple = binarize(gray, method='simple')
62
+ print(f"✅ Success! Binary shape: {binary_simple.shape}")
63
+ print(f" Unique values: {np.unique(binary_simple)}") # Should be [0, 255]
64
+
65
+ # Test 7: Binarization - Adaptive method
66
+ print("\nTest 7: Adaptive binarization...")
67
+ binary_adaptive = binarize(gray, method='adaptive', block_size=11, C=2)
68
+ print(f"✅ Success! Binary shape: {binary_adaptive.shape}")
69
+ print(f" Unique values: {np.unique(binary_adaptive)}")
70
+
71
+ # Visualize comparison
72
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
73
+
74
+ axes[0, 0].imshow(image)
75
+ axes[0, 0].set_title("1. Original (RGB)")
76
+ axes[0, 0].axis('off')
77
+
78
+ axes[0, 1].imshow(gray, cmap='gray')
79
+ axes[0, 1].set_title("2. Grayscale")
80
+ axes[0, 1].axis('off')
81
+
82
+ axes[1, 0].imshow(binary_simple, cmap='gray')
83
+ axes[1, 0].set_title("3. Simple Threshold")
84
+ axes[1, 0].axis('off')
85
+
86
+ axes[1, 1].imshow(binary_adaptive, cmap='gray')
87
+ axes[1, 1].set_title("4. Adaptive Threshold")
88
+ axes[1, 1].axis('off')
89
+
90
+ plt.tight_layout()
91
+ plt.show()
92
+
93
+ # Test 8: Error handling
94
+ print("\nTest 8: Testing error handling...")
95
+ try:
96
+ binarize(image, method='adaptive') # RGB image (3D) should fail
97
+ print("❌ Should have raised ValueError!")
98
+ except ValueError as e:
99
+ print(f"✅ Correctly raised error: {e}")
100
+
101
+ print("\n🎉 Binarization tests passed!")
102
+
103
+ # Test 9: Noise removal
104
+ print("\nTest 9: Noise removal...")
105
+ denoised = remove_noise(gray, kernel_size=3)
106
+ print(f"✅ Success! Denoised shape: {denoised.shape}")
107
+
108
+ # Test different kernel sizes
109
+ denoised_light = remove_noise(gray, kernel_size=3)
110
+ denoised_heavy = remove_noise(gray, kernel_size=7)
111
+
112
+ # Visualize comparison
113
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
114
+
115
+ axes[0].imshow(gray, cmap='gray')
116
+ axes[0].set_title("Original Grayscale")
117
+ axes[0].axis('off')
118
+
119
+ axes[1].imshow(denoised_light, cmap='gray')
120
+ axes[1].set_title("Denoised (kernel=3)")
121
+ axes[1].axis('off')
122
+
123
+ axes[2].imshow(denoised_heavy, cmap='gray')
124
+ axes[2].set_title("Denoised (kernel=7)")
125
+ axes[2].axis('off')
126
+
127
+ plt.tight_layout()
128
+ plt.show()
129
+ print(" Notice: kernel=7 is blurrier but removes more noise")
130
+
131
+ # Test 10: Error handling
132
+ print("\nTest 10: Noise removal error handling...")
133
+ try:
134
+ remove_noise(gray, kernel_size=4) # Even number
135
+ print("❌ Should have raised ValueError!")
136
+ except ValueError as e:
137
+ print(f"✅ Correctly raised error: {e}")
138
+
139
+ print("\n🎉 Noise removal tests passed!")
140
+
141
+ # Test 11: Full pipeline
142
+ print("\nTest 11: Full preprocessing pipeline...")
143
+
144
+ # Test with all steps
145
+ full_processed = preprocess_pipeline(image,
146
+ steps=['grayscale', 'denoise', 'binarize'],
147
+ denoise_kernel=3,
148
+ binarize_method='adaptive')
149
+ print(f"✅ Full pipeline success! Shape: {full_processed.shape}")
150
+
151
+ # Test with selective steps (your clean images)
152
+ clean_processed = preprocess_pipeline(image,
153
+ steps=['grayscale', 'binarize'],
154
+ binarize_method='adaptive')
155
+ print(f"✅ Clean pipeline success! Shape: {clean_processed.shape}")
156
+
157
+ # Visualize comparison
158
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
159
+
160
+ axes[0].imshow(image)
161
+ axes[0].set_title("Original")
162
+ axes[0].axis('off')
163
+
164
+ axes[1].imshow(full_processed, cmap='gray')
165
+ axes[1].set_title("Full Pipeline\n(grayscale → denoise → binarize)")
166
+ axes[1].axis('off')
167
+
168
+ axes[2].imshow(clean_processed, cmap='gray')
169
+ axes[2].set_title("Clean Pipeline\n(grayscale → binarize)")
170
+ axes[2].axis('off')
171
+
172
+ plt.tight_layout()
173
+ plt.show()
174
+
175
+ print("\n🎉 Pipeline tests passed!")
176
+
177
+ print("\n🎉 All tests passed!")
tests/utils.py CHANGED
@@ -1,7 +1,7 @@
1
- def save_image(image, path):
2
-
3
- def visualize_boxes(image, boxes, text):
4
-
5
- def validate_output(data):
6
-
7
  def format_currency(amount):
 
1
+ def save_image(image, path):
2
+
3
+ def visualize_boxes(image, boxes, text):
4
+
5
+ def validate_output(data):
6
+
7
  def format_currency(amount):
train_combined.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ from seqeval.metrics import f1_score
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import random
10
+ import os
11
+
12
+ # --- IMPORTS ---
13
+ from load_sroie_dataset import load_sroie
14
+ from src.data_loader import load_unified_dataset
15
+
16
+ # --- CONFIGURATION ---
17
+ # Points to your local SROIE copy
18
+ SROIE_DATA_PATH = "data/sroie"
19
+ MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
20
+ OUTPUT_DIR = "models/layoutlmv3-generalized"
21
+
22
+ # Standard Label Set
23
+ LABEL_LIST = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
24
+ 'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL',
25
+ 'B-INVOICE_NO', 'I-INVOICE_NO','B-BILL_TO', 'I-BILL_TO']
26
+ label2id = {label: idx for idx, label in enumerate(LABEL_LIST)}
27
+ id2label = {idx: label for idx, label in enumerate(LABEL_LIST)}
28
+
29
+ class UnifiedDataset(Dataset):
30
+ def __init__(self, data, processor, label2id):
31
+ self.data = data
32
+ self.processor = processor
33
+ self.label2id = label2id
34
+
35
+ def __len__(self):
36
+ return len(self.data)
37
+
38
+ def __getitem__(self, idx):
39
+ example = self.data[idx]
40
+
41
+ # 1. Image Loading
42
+ try:
43
+ if 'image' in example and isinstance(example['image'], Image.Image):
44
+ image = example['image']
45
+ elif 'image_path' in example:
46
+ image = Image.open(example['image_path']).convert("RGB")
47
+ else:
48
+ image = Image.new('RGB', (224, 224), color='white')
49
+ except Exception:
50
+ image = Image.new('RGB', (224, 224), color='white')
51
+
52
+ # 2. Boxes are ALREADY normalized!
53
+ # Just need to ensure they are integers and valid
54
+ boxes = []
55
+ for box in example['bboxes']:
56
+ # Extra safety clamp, just in case
57
+ safe_box = [
58
+ max(0, min(int(box[0]), 1000)),
59
+ max(0, min(int(box[1]), 1000)),
60
+ max(0, min(int(box[2]), 1000)),
61
+ max(0, min(int(box[3]), 1000))
62
+ ]
63
+ boxes.append(safe_box)
64
+
65
+ # 3. Label Encoding
66
+ word_labels = []
67
+ for label in example['ner_tags']:
68
+ word_labels.append(self.label2id.get(label, 0))
69
+
70
+ # 4. Processor Encoding
71
+ encoding = self.processor(
72
+ image,
73
+ text=example['words'],
74
+ boxes=boxes,
75
+ word_labels=word_labels,
76
+ truncation=True,
77
+ padding="max_length",
78
+ max_length=512,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ return {k: v.squeeze(0) for k, v in encoding.items()}
83
+
84
+ def train():
85
+ print(f"{'='*40}\n🚀 STARTING HYBRID TRAINING\n{'='*40}")
86
+
87
+ # Check SROIE path
88
+ if not os.path.exists(SROIE_DATA_PATH):
89
+ print(f"❌ Error: SROIE path not found at {SROIE_DATA_PATH}")
90
+ print("Please make sure you copied the 'sroie' folder into 'data/'.")
91
+ return
92
+
93
+ # 1. Load SROIE
94
+ print("📦 Loading SROIE dataset...")
95
+ sroie_data = load_sroie(SROIE_DATA_PATH)
96
+ print(f" - SROIE Train: {len(sroie_data['train'])}")
97
+ print(f" - SROIE Test: {len(sroie_data['test'])}")
98
+
99
+ # 2. Load New Dataset
100
+ print("📦 Loading General Invoice dataset...")
101
+ # Reduced sample size slightly to stay safe on RAM
102
+ new_data = load_unified_dataset(split='train', sample_size=600)
103
+
104
+ random.shuffle(new_data)
105
+ split_idx = int(len(new_data) * 0.9)
106
+ new_train = new_data[:split_idx]
107
+ new_test = new_data[split_idx:]
108
+
109
+ print(f" - General Train: {len(new_train)}")
110
+ print(f" - General Test: {len(new_test)}")
111
+
112
+ # 3. Merge
113
+ full_train_data = sroie_data['train'] + new_train
114
+ full_test_data = sroie_data['test'] + new_test
115
+ print(f"\n🔗 COMBINED DATASET SIZE: {len(full_train_data)} Training Images")
116
+
117
+ # 4. Setup Model
118
+ processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)
119
+ model = LayoutLMv3ForTokenClassification.from_pretrained(
120
+ MODEL_CHECKPOINT, num_labels=len(LABEL_LIST),
121
+ id2label=id2label, label2id=label2id
122
+ )
123
+
124
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125
+ model.to(device)
126
+ print(f" - Device: {device}")
127
+
128
+ # 5. Dataloaders
129
+ train_ds = UnifiedDataset(full_train_data, processor, label2id)
130
+ test_ds = UnifiedDataset(full_test_data, processor, label2id)
131
+
132
+ collator = DataCollatorForTokenClassification(processor.tokenizer, padding=True, return_tensors="pt")
133
+ train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collator)
134
+ test_loader = DataLoader(test_ds, batch_size=2, collate_fn=collator)
135
+
136
+ # 6. Optimize & Train
137
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
138
+ best_f1 = 0.0
139
+ NUM_EPOCHS = 5
140
+
141
+ print("\n🔥 Beginning Fine-Tuning...")
142
+ for epoch in range(NUM_EPOCHS):
143
+ model.train()
144
+ total_loss = 0
145
+
146
+ progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
147
+ for batch in progress:
148
+ batch = {k: v.to(device) for k, v in batch.items()}
149
+
150
+ optimizer.zero_grad()
151
+ outputs = model(**batch)
152
+ loss = outputs.loss
153
+ loss.backward()
154
+ optimizer.step()
155
+
156
+ total_loss += loss.item()
157
+ progress.set_postfix({"loss": f"{loss.item():.4f}"})
158
+
159
+ # --- Evaluation ---
160
+ model.eval()
161
+ all_preds, all_labels = [], []
162
+ print(" Running Validation...")
163
+ with torch.no_grad():
164
+ for batch in test_loader:
165
+ batch = {k: v.to(device) for k, v in batch.items()}
166
+ outputs = model(**batch)
167
+ predictions = outputs.logits.argmax(dim=-1)
168
+ labels = batch['labels']
169
+
170
+ for i in range(len(labels)):
171
+ true_labels = [id2label[l.item()] for l in labels[i] if l.item() != -100]
172
+ pred_labels = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
173
+ all_labels.append(true_labels)
174
+ all_preds.append(pred_labels)
175
+
176
+ f1 = f1_score(all_labels, all_preds)
177
+ print(f" 📊 Epoch {epoch+1} F1 Score: {f1:.4f}")
178
+
179
+ if f1 > best_f1:
180
+ best_f1 = f1
181
+ print(f" 💾 Saving Improved Model to {OUTPUT_DIR}")
182
+ Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
183
+ model.save_pretrained(OUTPUT_DIR)
184
+ processor.save_pretrained(OUTPUT_DIR)
185
+
186
+ if __name__ == "__main__":
187
+ train()
train_layoutlm.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor, DataCollatorForTokenClassification
4
+ from load_sroie_dataset import load_sroie # Assumes your helper script is in the root
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from seqeval.metrics import f1_score, precision_score, recall_score
8
+ from pathlib import Path
9
+
10
+ # --- 1. Global Configuration & Label Mapping ---
11
+ print("Setting up configuration...")
12
+ label_list = ['O', 'B-COMPANY', 'I-COMPANY', 'B-DATE', 'I-DATE',
13
+ 'B-ADDRESS', 'I-ADDRESS', 'B-TOTAL', 'I-TOTAL']
14
+ label2id = {label: idx for idx, label in enumerate(label_list)}
15
+ id2label = {idx: label for idx, label in enumerate(label_list)}
16
+
17
+ MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
18
+ SROIE_DATA_PATH = "C:\\Users\\Soumyajit Ghosh\\Downloads\\sroie\\sroie" # Make sure this path is correct
19
+
20
+ # --- 2. PyTorch Dataset Class ---
21
+ class SROIEDataset(Dataset):
22
+ """PyTorch Dataset for SROIE data."""
23
+ def __init__(self, data, processor, label2id):
24
+ self.data = data
25
+ self.processor = processor
26
+ self.label2id = label2id
27
+
28
+ def __len__(self):
29
+ return len(self.data)
30
+
31
+ def __getitem__(self, idx):
32
+ example = self.data[idx]
33
+
34
+ # Load image and get its dimensions
35
+ image = Image.open(example['image_path']).convert("RGB")
36
+ width, height = image.size
37
+
38
+ # Normalize bounding boxes
39
+ boxes = []
40
+ for box in example['bboxes']:
41
+ x, y, w, h = box
42
+ x0, y0, x1, y1 = x, y, x + w, y + h
43
+
44
+ x0_norm = int((x0 / width) * 1000)
45
+ y0_norm = int((y0 / height) * 1000)
46
+ x1_norm = int((x1 / width) * 1000)
47
+ y1_norm = int((y1 / height) * 1000)
48
+
49
+ # Clip to ensure all values are within the 0-1000 range
50
+ x0_norm = max(0, min(x0_norm, 1000))
51
+ y0_norm = max(0, min(y0_norm, 1000))
52
+ x1_norm = max(0, min(x1_norm, 1000))
53
+ y1_norm = max(0, min(y1_norm, 1000))
54
+
55
+ boxes.append([x0_norm, y0_norm, x1_norm, y1_norm])
56
+
57
+ # Convert NER tags to IDs
58
+ word_labels = [self.label2id[label] for label in example['ner_tags']]
59
+
60
+ # Use processor to encode everything, with truncation
61
+ encoding = self.processor(
62
+ image,
63
+ text=example['words'],
64
+ boxes=boxes,
65
+ word_labels=word_labels,
66
+ truncation=True,
67
+ max_length=512,
68
+ return_tensors="pt"
69
+ )
70
+
71
+ # Squeeze the batch dimension to get 1D tensors
72
+ item = {key: val.squeeze(0) for key, val in encoding.items()}
73
+ return item
74
+
75
+ # --- 3. Main Training Script ---
76
+ def train():
77
+ """Main function to run the training process."""
78
+ # --- Load Data ---
79
+ print("Loading SROIE dataset...")
80
+ raw_dataset = load_sroie(SROIE_DATA_PATH)
81
+
82
+ # --- Load Processor ---
83
+ print("Creating processor...")
84
+ processor = LayoutLMv3Processor.from_pretrained(MODEL_CHECKPOINT, apply_ocr=False)
85
+
86
+ # --- Create PyTorch Datasets and DataLoaders ---
87
+ print("Creating PyTorch datasets and dataloaders...")
88
+ train_dataset = SROIEDataset(raw_dataset['train'], processor, label2id)
89
+ test_dataset = SROIEDataset(raw_dataset['test'], processor, label2id)
90
+
91
+ data_collator = DataCollatorForTokenClassification(
92
+ tokenizer=processor.tokenizer,
93
+ padding=True,
94
+ return_tensors="pt"
95
+ )
96
+
97
+ train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=data_collator)
98
+ test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)
99
+
100
+ # --- Load Model ---
101
+ print("Loading LayoutLMv3 model for fine-tuning...")
102
+ model = LayoutLMv3ForTokenClassification.from_pretrained(
103
+ MODEL_CHECKPOINT,
104
+ num_labels=len(label_list),
105
+ id2label=id2label,
106
+ label2id=label2id
107
+ )
108
+
109
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+ model.to(device)
111
+ print(f"Training on: {device}")
112
+
113
+ # --- Setup Optimizer ---
114
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
115
+
116
+ # --- Training Loop ---
117
+ best_f1 = 0
118
+ NUM_EPOCHS = 10
119
+
120
+ for epoch in range(NUM_EPOCHS):
121
+ print(f"\n{'='*60}\nEpoch {epoch + 1}/{NUM_EPOCHS}\n{'='*60}")
122
+
123
+ # --- Training Step ---
124
+ model.train()
125
+ total_train_loss = 0
126
+ train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}")
127
+ for batch in train_progress_bar:
128
+ batch = {k: v.to(device) for k, v in batch.items()}
129
+
130
+ outputs = model(**batch)
131
+ loss = outputs.loss
132
+
133
+ loss.backward()
134
+ optimizer.step()
135
+ optimizer.zero_grad()
136
+
137
+ total_train_loss += loss.item()
138
+ train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
139
+
140
+ avg_train_loss = total_train_loss / len(train_dataloader)
141
+
142
+ # --- Validation Step ---
143
+ model.eval()
144
+ all_predictions = []
145
+ all_labels = []
146
+ with torch.no_grad():
147
+ for batch in tqdm(test_dataloader, desc="Validation"):
148
+ batch = {k: v.to(device) for k, v in batch.items()}
149
+ outputs = model(**batch)
150
+
151
+ predictions = outputs.logits.argmax(dim=-1)
152
+ labels = batch['labels']
153
+
154
+ for i in range(labels.shape[0]):
155
+ true_labels_i = [id2label[l.item()] for l in labels[i] if l.item() != -100]
156
+ pred_labels_i = [id2label[p.item()] for p, l in zip(predictions[i], labels[i]) if l.item() != -100]
157
+ all_labels.append(true_labels_i)
158
+ all_predictions.append(pred_labels_i)
159
+
160
+ # --- Calculate Metrics ---
161
+ f1 = f1_score(all_labels, all_predictions)
162
+ precision = precision_score(all_labels, all_predictions)
163
+ recall = recall_score(all_labels, all_predictions)
164
+
165
+ print(f"\n📊 Epoch {epoch + 1} Results:")
166
+ print(f" Train Loss: {avg_train_loss:.4f}")
167
+ print(f" F1 Score: {f1:.4f}")
168
+ print(f" Precision: {precision:.4f}")
169
+ print(f" Recall: {recall:.4f}")
170
+
171
+ # --- Save Best Model ---
172
+ if f1 > best_f1:
173
+ best_f1 = f1
174
+ print(f" 🌟 New best F1! Saving model...")
175
+ save_path = Path("./models/layoutlmv3-sroie-best")
176
+ save_path.mkdir(parents=True, exist_ok=True)
177
+ model.save_pretrained(save_path)
178
+ processor.save_pretrained(save_path)
179
+
180
+ print(f"\n🎉 TRAINING COMPLETE! Best F1 Score: {best_f1:.4f}")
181
+ print(f"Model saved to: ./models/layoutlmv3-sroie-best")
182
+
183
+
184
+ if __name__ == '__main__':
185
+ train()