Commit
·
2a944a5
1
Parent(s):
097a95c
feat: PDF preview, database integration, and improved error handling
Browse files- Add PDF preview support using pdf2image
- Enable bounding box overlay visualization for PDFs
- Implement database persistence with SQLModel (Invoice, LineItem)
- Add InvoiceRepository with save and duplicate detection
- Improve DB status messages (show 'unavailable' once at startup)
- Show 'Demo Mode' toast only once per session
- Fix torch.load and transformers deprecation warnings
- Add conda environment.yml for reproducible setup
- Update README with conda installation instructions
- README.md +78 -29
- app.py +88 -39
- docker-compose.yml +1 -1
- environment.yml +31 -0
- requirements.txt +14 -11
- src/api.py +2 -4
- src/database.py +48 -16
- src/extraction.py +67 -13
- src/ml_extraction.py +5 -3
- src/models.py +34 -29
- src/pipeline.py +40 -6
- src/repository.py +66 -20
README.md
CHANGED
|
@@ -46,6 +46,8 @@ A production-grade Hybrid Invoice Extraction System that combines the semantic u
|
|
| 46 |
- **Defensive Data Handling:** Implemented coordinate clamping to prevent model crashes from negative OCR bounding boxes.
|
| 47 |
- **GPU-Accelerated OCR:** DocTR (Mindee) with automatic CUDA acceleration for faster inference in production.
|
| 48 |
- **Clean JSON Output:** Normalized schema handling nested entities, line items, and validation flags.
|
|
|
|
|
|
|
| 49 |
|
| 50 |
### 💻 Usability
|
| 51 |
|
|
@@ -91,6 +93,14 @@ The system outputs a clean JSON with the following fields:
|
|
| 91 |
- `extraction_confidence`: The confidence of the extraction (0-100).
|
| 92 |
- `validation_passed`: Whether the validation passed (true/false).
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
---
|
| 95 |
|
| 96 |
## 📊 Demo
|
|
@@ -156,37 +166,35 @@ _UI shows simple format hints and confidence._
|
|
| 156 |
### Prerequisites
|
| 157 |
|
| 158 |
- Python 3.10+
|
| 159 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
### Installation
|
| 162 |
|
| 163 |
-
1. Clone the repository
|
| 164 |
|
| 165 |
```bash
|
| 166 |
git clone https://github.com/GSoumyajit2005/invoice-processor-ml
|
| 167 |
cd invoice-processor-ml
|
| 168 |
```
|
| 169 |
|
| 170 |
-
2. Create and
|
| 171 |
-
|
| 172 |
-
- **Linux / macOS**:
|
| 173 |
-
|
| 174 |
-
```bash
|
| 175 |
-
python3 -m venv venv
|
| 176 |
-
source venv/bin/activate
|
| 177 |
-
```
|
| 178 |
-
|
| 179 |
-
- **Windows**:
|
| 180 |
|
| 181 |
```bash
|
| 182 |
-
|
| 183 |
-
|
| 184 |
```
|
| 185 |
|
| 186 |
-
3.
|
| 187 |
|
| 188 |
```bash
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
| 190 |
```
|
| 191 |
|
| 192 |
4. Run the web app
|
|
@@ -195,6 +203,9 @@ pip install -r requirements.txt
|
|
| 195 |
streamlit run app.py
|
| 196 |
```
|
| 197 |
|
|
|
|
|
|
|
|
|
|
| 198 |
### Training the Model (Optional)
|
| 199 |
|
| 200 |
To retrain the model from scratch using the provided scripts:
|
|
@@ -205,6 +216,27 @@ python scripts/train_combined.py
|
|
| 205 |
|
| 206 |
(Note: Requires SROIE dataset in data/sroie)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
## 💻 Usage
|
| 209 |
|
| 210 |
### Web Interface (Recommended)
|
|
@@ -290,10 +322,16 @@ print(json.dumps(result, indent=2))
|
|
| 290 |
│ Post-process │
|
| 291 |
│ validate, scores │
|
| 292 |
└────────┬─────────┘
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
```
|
| 298 |
|
| 299 |
## 📁 Project Structure
|
|
@@ -326,14 +364,14 @@ invoice-processor-ml/
|
|
| 326 |
├── src/
|
| 327 |
│ ├── api.py # FastAPI REST endpoint for API access
|
| 328 |
│ ├── data_loader.py # Unified data loader for training
|
| 329 |
-
│ ├── database.py #
|
| 330 |
│ ├── extraction.py # Regex-based information extraction logic
|
| 331 |
│ ├── ml_extraction.py # ML-based extraction (LayoutLMv3 + DocTR)
|
| 332 |
-
│ ├── models.py # SQLModel tables
|
| 333 |
│ ├── pdf_utils.py # PDF text extraction and image conversion
|
| 334 |
│ ├── pipeline.py # Main orchestrator for the pipeline and CLI
|
| 335 |
│ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
|
| 336 |
-
│ ├── repository.py # CRUD operations
|
| 337 |
│ ├── schema.py # Pydantic models for API response validation
|
| 338 |
│ ├── sroie_loader.py # SROIE dataset loading logic
|
| 339 |
│ └── utils.py # Utility functions (semantic hashing, etc.)
|
|
@@ -346,6 +384,10 @@ invoice-processor-ml/
|
|
| 346 |
│
|
| 347 |
├── app.py # Streamlit web interface
|
| 348 |
├── requirements.txt # Python dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
└── README.md # You are Here!
|
| 350 |
```
|
| 351 |
|
|
@@ -364,16 +406,21 @@ invoice-processor-ml/
|
|
| 364 |
## 📈 Performance
|
| 365 |
|
| 366 |
- **OCR Precision**: State-of-the-art hierarchical detection using **DocTR (ResNet-50)**. Outperforms Tesseract on complex/noisy layouts.
|
| 367 |
-
- **ML-based Extraction**:
|
| 368 |
-
- **Accuracy**: ~83% F1 Score on SROIE +
|
| 369 |
-
- **Speed**:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
## ⚠️ Known Limitations
|
| 372 |
|
| 373 |
1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
|
| 374 |
2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
|
| 375 |
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 376 |
-
4. **Inference Latency**:
|
| 377 |
|
| 378 |
## 🔮 Future Enhancements
|
| 379 |
|
|
@@ -386,7 +433,7 @@ invoice-processor-ml/
|
|
| 386 |
- [x] CI/CD pipeline (GitHub Actions → HuggingFace Spaces auto-deploy)
|
| 387 |
- [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
|
| 388 |
- [ ] Confidence calibration and better validation rules
|
| 389 |
-
- [
|
| 390 |
|
| 391 |
## 🛠️ Tech Stack
|
| 392 |
|
|
@@ -400,6 +447,8 @@ invoice-processor-ml/
|
|
| 400 |
| Data Format | JSON |
|
| 401 |
| CI/CD | GitHub Actions → HuggingFace Spaces |
|
| 402 |
| Containerization | Docker |
|
|
|
|
|
|
|
| 403 |
|
| 404 |
## 📚 What I Learned
|
| 405 |
|
|
|
|
| 46 |
- **Defensive Data Handling:** Implemented coordinate clamping to prevent model crashes from negative OCR bounding boxes.
|
| 47 |
- **GPU-Accelerated OCR:** DocTR (Mindee) with automatic CUDA acceleration for faster inference in production.
|
| 48 |
- **Clean JSON Output:** Normalized schema handling nested entities, line items, and validation flags.
|
| 49 |
+
- **Defensive Persistence:** Optional PostgreSQL integration that automatically saves extracted data when credentials are present, but gracefully degrades (skips saving) in serverless/demo environments like Hugging Face Spaces.
|
| 50 |
+
- **Duplicate Prevention:** Implemented *Semantic Hashing* (Vendor + Date + Total + ID) to automatically detect and prevent duplicate invoice entries.
|
| 51 |
|
| 52 |
### 💻 Usability
|
| 53 |
|
|
|
|
| 93 |
- `extraction_confidence`: The confidence of the extraction (0-100).
|
| 94 |
- `validation_passed`: Whether the validation passed (true/false).
|
| 95 |
|
| 96 |
+
### 5. Defensive Database Architecture
|
| 97 |
+
|
| 98 |
+
To support both local development (with full persistence) and lightweight cloud demos (without databases), the system uses a **"Soft Fail" Persistence Layer**:
|
| 99 |
+
|
| 100 |
+
1. **Connection Check:** On startup, the system checks for PostgreSQL credentials. If missing, the database engine is disabled.
|
| 101 |
+
2. **Repository Guard:** All CRUD operations check for an active session. If the database is disabled, save operations are skipped silently without crashing the pipeline.
|
| 102 |
+
3. **Semantic Hashing:** Before saving, a content-based hash is generated to ensure idempotency.
|
| 103 |
+
|
| 104 |
---
|
| 105 |
|
| 106 |
## 📊 Demo
|
|
|
|
| 166 |
### Prerequisites
|
| 167 |
|
| 168 |
- Python 3.10+
|
| 169 |
+
- Conda / Miniforge (recommended)
|
| 170 |
+
- NVIDIA GPU with CUDA (strongly recommended for usable performance)
|
| 171 |
+
|
| 172 |
+
⚠️ CPU-only execution is supported but significantly slower
|
| 173 |
+
(5–10s per invoice) and intended only for testing.
|
| 174 |
|
| 175 |
+
### Installation (Conda – Recommended)
|
| 176 |
|
| 177 |
+
1. Clone the repository:
|
| 178 |
|
| 179 |
```bash
|
| 180 |
git clone https://github.com/GSoumyajit2005/invoice-processor-ml
|
| 181 |
cd invoice-processor-ml
|
| 182 |
```
|
| 183 |
|
| 184 |
+
2. Create and activate the Conda environment:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
```bash
|
| 187 |
+
conda env create -f environment.yml
|
| 188 |
+
conda activate invoice-ml
|
| 189 |
```
|
| 190 |
|
| 191 |
+
3. Verify CUDA availability (recommended):
|
| 192 |
|
| 193 |
```bash
|
| 194 |
+
python - <<EOF
|
| 195 |
+
import torch
|
| 196 |
+
print(torch.cuda.is_available())
|
| 197 |
+
EOF
|
| 198 |
```
|
| 199 |
|
| 200 |
4. Run the web app
|
|
|
|
| 203 |
streamlit run app.py
|
| 204 |
```
|
| 205 |
|
| 206 |
+
> Note: `requirements.txt` is consumed internally by `environment.yml`.
|
| 207 |
+
> Do not install it manually with pip.
|
| 208 |
+
|
| 209 |
### Training the Model (Optional)
|
| 210 |
|
| 211 |
To retrain the model from scratch using the provided scripts:
|
|
|
|
| 216 |
|
| 217 |
(Note: Requires SROIE dataset in data/sroie)
|
| 218 |
|
| 219 |
+
### API Usage (Optional)
|
| 220 |
+
|
| 221 |
+
To run the API server:
|
| 222 |
+
|
| 223 |
+
```bash
|
| 224 |
+
python src/api.py
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
The API provides endpoints for processing invoices and extracting information.
|
| 228 |
+
|
| 229 |
+
### Running with Database (Optional)
|
| 230 |
+
|
| 231 |
+
To enable data persistence, run the included Docker Compose file to spin up PostgreSQL:
|
| 232 |
+
|
| 233 |
+
```bash
|
| 234 |
+
docker-compose up -d
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
The application will automatically detect the database and start saving invoices.
|
| 238 |
+
|
| 239 |
+
|
| 240 |
## 💻 Usage
|
| 241 |
|
| 242 |
### Web Interface (Recommended)
|
|
|
|
| 322 |
│ Post-process │
|
| 323 |
│ validate, scores │
|
| 324 |
└────────┬─────────┘
|
| 325 |
+
│
|
| 326 |
+
┌──────────────┴──────────────┐
|
| 327 |
+
│ │
|
| 328 |
+
▼ ▼
|
| 329 |
+
┌──────────────────┐ ┌────────────────────┐
|
| 330 |
+
│ JSON Output │ │ DB (PostgreSQL) │
|
| 331 |
+
└──────────────────┘ │ (Optional Save) │
|
| 332 |
+
└────────────────────┘
|
| 333 |
+
|
| 334 |
+
|
| 335 |
```
|
| 336 |
|
| 337 |
## 📁 Project Structure
|
|
|
|
| 364 |
├── src/
|
| 365 |
│ ├── api.py # FastAPI REST endpoint for API access
|
| 366 |
│ ├── data_loader.py # Unified data loader for training
|
| 367 |
+
│ ├── database.py # Database connection with environment-aware 'soft fail' check
|
| 368 |
│ ├── extraction.py # Regex-based information extraction logic
|
| 369 |
│ ├── ml_extraction.py # ML-based extraction (LayoutLMv3 + DocTR)
|
| 370 |
+
│ ├── models.py # SQLModel tables (Invoice, LineItem) with schema validation
|
| 371 |
│ ├── pdf_utils.py # PDF text extraction and image conversion
|
| 372 |
│ ├── pipeline.py # Main orchestrator for the pipeline and CLI
|
| 373 |
│ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
|
| 374 |
+
│ ├── repository.py # CRUD operations with session safety handling
|
| 375 |
│ ├── schema.py # Pydantic models for API response validation
|
| 376 |
│ ├── sroie_loader.py # SROIE dataset loading logic
|
| 377 |
│ └── utils.py # Utility functions (semantic hashing, etc.)
|
|
|
|
| 384 |
│
|
| 385 |
├── app.py # Streamlit web interface
|
| 386 |
├── requirements.txt # Python dependencies
|
| 387 |
+
├── environment.yml # Conda environment configuration
|
| 388 |
+
├── docker-compose.yml # Docker Compose configuration for PostgreSQL
|
| 389 |
+
├── Dockerfile # Dockerfile for building the application container
|
| 390 |
+
├── .gitignore # Git ignore file
|
| 391 |
└── README.md # You are Here!
|
| 392 |
```
|
| 393 |
|
|
|
|
| 406 |
## 📈 Performance
|
| 407 |
|
| 408 |
- **OCR Precision**: State-of-the-art hierarchical detection using **DocTR (ResNet-50)**. Outperforms Tesseract on complex/noisy layouts.
|
| 409 |
+
- **ML-based Extraction**:
|
| 410 |
+
- **Accuracy**: ~83% F1 Score on SROIE + custom invoices
|
| 411 |
+
- **Speed**:
|
| 412 |
+
- **GPU (recommended)**: <1s per invoice
|
| 413 |
+
- **CPU (fallback)**: ~5–7s per invoice
|
| 414 |
+
|
| 415 |
+
⚠️ CPU-only execution is supported for testing and experimentation but results
|
| 416 |
+
in significantly higher latency due to the heavy OCR and layout-aware models.
|
| 417 |
|
| 418 |
## ⚠️ Known Limitations
|
| 419 |
|
| 420 |
1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
|
| 421 |
2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
|
| 422 |
3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
|
| 423 |
+
4. **Inference Latency**: CPU execution is significantly slower due to heavy OCR and layout-aware models.
|
| 424 |
|
| 425 |
## 🔮 Future Enhancements
|
| 426 |
|
|
|
|
| 433 |
- [x] CI/CD pipeline (GitHub Actions → HuggingFace Spaces auto-deploy)
|
| 434 |
- [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
|
| 435 |
- [ ] Confidence calibration and better validation rules
|
| 436 |
+
- [x] Database persistence layer (PostgreSQL with SQLModel & Redundancy checks)
|
| 437 |
|
| 438 |
## 🛠️ Tech Stack
|
| 439 |
|
|
|
|
| 447 |
| Data Format | JSON |
|
| 448 |
| CI/CD | GitHub Actions → HuggingFace Spaces |
|
| 449 |
| Containerization | Docker |
|
| 450 |
+
| Database | PostgreSQL, SQLModel |
|
| 451 |
+
| Containerization | Docker & Docker Compose |
|
| 452 |
|
| 453 |
## 📚 What I Learned
|
| 454 |
|
app.py
CHANGED
|
@@ -7,12 +7,21 @@ from PIL import Image, ImageDraw
|
|
| 7 |
import pandas as pd
|
| 8 |
import sys
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# --------------------------------------------------
|
| 11 |
# Pipeline import (PURE DATA ONLY)
|
| 12 |
# --------------------------------------------------
|
| 13 |
-
|
| 14 |
-
from
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# --------------------------------------------------
|
| 18 |
# Mock format detection (UI-level, safe)
|
|
@@ -119,17 +128,22 @@ with tab1:
|
|
| 119 |
|
| 120 |
if uploaded_file:
|
| 121 |
st.caption(f"File: {uploaded_file.name}")
|
| 122 |
-
|
|
|
|
| 123 |
if uploaded_file.type == "application/pdf":
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
image = Image.open(uploaded_file)
|
| 127 |
-
|
| 128 |
-
st.image(
|
| 129 |
-
image,
|
| 130 |
-
width=250,
|
| 131 |
-
caption="Uploaded Invoice"
|
| 132 |
-
)
|
| 133 |
|
| 134 |
|
| 135 |
# -----------------------------
|
|
@@ -149,12 +163,39 @@ with tab1:
|
|
| 149 |
f.write(uploaded_file.getbuffer())
|
| 150 |
|
| 151 |
method = "ml" if "ML" in extraction_method else "rules"
|
|
|
|
|
|
|
| 152 |
result = process_invoice(str(temp_path), method=method)
|
| 153 |
|
| 154 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if not isinstance(result, dict):
|
| 156 |
st.error("Pipeline returned invalid data.")
|
| 157 |
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
st.session_state.data = result
|
| 160 |
st.session_state.format_info = detect_invoice_format(
|
|
@@ -162,35 +203,43 @@ with tab1:
|
|
| 162 |
)
|
| 163 |
st.session_state.processed_count += 1
|
| 164 |
|
| 165 |
-
st.success("Extraction Complete")
|
| 166 |
-
|
| 167 |
# --- AI Detection Overlay Visualization ---
|
| 168 |
raw_predictions = result.get("raw_predictions")
|
| 169 |
-
if raw_predictions
|
| 170 |
-
#
|
| 171 |
-
uploaded_file.
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
except Exception as e:
|
| 196 |
st.error(f"Pipeline error: {e}")
|
|
@@ -276,7 +325,7 @@ with tab2:
|
|
| 276 |
st.image(
|
| 277 |
Image.open(samples[0]),
|
| 278 |
caption=samples[0].name,
|
| 279 |
-
|
| 280 |
)
|
| 281 |
else:
|
| 282 |
st.info("No sample invoices found.")
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import sys
|
| 9 |
|
| 10 |
+
# PDF to image conversion
|
| 11 |
+
try:
|
| 12 |
+
from pdf2image import convert_from_bytes
|
| 13 |
+
PDF_SUPPORT = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
PDF_SUPPORT = False
|
| 16 |
+
|
| 17 |
# --------------------------------------------------
|
| 18 |
# Pipeline import (PURE DATA ONLY)
|
| 19 |
# --------------------------------------------------
|
| 20 |
+
from src.pipeline import process_invoice
|
| 21 |
+
from src.database import init_db
|
| 22 |
|
| 23 |
+
# Initialize database
|
| 24 |
+
init_db()
|
| 25 |
|
| 26 |
# --------------------------------------------------
|
| 27 |
# Mock format detection (UI-level, safe)
|
|
|
|
| 128 |
|
| 129 |
if uploaded_file:
|
| 130 |
st.caption(f"File: {uploaded_file.name}")
|
| 131 |
+
|
| 132 |
+
# Handle PDF preview
|
| 133 |
if uploaded_file.type == "application/pdf":
|
| 134 |
+
if PDF_SUPPORT:
|
| 135 |
+
pdf_bytes = uploaded_file.read()
|
| 136 |
+
uploaded_file.seek(0) # Reset for later processing
|
| 137 |
+
pages = convert_from_bytes(pdf_bytes, first_page=1, last_page=1)
|
| 138 |
+
if pages:
|
| 139 |
+
pdf_preview_image = pages[0]
|
| 140 |
+
st.session_state.pdf_preview = pdf_preview_image
|
| 141 |
+
st.image(pdf_preview_image, width=250, caption="PDF Preview (Page 1)")
|
| 142 |
+
else:
|
| 143 |
+
st.warning("PDF preview requires pdf2image. Install with: `pip install pdf2image`")
|
| 144 |
else:
|
| 145 |
image = Image.open(uploaded_file)
|
| 146 |
+
st.image(image, width=250, caption="Uploaded Invoice")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
# -----------------------------
|
|
|
|
| 163 |
f.write(uploaded_file.getbuffer())
|
| 164 |
|
| 165 |
method = "ml" if "ML" in extraction_method else "rules"
|
| 166 |
+
|
| 167 |
+
# CALL PIPELINE
|
| 168 |
result = process_invoice(str(temp_path), method=method)
|
| 169 |
|
| 170 |
+
# --- SMART STATUS NOTIFICATIONS ---
|
| 171 |
+
db_status = result.get('_db_status', 'disabled')
|
| 172 |
+
|
| 173 |
+
if db_status == 'saved':
|
| 174 |
+
st.success("✅ Extraction & Storage Complete")
|
| 175 |
+
st.toast("Invoice saved to Database!", icon="💾")
|
| 176 |
+
|
| 177 |
+
elif db_status == 'duplicate':
|
| 178 |
+
st.success("✅ Extraction Complete")
|
| 179 |
+
st.toast("Duplicate invoice (already in database)", icon="⚠️")
|
| 180 |
+
|
| 181 |
+
elif db_status == 'disabled':
|
| 182 |
+
st.success("✅ Extraction Complete")
|
| 183 |
+
# Only show "Demo Mode" toast once per session
|
| 184 |
+
if not st.session_state.get('_db_warning_shown', False):
|
| 185 |
+
st.toast("Database disabled (Demo Mode)", icon="ℹ️")
|
| 186 |
+
st.session_state['_db_warning_shown'] = True
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
st.success("✅ Extraction Complete")
|
| 190 |
+
|
| 191 |
+
# Hard guard — prevents DeltaGenerator bugs
|
| 192 |
if not isinstance(result, dict):
|
| 193 |
st.error("Pipeline returned invalid data.")
|
| 194 |
st.stop()
|
| 195 |
+
|
| 196 |
+
# Remove the metadata field so it doesn't show up in the JSON view
|
| 197 |
+
if '_db_status' in result:
|
| 198 |
+
del result['_db_status']
|
| 199 |
|
| 200 |
st.session_state.data = result
|
| 201 |
st.session_state.format_info = detect_invoice_format(
|
|
|
|
| 203 |
)
|
| 204 |
st.session_state.processed_count += 1
|
| 205 |
|
|
|
|
|
|
|
| 206 |
# --- AI Detection Overlay Visualization ---
|
| 207 |
raw_predictions = result.get("raw_predictions")
|
| 208 |
+
if raw_predictions:
|
| 209 |
+
# Get the base image for annotation
|
| 210 |
+
if uploaded_file.type == "application/pdf":
|
| 211 |
+
# Use the converted PDF preview image
|
| 212 |
+
if "pdf_preview" in st.session_state:
|
| 213 |
+
overlay_image = st.session_state.pdf_preview.copy().convert("RGB")
|
| 214 |
+
else:
|
| 215 |
+
overlay_image = None
|
| 216 |
+
else:
|
| 217 |
+
# Reload the original image for annotation
|
| 218 |
+
uploaded_file.seek(0)
|
| 219 |
+
overlay_image = Image.open(uploaded_file).convert("RGB")
|
| 220 |
+
|
| 221 |
+
if overlay_image:
|
| 222 |
+
draw = ImageDraw.Draw(overlay_image)
|
| 223 |
+
|
| 224 |
+
# Draw red rectangles around each detected entity's bounding boxes
|
| 225 |
+
for entity_name, entity_data in raw_predictions.items():
|
| 226 |
+
bboxes = entity_data.get("bbox", [])
|
| 227 |
+
for box in bboxes:
|
| 228 |
+
# bbox format: [x, y, width, height]
|
| 229 |
+
x, y, w, h = box
|
| 230 |
+
draw.rectangle(
|
| 231 |
+
[x, y, x + w, y + h],
|
| 232 |
+
outline="red",
|
| 233 |
+
width=2
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
overlay_image.thumbnail((800, 800))
|
| 237 |
+
|
| 238 |
+
st.image(
|
| 239 |
+
overlay_image,
|
| 240 |
+
caption="AI Detection Overlay",
|
| 241 |
+
width="content"
|
| 242 |
+
)
|
| 243 |
|
| 244 |
except Exception as e:
|
| 245 |
st.error(f"Pipeline error: {e}")
|
|
|
|
| 325 |
st.image(
|
| 326 |
Image.open(samples[0]),
|
| 327 |
caption=samples[0].name,
|
| 328 |
+
width=250
|
| 329 |
)
|
| 330 |
else:
|
| 331 |
st.info("No sample invoices found.")
|
docker-compose.yml
CHANGED
|
@@ -11,7 +11,7 @@ services:
|
|
| 11 |
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
|
| 12 |
POSTGRES_DB: ${POSTGRES_DB:-invoices_db}
|
| 13 |
ports:
|
| 14 |
-
- "
|
| 15 |
volumes:
|
| 16 |
- postgres_data:/var/lib/postgresql/data
|
| 17 |
healthcheck:
|
|
|
|
| 11 |
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
|
| 12 |
POSTGRES_DB: ${POSTGRES_DB:-invoices_db}
|
| 13 |
ports:
|
| 14 |
+
- "5433:5432"
|
| 15 |
volumes:
|
| 16 |
- postgres_data:/var/lib/postgresql/data
|
| 17 |
healthcheck:
|
environment.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: invoice-ml
|
| 2 |
+
|
| 3 |
+
channels:
|
| 4 |
+
- pytorch
|
| 5 |
+
- nvidia
|
| 6 |
+
- conda-forge
|
| 7 |
+
|
| 8 |
+
dependencies:
|
| 9 |
+
# ----- Python -----
|
| 10 |
+
- python=3.10
|
| 11 |
+
- pip
|
| 12 |
+
|
| 13 |
+
# ----- CUDA-enabled PyTorch -----
|
| 14 |
+
- pytorch
|
| 15 |
+
- torchvision
|
| 16 |
+
- torchaudio
|
| 17 |
+
- pytorch-cuda=11.8
|
| 18 |
+
|
| 19 |
+
# ----- Core numeric / system -----
|
| 20 |
+
- numpy
|
| 21 |
+
- certifi
|
| 22 |
+
- openssl
|
| 23 |
+
- ca-certificates
|
| 24 |
+
|
| 25 |
+
# ----- Computer Vision / PDF -----
|
| 26 |
+
- poppler
|
| 27 |
+
- ghostscript
|
| 28 |
+
|
| 29 |
+
# ----- App-level Python deps -----
|
| 30 |
+
- pip:
|
| 31 |
+
- -r requirements.txt
|
requirements.txt
CHANGED
|
@@ -1,23 +1,26 @@
|
|
| 1 |
-
# ----- Streamlit -----
|
| 2 |
streamlit>=1.28.0
|
| 3 |
|
| 4 |
-
# ----- OCR -----
|
| 5 |
-
python-doctr
|
| 6 |
-
opencv-python>=4.8.0
|
| 7 |
Pillow>=10.0.0
|
| 8 |
|
| 9 |
-
# -----
|
| 10 |
-
numpy>=1.24.0
|
| 11 |
-
pandas>=2.0.0
|
| 12 |
-
|
| 13 |
-
# ----- Machine Learning -----
|
| 14 |
-
torch>=2.0.0
|
| 15 |
-
torchvision>=0.15.0
|
| 16 |
transformers>=4.30.0
|
| 17 |
datasets>=2.14.0
|
| 18 |
huggingface-hub>=0.17.0
|
| 19 |
seqeval>=1.2.2
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# ----- Data Validation -----
|
| 22 |
pydantic>=2.12.0
|
| 23 |
|
|
|
|
| 1 |
+
# ----- Streamlit -----
|
| 2 |
streamlit>=1.28.0
|
| 3 |
|
| 4 |
+
# ----- OCR -----
|
| 5 |
+
python-doctr>=0.8.0
|
| 6 |
+
opencv-python-headless>=4.8.0
|
| 7 |
Pillow>=10.0.0
|
| 8 |
|
| 9 |
+
# ----- NLP / Transformers -----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
transformers>=4.30.0
|
| 11 |
datasets>=2.14.0
|
| 12 |
huggingface-hub>=0.17.0
|
| 13 |
seqeval>=1.2.2
|
| 14 |
|
| 15 |
+
# ----- Utilities -----
|
| 16 |
+
python-dotenv>=1.0.0
|
| 17 |
+
httpx>=0.28.0
|
| 18 |
+
tenacity>=8.0.0
|
| 19 |
+
validators>=0.22.0
|
| 20 |
+
langdetect>=1.0.9
|
| 21 |
+
RapidFuzz>=3.0.0
|
| 22 |
+
python-dateutil>=2.9.0
|
| 23 |
+
|
| 24 |
# ----- Data Validation -----
|
| 25 |
pydantic>=2.12.0
|
| 26 |
|
src/api.py
CHANGED
|
@@ -8,10 +8,8 @@ from pathlib import Path
|
|
| 8 |
import uuid
|
| 9 |
import sys
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from pipeline import process_invoice
|
| 14 |
-
from schema import InvoiceData
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
title="Invoice Extraction API",
|
|
|
|
| 8 |
import uuid
|
| 9 |
import sys
|
| 10 |
|
| 11 |
+
from src.pipeline import process_invoice
|
| 12 |
+
from src.schema import InvoiceData
|
|
|
|
|
|
|
| 13 |
|
| 14 |
app = FastAPI(
|
| 15 |
title="Invoice Extraction API",
|
src/database.py
CHANGED
|
@@ -1,33 +1,65 @@
|
|
| 1 |
# src/database.py
|
| 2 |
|
| 3 |
from sqlmodel import SQLModel, create_engine, Session
|
| 4 |
-
from
|
|
|
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
# 1. Get credentials from environment variables (Os.getenv)
|
| 9 |
-
# 2. Construct the DATABASE_URL string: postgresql://user:pass@host:port/db
|
| 10 |
-
# 3. Create the SQLModel engine
|
| 11 |
-
# 4. Implement the init_db and get_session functions
|
| 12 |
|
| 13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def init_db():
|
| 20 |
"""
|
| 21 |
Idempotent DB initialization.
|
| 22 |
-
|
| 23 |
"""
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
def get_session() -> Generator[Session, None, None]:
|
| 28 |
"""
|
| 29 |
Dependency for yielding a database session.
|
| 30 |
-
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# src/database.py
|
| 2 |
|
| 3 |
from sqlmodel import SQLModel, create_engine, Session
|
| 4 |
+
from sqlalchemy import text
|
| 5 |
+
from typing import Generator, Optional
|
| 6 |
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# 1. Get credentials (with defaults to avoid immediate crashes if vars are missing)
|
| 12 |
+
POSTGRES_USER = os.getenv("POSTGRES_USER")
|
| 13 |
+
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
| 14 |
+
POSTGRES_DB = os.getenv("POSTGRES_DB")
|
| 15 |
+
POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
|
| 16 |
+
POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432")
|
| 17 |
|
| 18 |
+
# 2. Construct DATABASE_URL conditionally
|
| 19 |
+
DATABASE_URL = None
|
| 20 |
+
engine = None
|
| 21 |
+
DB_CONNECTED = False # Track actual connection status
|
| 22 |
|
| 23 |
+
if POSTGRES_USER and POSTGRES_PASSWORD and POSTGRES_DB:
|
| 24 |
+
DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
# 3. Create the engine only if we have credentials
|
| 28 |
+
engine = create_engine(DATABASE_URL, echo=False)
|
| 29 |
+
|
| 30 |
+
# 4. Test actual connection (once at startup)
|
| 31 |
+
with engine.connect() as conn:
|
| 32 |
+
conn.execute(text("SELECT 1"))
|
| 33 |
+
DB_CONNECTED = True
|
| 34 |
+
print("✅ Database connection verified.")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"⚠️ Database unavailable: {e}")
|
| 37 |
+
DB_CONNECTED = False
|
| 38 |
+
else:
|
| 39 |
+
print("⚠️ Database credentials missing. Database features will be disabled.")
|
| 40 |
|
| 41 |
def init_db():
|
| 42 |
"""
|
| 43 |
Idempotent DB initialization.
|
| 44 |
+
Only runs if engine is successfully configured AND connected.
|
| 45 |
"""
|
| 46 |
+
if engine and DB_CONNECTED:
|
| 47 |
+
try:
|
| 48 |
+
SQLModel.metadata.create_all(engine)
|
| 49 |
+
print("✅ Database tables created/verified.")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"❌ Error initializing database: {e}")
|
| 52 |
+
# Silent skip when DB is not connected - message already shown at startup
|
| 53 |
|
| 54 |
+
def get_session() -> Generator[Optional[Session], None, None]:
|
| 55 |
"""
|
| 56 |
Dependency for yielding a database session.
|
| 57 |
+
Yields None if database is not configured.
|
| 58 |
"""
|
| 59 |
+
if engine:
|
| 60 |
+
with Session(engine) as session:
|
| 61 |
+
yield session
|
| 62 |
+
else:
|
| 63 |
+
# Yield None so code depending on this doesn't crash immediately,
|
| 64 |
+
# but can check 'if session is None'.
|
| 65 |
+
yield None
|
src/extraction.py
CHANGED
|
@@ -7,29 +7,83 @@ from difflib import SequenceMatcher
|
|
| 7 |
|
| 8 |
def extract_dates(text: str) -> List[str]:
|
| 9 |
"""
|
| 10 |
-
Robust date extraction that handles
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
if not text: return []
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
valid_dates = []
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
try:
|
| 23 |
-
# Try to parse it to check if it's a real date
|
| 24 |
-
# This filters out "99/99/2000" or random phone numbers like 12 34 5678
|
| 25 |
-
# Assuming Day-Month-Year format which is common in SROIE/International
|
| 26 |
-
# For US format, you might swap d and m
|
| 27 |
dt = datetime(int(y), int(m), int(d))
|
| 28 |
valid_dates.append(dt.strftime("%d/%m/%Y"))
|
| 29 |
except ValueError:
|
| 30 |
-
continue
|
| 31 |
|
| 32 |
-
return list(dict.fromkeys(valid_dates))
|
| 33 |
|
| 34 |
def extract_amounts(text: str) -> List[float]:
|
| 35 |
if not text: return []
|
|
|
|
| 7 |
|
| 8 |
def extract_dates(text: str) -> List[str]:
|
| 9 |
"""
|
| 10 |
+
Robust date extraction that handles:
|
| 11 |
+
- Numeric formats: DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY
|
| 12 |
+
- Text month formats: 22 Mar 18, March 22, 2018, 22-Mar-2018
|
| 13 |
+
- OCR noise like pipes (|) instead of slashes
|
| 14 |
+
Validates using datetime to ensure semantic correctness.
|
| 15 |
"""
|
| 16 |
if not text: return []
|
| 17 |
|
| 18 |
+
# Month name mappings
|
| 19 |
+
MONTH_MAP = {
|
| 20 |
+
'jan': 1, 'january': 1,
|
| 21 |
+
'feb': 2, 'february': 2,
|
| 22 |
+
'mar': 3, 'march': 3,
|
| 23 |
+
'apr': 4, 'april': 4,
|
| 24 |
+
'may': 5,
|
| 25 |
+
'jun': 6, 'june': 6,
|
| 26 |
+
'jul': 7, 'july': 7,
|
| 27 |
+
'aug': 8, 'august': 8,
|
| 28 |
+
'sep': 9, 'sept': 9, 'september': 9,
|
| 29 |
+
'oct': 10, 'october': 10,
|
| 30 |
+
'nov': 11, 'november': 11,
|
| 31 |
+
'dec': 12, 'december': 12
|
| 32 |
+
}
|
| 33 |
|
| 34 |
valid_dates = []
|
| 35 |
+
|
| 36 |
+
# Pattern 1: Numeric dates - DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY, DD MM YYYY
|
| 37 |
+
# Also handles OCR noise like pipes (|) instead of slashes
|
| 38 |
+
numeric_pattern = r'\b(\d{1,2})[\s/|.-](\d{1,2})[\s/|.-](\d{2,4})\b'
|
| 39 |
+
for d, m, y in re.findall(numeric_pattern, text):
|
| 40 |
+
try:
|
| 41 |
+
year = int(y)
|
| 42 |
+
if year < 100:
|
| 43 |
+
year = 2000 + year if year < 50 else 1900 + year
|
| 44 |
+
dt = datetime(year, int(m), int(d))
|
| 45 |
+
valid_dates.append(dt.strftime("%d/%m/%Y"))
|
| 46 |
+
except ValueError:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
# Pattern 2: DD Mon YY/YYYY (e.g., "22 Mar 18", "22-Mar-2018", "22 March 2018")
|
| 50 |
+
text_month_pattern1 = r'\b(\d{1,2})[\s/.-]?([A-Za-z]{3,9})[\s/.-]?(\d{2,4})\b'
|
| 51 |
+
for d, m, y in re.findall(text_month_pattern1, text, re.IGNORECASE):
|
| 52 |
+
month_num = MONTH_MAP.get(m.lower())
|
| 53 |
+
if month_num:
|
| 54 |
+
try:
|
| 55 |
+
year = int(y)
|
| 56 |
+
if year < 100:
|
| 57 |
+
year = 2000 + year if year < 50 else 1900 + year
|
| 58 |
+
dt = datetime(year, month_num, int(d))
|
| 59 |
+
valid_dates.append(dt.strftime("%d/%m/%Y"))
|
| 60 |
+
except ValueError:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# Pattern 3: Mon DD, YYYY (e.g., "March 22, 2018", "Mar 22 2018")
|
| 64 |
+
text_month_pattern2 = r'\b([A-Za-z]{3,9})[\s.-]?(\d{1,2})[,\s.-]+(\d{2,4})\b'
|
| 65 |
+
for m, d, y in re.findall(text_month_pattern2, text, re.IGNORECASE):
|
| 66 |
+
month_num = MONTH_MAP.get(m.lower())
|
| 67 |
+
if month_num:
|
| 68 |
+
try:
|
| 69 |
+
year = int(y)
|
| 70 |
+
if year < 100:
|
| 71 |
+
year = 2000 + year if year < 50 else 1900 + year
|
| 72 |
+
dt = datetime(year, month_num, int(d))
|
| 73 |
+
valid_dates.append(dt.strftime("%d/%m/%Y"))
|
| 74 |
+
except ValueError:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
# Pattern 4: YYYY-MM-DD (ISO format)
|
| 78 |
+
iso_pattern = r'\b(\d{4})[-/](\d{1,2})[-/](\d{1,2})\b'
|
| 79 |
+
for y, m, d in re.findall(iso_pattern, text):
|
| 80 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
dt = datetime(int(y), int(m), int(d))
|
| 82 |
valid_dates.append(dt.strftime("%d/%m/%Y"))
|
| 83 |
except ValueError:
|
| 84 |
+
continue
|
| 85 |
|
| 86 |
+
return list(dict.fromkeys(valid_dates)) # Deduplicate while preserving order
|
| 87 |
|
| 88 |
def extract_amounts(text: str) -> List[float]:
|
| 89 |
if not text: return []
|
src/ml_extraction.py
CHANGED
|
@@ -8,7 +8,7 @@ from PIL import Image
|
|
| 8 |
from typing import List, Dict, Any, Tuple
|
| 9 |
import re
|
| 10 |
import numpy as np
|
| 11 |
-
from extraction import extract_invoice_number, extract_total, extract_address
|
| 12 |
from doctr.io import DocumentFile
|
| 13 |
from doctr.models import ocr_predictor
|
| 14 |
|
|
@@ -219,10 +219,12 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
|
|
| 219 |
encoding = PROCESSOR(
|
| 220 |
image, text=words, boxes=normalized_boxes,
|
| 221 |
truncation=True, max_length=512, return_tensors="pt"
|
| 222 |
-
)
|
|
|
|
|
|
|
| 223 |
|
| 224 |
with torch.no_grad():
|
| 225 |
-
outputs = MODEL(**
|
| 226 |
|
| 227 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 228 |
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
|
|
|
| 8 |
from typing import List, Dict, Any, Tuple
|
| 9 |
import re
|
| 10 |
import numpy as np
|
| 11 |
+
from src.extraction import extract_invoice_number, extract_total, extract_address
|
| 12 |
from doctr.io import DocumentFile
|
| 13 |
from doctr.models import ocr_predictor
|
| 14 |
|
|
|
|
| 219 |
encoding = PROCESSOR(
|
| 220 |
image, text=words, boxes=normalized_boxes,
|
| 221 |
truncation=True, max_length=512, return_tensors="pt"
|
| 222 |
+
)
|
| 223 |
+
# Move tensors to device for inference, but keep original encoding for word_ids()
|
| 224 |
+
model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}
|
| 225 |
|
| 226 |
with torch.no_grad():
|
| 227 |
+
outputs = MODEL(**model_inputs)
|
| 228 |
|
| 229 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
| 230 |
extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
|
src/models.py
CHANGED
|
@@ -2,49 +2,54 @@
|
|
| 2 |
|
| 3 |
from typing import List, Optional
|
| 4 |
from datetime import date as DateType
|
|
|
|
| 5 |
from decimal import Decimal
|
| 6 |
from sqlmodel import SQLModel, Field, Relationship
|
| 7 |
|
| 8 |
-
|
| 9 |
-
# SQLModel classes should mirror the Pydantic models in src/schema.py
|
| 10 |
-
# but with database-specific configurations (primary keys, foreign keys).
|
| 11 |
|
| 12 |
class Invoice(SQLModel, table=True):
|
| 13 |
__tablename__ = "invoices"
|
|
|
|
| 14 |
|
| 15 |
-
#
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# - semantic_hash (str, unique, indexed) -> Critical for deduplication
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
# - validation_errors (str) -> Store as JSON string since we don't need to query inside it yet
|
| 28 |
-
# - created_at (DateType) -> Default to today
|
| 29 |
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class LineItem(SQLModel, table=True):
|
| 36 |
__tablename__ = "line_items"
|
|
|
|
| 37 |
|
| 38 |
-
#
|
|
|
|
| 39 |
|
| 40 |
-
#
|
|
|
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
pass
|
|
|
|
| 2 |
|
| 3 |
from typing import List, Optional
|
| 4 |
from datetime import date as DateType
|
| 5 |
+
from datetime import datetime
|
| 6 |
from decimal import Decimal
|
| 7 |
from sqlmodel import SQLModel, Field, Relationship
|
| 8 |
|
| 9 |
+
SQLModel.metadata.clear()
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class Invoice(SQLModel, table=True):
|
| 12 |
__tablename__ = "invoices"
|
| 13 |
+
__table_args__ = {"extend_existing": True}
|
| 14 |
|
| 15 |
+
# Primary Key
|
| 16 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 17 |
|
| 18 |
+
# Data Fields
|
| 19 |
+
receipt_number: Optional[str] = Field(default=None, index=True)
|
| 20 |
+
date: Optional[DateType] = Field(default=None)
|
| 21 |
+
total_amount: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
|
| 22 |
+
vendor: Optional[str] = Field(default=None)
|
| 23 |
+
address: Optional[str] = Field(default=None)
|
|
|
|
| 24 |
|
| 25 |
+
# Critical for Deduplication
|
| 26 |
+
semantic_hash: str = Field(unique=True, index=True)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# Metadata Fields
|
| 29 |
+
validation_status: str = Field(default="unknown")
|
| 30 |
+
# Store validation_errors as a JSON string because SQLModel/SQLite doesn't always support arrays out of the box
|
| 31 |
+
validation_errors: Optional[str] = Field(default="[]")
|
| 32 |
+
created_at: DateType = Field(default_factory=datetime.now)
|
| 33 |
+
|
| 34 |
+
# Relationship to LineItem (One-to-Many)
|
| 35 |
+
items: List["LineItem"] = Relationship(back_populates="invoice")
|
| 36 |
|
| 37 |
|
| 38 |
class LineItem(SQLModel, table=True):
|
| 39 |
__tablename__ = "line_items"
|
| 40 |
+
__table_args__ = {"extend_existing": True}
|
| 41 |
|
| 42 |
+
# Primary Key
|
| 43 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 44 |
|
| 45 |
+
# Foreign Key
|
| 46 |
+
invoice_id: Optional[int] = Field(default=None, foreign_key="invoices.id")
|
| 47 |
|
| 48 |
+
# Data Fields
|
| 49 |
+
description: str
|
| 50 |
+
quantity: int = Field(default=1)
|
| 51 |
+
unit_price: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
|
| 52 |
+
total: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
|
| 53 |
+
|
| 54 |
+
# Relationship back to Invoice
|
| 55 |
+
invoice: Optional[Invoice] = Relationship(back_populates="items")
|
|
|
src/pipeline.py
CHANGED
|
@@ -12,12 +12,14 @@ from pydantic import ValidationError
|
|
| 12 |
import cv2
|
| 13 |
|
| 14 |
# --- IMPORTS ---
|
| 15 |
-
from preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 16 |
-
from extraction import structure_output
|
| 17 |
-
from ml_extraction import extract_ml_based
|
| 18 |
-
from schema import InvoiceData
|
| 19 |
-
from pdf_utils import extract_text_from_pdf, convert_pdf_to_images
|
| 20 |
-
from utils import generate_semantic_hash
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def process_invoice(image_path: str,
|
| 23 |
method: str = 'ml',
|
|
@@ -136,6 +138,38 @@ def process_invoice(image_path: str,
|
|
| 136 |
# We calculate the hash based on the final (or raw) data.
|
| 137 |
# This gives us a unique fingerprint for this specific business transaction.
|
| 138 |
final_data['semantic_hash'] = generate_semantic_hash(final_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# --- SAVING STEP ---
|
| 141 |
if save_results:
|
|
|
|
| 12 |
import cv2
|
| 13 |
|
| 14 |
# --- IMPORTS ---
|
| 15 |
+
from src.preprocessing import load_image, convert_to_grayscale, remove_noise
|
| 16 |
+
from src.extraction import structure_output
|
| 17 |
+
from src.ml_extraction import extract_ml_based
|
| 18 |
+
from src.schema import InvoiceData
|
| 19 |
+
from src.pdf_utils import extract_text_from_pdf, convert_pdf_to_images
|
| 20 |
+
from src.utils import generate_semantic_hash
|
| 21 |
+
from src.repository import InvoiceRepository
|
| 22 |
+
from src.database import DB_CONNECTED
|
| 23 |
|
| 24 |
def process_invoice(image_path: str,
|
| 25 |
method: str = 'ml',
|
|
|
|
| 138 |
# We calculate the hash based on the final (or raw) data.
|
| 139 |
# This gives us a unique fingerprint for this specific business transaction.
|
| 140 |
final_data['semantic_hash'] = generate_semantic_hash(final_data)
|
| 141 |
+
|
| 142 |
+
# --- DATABASE SAVE (The Integration) ---
|
| 143 |
+
if not DB_CONNECTED:
|
| 144 |
+
# Database not available - skip save entirely (message shown once at startup)
|
| 145 |
+
final_data['_db_status'] = 'disabled'
|
| 146 |
+
else:
|
| 147 |
+
final_data['_db_status'] = 'disabled' # Default assumption
|
| 148 |
+
try:
|
| 149 |
+
print("💾 Attempting to save to Database...")
|
| 150 |
+
repo = InvoiceRepository()
|
| 151 |
+
|
| 152 |
+
if repo.session:
|
| 153 |
+
saved_record = repo.save_invoice(final_data)
|
| 154 |
+
if saved_record:
|
| 155 |
+
print(f" ✅ Successfully saved Invoice #{saved_record.id}")
|
| 156 |
+
final_data['_db_status'] = 'saved'
|
| 157 |
+
else:
|
| 158 |
+
# Check if it's a duplicate by looking up the hash
|
| 159 |
+
existing = repo.get_by_hash(final_data.get('semantic_hash', ''))
|
| 160 |
+
if existing:
|
| 161 |
+
print(" ⚠️ Duplicate invoice detected (already in database)")
|
| 162 |
+
final_data['_db_status'] = 'duplicate'
|
| 163 |
+
else:
|
| 164 |
+
print(" ⚠️ Save failed (unknown error)")
|
| 165 |
+
final_data['_db_status'] = 'error'
|
| 166 |
+
else:
|
| 167 |
+
print(" ⚠️ Skipped DB Save (Database disabled)")
|
| 168 |
+
final_data['_db_status'] = 'disabled'
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f" ⚠️ Database Error (Ignored): {e}")
|
| 172 |
+
final_data['_db_status'] = 'error'
|
| 173 |
|
| 174 |
# --- SAVING STEP ---
|
| 175 |
if save_results:
|
src/repository.py
CHANGED
|
@@ -3,39 +3,85 @@
|
|
| 3 |
from sqlmodel import Session, select
|
| 4 |
from typing import Dict, Any, Optional
|
| 5 |
import json
|
|
|
|
| 6 |
|
| 7 |
from src.models import Invoice, LineItem
|
| 8 |
-
from src.database import get_session, engine
|
| 9 |
|
| 10 |
class InvoiceRepository:
|
| 11 |
-
def __init__(self, session: Session = None):
|
| 12 |
"""
|
| 13 |
Initialize with an optional session.
|
| 14 |
-
|
|
|
|
| 15 |
"""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
def save_invoice(self, invoice_data: Dict[str, Any]) -> Invoice:
|
| 19 |
"""
|
| 20 |
Saves an invoice and its line items to the database.
|
| 21 |
-
|
| 22 |
-
Steps to implement:
|
| 23 |
-
1. Manage Session: If self.session is None, create a new one using 'engine'.
|
| 24 |
-
2. Clean Data: Separate 'items' list from the main invoice properties.
|
| 25 |
-
3. Create Invoice: Instantiate the Invoice SQLModel.
|
| 26 |
-
4. Deserialize Complex Types: e.g. 'validation_errors' list -> JSON string.
|
| 27 |
-
5. Process Items: Iterate 'items', create LineItem models, check keys match, and append to invoice.items.
|
| 28 |
-
6. Commit: Add to session, commit, and refresh.
|
| 29 |
-
7. Error Handling: Wrap in try/except to rollback on failure.
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def get_by_hash(self, semantic_hash: str) -> Optional[Invoice]:
|
| 35 |
"""
|
| 36 |
Check if invoice already exists using the semantic hash.
|
| 37 |
"""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from sqlmodel import Session, select
|
| 4 |
from typing import Dict, Any, Optional
|
| 5 |
import json
|
| 6 |
+
from datetime import date
|
| 7 |
|
| 8 |
from src.models import Invoice, LineItem
|
| 9 |
+
from src.database import get_session, engine, DB_CONNECTED
|
| 10 |
|
| 11 |
class InvoiceRepository:
|
| 12 |
+
def __init__(self, session: Optional[Session] = None):
|
| 13 |
"""
|
| 14 |
Initialize with an optional session.
|
| 15 |
+
If no session is provided, try to get a new one from the engine.
|
| 16 |
+
Only creates session if database is actually connected.
|
| 17 |
"""
|
| 18 |
+
if session:
|
| 19 |
+
self.session = session
|
| 20 |
+
elif engine and DB_CONNECTED:
|
| 21 |
+
self.session = Session(engine)
|
| 22 |
+
else:
|
| 23 |
+
self.session = None
|
| 24 |
|
| 25 |
+
def save_invoice(self, invoice_data: Dict[str, Any]) -> Optional[Invoice]:
|
| 26 |
"""
|
| 27 |
Saves an invoice and its line items to the database.
|
| 28 |
+
Returns the saved Invoice object or None if DB is disabled/failed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
+
if not self.session:
|
| 31 |
+
print("⚠️ DB Session missing. Skipping save.")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# 1. Prepare Data
|
| 36 |
+
data = invoice_data.copy()
|
| 37 |
+
|
| 38 |
+
# Serialize complex types (validation_errors)
|
| 39 |
+
if 'validation_errors' in data and isinstance(data['validation_errors'], list):
|
| 40 |
+
data['validation_errors'] = json.dumps(data['validation_errors'])
|
| 41 |
+
|
| 42 |
+
# Extract items to process separately
|
| 43 |
+
items_data = data.pop('items', [])
|
| 44 |
+
|
| 45 |
+
# 2. Create Invoice Record
|
| 46 |
+
invoice = Invoice(**data)
|
| 47 |
+
|
| 48 |
+
# 3. Process Items
|
| 49 |
+
for item in items_data:
|
| 50 |
+
# Ensure item is a dict (if it's a Pydantic model, convert it)
|
| 51 |
+
if hasattr(item, 'model_dump'):
|
| 52 |
+
item_dict = item.model_dump()
|
| 53 |
+
elif isinstance(item, dict):
|
| 54 |
+
item_dict = item
|
| 55 |
+
else:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
line_item = LineItem(**item_dict)
|
| 59 |
+
invoice.items.append(line_item)
|
| 60 |
+
|
| 61 |
+
# 4. Commit
|
| 62 |
+
self.session.add(invoice)
|
| 63 |
+
self.session.commit()
|
| 64 |
+
self.session.refresh(invoice)
|
| 65 |
+
|
| 66 |
+
print(f"✅ Invoice {invoice.id} saved to DB.")
|
| 67 |
+
return invoice
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"❌ Error saving invoice to DB: {e}")
|
| 71 |
+
self.session.rollback()
|
| 72 |
+
return None
|
| 73 |
|
| 74 |
def get_by_hash(self, semantic_hash: str) -> Optional[Invoice]:
|
| 75 |
"""
|
| 76 |
Check if invoice already exists using the semantic hash.
|
| 77 |
"""
|
| 78 |
+
if not self.session:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
statement = select(Invoice).where(Invoice.semantic_hash == semantic_hash)
|
| 83 |
+
results = self.session.exec(statement)
|
| 84 |
+
return results.first()
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"❌ Error checking hash: {e}")
|
| 87 |
+
return None
|