GSoumyajit2005 commited on
Commit
2a944a5
·
1 Parent(s): 097a95c

feat: PDF preview, database integration, and improved error handling

Browse files

- Add PDF preview support using pdf2image
- Enable bounding box overlay visualization for PDFs
- Implement database persistence with SQLModel (Invoice, LineItem)
- Add InvoiceRepository with save and duplicate detection
- Improve DB status messages (show 'unavailable' once at startup)
- Show 'Demo Mode' toast only once per session
- Fix torch.load and transformers deprecation warnings
- Add conda environment.yml for reproducible setup
- Update README with conda installation instructions

Files changed (12) hide show
  1. README.md +78 -29
  2. app.py +88 -39
  3. docker-compose.yml +1 -1
  4. environment.yml +31 -0
  5. requirements.txt +14 -11
  6. src/api.py +2 -4
  7. src/database.py +48 -16
  8. src/extraction.py +67 -13
  9. src/ml_extraction.py +5 -3
  10. src/models.py +34 -29
  11. src/pipeline.py +40 -6
  12. src/repository.py +66 -20
README.md CHANGED
@@ -46,6 +46,8 @@ A production-grade Hybrid Invoice Extraction System that combines the semantic u
46
  - **Defensive Data Handling:** Implemented coordinate clamping to prevent model crashes from negative OCR bounding boxes.
47
  - **GPU-Accelerated OCR:** DocTR (Mindee) with automatic CUDA acceleration for faster inference in production.
48
  - **Clean JSON Output:** Normalized schema handling nested entities, line items, and validation flags.
 
 
49
 
50
  ### 💻 Usability
51
 
@@ -91,6 +93,14 @@ The system outputs a clean JSON with the following fields:
91
  - `extraction_confidence`: The confidence of the extraction (0-100).
92
  - `validation_passed`: Whether the validation passed (true/false).
93
 
 
 
 
 
 
 
 
 
94
  ---
95
 
96
  ## 📊 Demo
@@ -156,37 +166,35 @@ _UI shows simple format hints and confidence._
156
  ### Prerequisites
157
 
158
  - Python 3.10+
159
- - (Optional) CUDA-capable GPU for training/inference speed
 
 
 
 
160
 
161
- ### Installation
162
 
163
- 1. Clone the repository
164
 
165
  ```bash
166
  git clone https://github.com/GSoumyajit2005/invoice-processor-ml
167
  cd invoice-processor-ml
168
  ```
169
 
170
- 2. Create and Activate Virtual Environment (Recommended) Ensures the correct Python version and isolates dependencies.
171
-
172
- - **Linux / macOS**:
173
-
174
- ```bash
175
- python3 -m venv venv
176
- source venv/bin/activate
177
- ```
178
-
179
- - **Windows**:
180
 
181
  ```bash
182
- python -m venv venv
183
- .\venv\Scripts\activate
184
  ```
185
 
186
- 3. Install dependencies
187
 
188
  ```bash
189
- pip install -r requirements.txt
 
 
 
190
  ```
191
 
192
  4. Run the web app
@@ -195,6 +203,9 @@ pip install -r requirements.txt
195
  streamlit run app.py
196
  ```
197
 
 
 
 
198
  ### Training the Model (Optional)
199
 
200
  To retrain the model from scratch using the provided scripts:
@@ -205,6 +216,27 @@ python scripts/train_combined.py
205
 
206
  (Note: Requires SROIE dataset in data/sroie)
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  ## 💻 Usage
209
 
210
  ### Web Interface (Recommended)
@@ -290,10 +322,16 @@ print(json.dumps(result, indent=2))
290
  │ Post-process │
291
  │ validate, scores │
292
  └────────┬─────────┘
293
-
294
- ┌──────────────────┐
295
- JSON Output
296
- └──────────────────┘
 
 
 
 
 
 
297
  ```
298
 
299
  ## 📁 Project Structure
@@ -326,14 +364,14 @@ invoice-processor-ml/
326
  ├── src/
327
  │ ├── api.py # FastAPI REST endpoint for API access
328
  │ ├── data_loader.py # Unified data loader for training
329
- │ ├── database.py # PostgreSQL connection (scaffolded)
330
  │ ├── extraction.py # Regex-based information extraction logic
331
  │ ├── ml_extraction.py # ML-based extraction (LayoutLMv3 + DocTR)
332
- │ ├── models.py # SQLModel tables for persistence (scaffolded)
333
  │ ├── pdf_utils.py # PDF text extraction and image conversion
334
  │ ├── pipeline.py # Main orchestrator for the pipeline and CLI
335
  │ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
336
- │ ├── repository.py # CRUD operations for invoices (scaffolded)
337
  │ ├── schema.py # Pydantic models for API response validation
338
  │ ├── sroie_loader.py # SROIE dataset loading logic
339
  │ └── utils.py # Utility functions (semantic hashing, etc.)
@@ -346,6 +384,10 @@ invoice-processor-ml/
346
 
347
  ├── app.py # Streamlit web interface
348
  ├── requirements.txt # Python dependencies
 
 
 
 
349
  └── README.md # You are Here!
350
  ```
351
 
@@ -364,16 +406,21 @@ invoice-processor-ml/
364
  ## 📈 Performance
365
 
366
  - **OCR Precision**: State-of-the-art hierarchical detection using **DocTR (ResNet-50)**. Outperforms Tesseract on complex/noisy layouts.
367
- - **ML-based Extraction**:
368
- - **Accuracy**: ~83% F1 Score on SROIE + Custom Dataset.
369
- - **Speed**: ~5-7s per invoice (CPU) / <1s (GPU). Prioritizes high precision over raw speed.
 
 
 
 
 
370
 
371
  ## ⚠️ Known Limitations
372
 
373
  1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
374
  2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
375
  3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
376
- 4. **Inference Latency**: Using the heavy **DocTR (ResNet-50)** backbone ensures maximum accuracy but results in higher inference time (~5-7s on CPU) compared to lightweight engines.
377
 
378
  ## 🔮 Future Enhancements
379
 
@@ -386,7 +433,7 @@ invoice-processor-ml/
386
  - [x] CI/CD pipeline (GitHub Actions → HuggingFace Spaces auto-deploy)
387
  - [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
388
  - [ ] Confidence calibration and better validation rules
389
- - [ ] Database persistence layer (PostgreSQL - scaffolded, ready for implementation)
390
 
391
  ## 🛠️ Tech Stack
392
 
@@ -400,6 +447,8 @@ invoice-processor-ml/
400
  | Data Format | JSON |
401
  | CI/CD | GitHub Actions → HuggingFace Spaces |
402
  | Containerization | Docker |
 
 
403
 
404
  ## 📚 What I Learned
405
 
 
46
  - **Defensive Data Handling:** Implemented coordinate clamping to prevent model crashes from negative OCR bounding boxes.
47
  - **GPU-Accelerated OCR:** DocTR (Mindee) with automatic CUDA acceleration for faster inference in production.
48
  - **Clean JSON Output:** Normalized schema handling nested entities, line items, and validation flags.
49
+ - **Defensive Persistence:** Optional PostgreSQL integration that automatically saves extracted data when credentials are present, but gracefully degrades (skips saving) in serverless/demo environments like Hugging Face Spaces.
50
+ - **Duplicate Prevention:** Implemented *Semantic Hashing* (Vendor + Date + Total + ID) to automatically detect and prevent duplicate invoice entries.
51
 
52
  ### 💻 Usability
53
 
 
93
  - `extraction_confidence`: The confidence of the extraction (0-100).
94
  - `validation_passed`: Whether the validation passed (true/false).
95
 
96
+ ### 5. Defensive Database Architecture
97
+
98
+ To support both local development (with full persistence) and lightweight cloud demos (without databases), the system uses a **"Soft Fail" Persistence Layer**:
99
+
100
+ 1. **Connection Check:** On startup, the system checks for PostgreSQL credentials. If missing, the database engine is disabled.
101
+ 2. **Repository Guard:** All CRUD operations check for an active session. If the database is disabled, save operations are skipped silently without crashing the pipeline.
102
+ 3. **Semantic Hashing:** Before saving, a content-based hash is generated to ensure idempotency.
103
+
104
  ---
105
 
106
  ## 📊 Demo
 
166
  ### Prerequisites
167
 
168
  - Python 3.10+
169
+ - Conda / Miniforge (recommended)
170
+ - NVIDIA GPU with CUDA (strongly recommended for usable performance)
171
+
172
+ ⚠️ CPU-only execution is supported but significantly slower
173
+ (5–10s per invoice) and intended only for testing.
174
 
175
+ ### Installation (Conda – Recommended)
176
 
177
+ 1. Clone the repository:
178
 
179
  ```bash
180
  git clone https://github.com/GSoumyajit2005/invoice-processor-ml
181
  cd invoice-processor-ml
182
  ```
183
 
184
+ 2. Create and activate the Conda environment:
 
 
 
 
 
 
 
 
 
185
 
186
  ```bash
187
+ conda env create -f environment.yml
188
+ conda activate invoice-ml
189
  ```
190
 
191
+ 3. Verify CUDA availability (recommended):
192
 
193
  ```bash
194
+ python - <<EOF
195
+ import torch
196
+ print(torch.cuda.is_available())
197
+ EOF
198
  ```
199
 
200
  4. Run the web app
 
203
  streamlit run app.py
204
  ```
205
 
206
+ > Note: `requirements.txt` is consumed internally by `environment.yml`.
207
+ > Do not install it manually with pip.
208
+
209
  ### Training the Model (Optional)
210
 
211
  To retrain the model from scratch using the provided scripts:
 
216
 
217
  (Note: Requires SROIE dataset in data/sroie)
218
 
219
+ ### API Usage (Optional)
220
+
221
+ To run the API server:
222
+
223
+ ```bash
224
+ python src/api.py
225
+ ```
226
+
227
+ The API provides endpoints for processing invoices and extracting information.
228
+
229
+ ### Running with Database (Optional)
230
+
231
+ To enable data persistence, run the included Docker Compose file to spin up PostgreSQL:
232
+
233
+ ```bash
234
+ docker-compose up -d
235
+ ```
236
+
237
+ The application will automatically detect the database and start saving invoices.
238
+
239
+
240
  ## 💻 Usage
241
 
242
  ### Web Interface (Recommended)
 
322
  │ Post-process │
323
  │ validate, scores │
324
  └────────┬─────────┘
325
+
326
+ ┌──────────────┴──────────────┐
327
+
328
+ ▼ ▼
329
+ ┌──────────────────┐ ┌────────────────────┐
330
+ │ JSON Output │ │ DB (PostgreSQL) │
331
+ └──────────────────┘ │ (Optional Save) │
332
+ └────────────────────┘
333
+
334
+
335
  ```
336
 
337
  ## 📁 Project Structure
 
364
  ├── src/
365
  │ ├── api.py # FastAPI REST endpoint for API access
366
  │ ├── data_loader.py # Unified data loader for training
367
+ │ ├── database.py # Database connection with environment-aware 'soft fail' check
368
  │ ├── extraction.py # Regex-based information extraction logic
369
  │ ├── ml_extraction.py # ML-based extraction (LayoutLMv3 + DocTR)
370
+ │ ├── models.py # SQLModel tables (Invoice, LineItem) with schema validation
371
  │ ├── pdf_utils.py # PDF text extraction and image conversion
372
  │ ├── pipeline.py # Main orchestrator for the pipeline and CLI
373
  │ ├── preprocessing.py # Image preprocessing functions (grayscale, denoise)
374
+ │ ├── repository.py # CRUD operations with session safety handling
375
  │ ├── schema.py # Pydantic models for API response validation
376
  │ ├── sroie_loader.py # SROIE dataset loading logic
377
  │ └── utils.py # Utility functions (semantic hashing, etc.)
 
384
 
385
  ├── app.py # Streamlit web interface
386
  ├── requirements.txt # Python dependencies
387
+ ├── environment.yml # Conda environment configuration
388
+ ├── docker-compose.yml # Docker Compose configuration for PostgreSQL
389
+ ├── Dockerfile # Dockerfile for building the application container
390
+ ├── .gitignore # Git ignore file
391
  └── README.md # You are Here!
392
  ```
393
 
 
406
  ## 📈 Performance
407
 
408
  - **OCR Precision**: State-of-the-art hierarchical detection using **DocTR (ResNet-50)**. Outperforms Tesseract on complex/noisy layouts.
409
+ - **ML-based Extraction**:
410
+ - **Accuracy**: ~83% F1 Score on SROIE + custom invoices
411
+ - **Speed**:
412
+ - **GPU (recommended)**: <1s per invoice
413
+ - **CPU (fallback)**: ~5–7s per invoice
414
+
415
+ ⚠️ CPU-only execution is supported for testing and experimentation but results
416
+ in significantly higher latency due to the heavy OCR and layout-aware models.
417
 
418
  ## ⚠️ Known Limitations
419
 
420
  1. **Layout Sensitivity**: The ML model was fine‑tuned on SROIE (retail receipts) and mychen76/invoices-and-receipts_ocr_v1 (English). Professional multi-column invoices may underperform until you fine‑tune on more diverse datasets.
421
  2. **Invoice Number**: SROIE dataset lacks invoice number labels. The system solves this by using the Hybrid Fallback Engine, which successfully extracts invoice numbers using Regex whenever the ML model output is empty.
422
  3. **Line Items/Tables**: Not trained for table extraction yet. Rule-based supports simple totals; table extraction comes later.
423
+ 4. **Inference Latency**: CPU execution is significantly slower due to heavy OCR and layout-aware models.
424
 
425
  ## 🔮 Future Enhancements
426
 
 
433
  - [x] CI/CD pipeline (GitHub Actions → HuggingFace Spaces auto-deploy)
434
  - [ ] Multilingual OCR (PaddleOCR) and multilingual fine‑tuning
435
  - [ ] Confidence calibration and better validation rules
436
+ - [x] Database persistence layer (PostgreSQL with SQLModel & Redundancy checks)
437
 
438
  ## 🛠️ Tech Stack
439
 
 
447
  | Data Format | JSON |
448
  | CI/CD | GitHub Actions → HuggingFace Spaces |
449
  | Containerization | Docker |
450
+ | Database | PostgreSQL, SQLModel |
451
+ | Containerization | Docker & Docker Compose |
452
 
453
  ## 📚 What I Learned
454
 
app.py CHANGED
@@ -7,12 +7,21 @@ from PIL import Image, ImageDraw
7
  import pandas as pd
8
  import sys
9
 
 
 
 
 
 
 
 
10
  # --------------------------------------------------
11
  # Pipeline import (PURE DATA ONLY)
12
  # --------------------------------------------------
13
- sys.path.append("src")
14
- from pipeline import process_invoice
15
 
 
 
16
 
17
  # --------------------------------------------------
18
  # Mock format detection (UI-level, safe)
@@ -119,17 +128,22 @@ with tab1:
119
 
120
  if uploaded_file:
121
  st.caption(f"File: {uploaded_file.name}")
122
-
 
123
  if uploaded_file.type == "application/pdf":
124
- st.info("PDF uploaded (preview not available)")
 
 
 
 
 
 
 
 
 
125
  else:
126
  image = Image.open(uploaded_file)
127
-
128
- st.image(
129
- image,
130
- width=250,
131
- caption="Uploaded Invoice"
132
- )
133
 
134
 
135
  # -----------------------------
@@ -149,12 +163,39 @@ with tab1:
149
  f.write(uploaded_file.getbuffer())
150
 
151
  method = "ml" if "ML" in extraction_method else "rules"
 
 
152
  result = process_invoice(str(temp_path), method=method)
153
 
154
- # Hard guard prevents DeltaGenerator bugs forever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if not isinstance(result, dict):
156
  st.error("Pipeline returned invalid data.")
157
  st.stop()
 
 
 
 
158
 
159
  st.session_state.data = result
160
  st.session_state.format_info = detect_invoice_format(
@@ -162,35 +203,43 @@ with tab1:
162
  )
163
  st.session_state.processed_count += 1
164
 
165
- st.success("Extraction Complete")
166
-
167
  # --- AI Detection Overlay Visualization ---
168
  raw_predictions = result.get("raw_predictions")
169
- if raw_predictions and uploaded_file.type != "application/pdf":
170
- # Reload the original image for annotation
171
- uploaded_file.seek(0)
172
- overlay_image = Image.open(uploaded_file).convert("RGB")
173
- draw = ImageDraw.Draw(overlay_image)
174
-
175
- # Draw red rectangles around each detected entity's bounding boxes
176
- for entity_name, entity_data in raw_predictions.items():
177
- bboxes = entity_data.get("bbox", [])
178
- for box in bboxes:
179
- # bbox format: [x, y, width, height]
180
- x, y, w, h = box
181
- draw.rectangle(
182
- [x, y, x + w, y + h],
183
- outline="red",
184
- width=2
185
- )
186
-
187
- overlay_image.thumbnail((800, 800))
188
-
189
- st.image(
190
- overlay_image,
191
- caption="AI Detection Overlay",
192
- width="content"
193
- )
 
 
 
 
 
 
 
 
 
 
194
 
195
  except Exception as e:
196
  st.error(f"Pipeline error: {e}")
@@ -276,7 +325,7 @@ with tab2:
276
  st.image(
277
  Image.open(samples[0]),
278
  caption=samples[0].name,
279
- use_container_width=True
280
  )
281
  else:
282
  st.info("No sample invoices found.")
 
7
  import pandas as pd
8
  import sys
9
 
10
+ # PDF to image conversion
11
+ try:
12
+ from pdf2image import convert_from_bytes
13
+ PDF_SUPPORT = True
14
+ except ImportError:
15
+ PDF_SUPPORT = False
16
+
17
  # --------------------------------------------------
18
  # Pipeline import (PURE DATA ONLY)
19
  # --------------------------------------------------
20
+ from src.pipeline import process_invoice
21
+ from src.database import init_db
22
 
23
+ # Initialize database
24
+ init_db()
25
 
26
  # --------------------------------------------------
27
  # Mock format detection (UI-level, safe)
 
128
 
129
  if uploaded_file:
130
  st.caption(f"File: {uploaded_file.name}")
131
+
132
+ # Handle PDF preview
133
  if uploaded_file.type == "application/pdf":
134
+ if PDF_SUPPORT:
135
+ pdf_bytes = uploaded_file.read()
136
+ uploaded_file.seek(0) # Reset for later processing
137
+ pages = convert_from_bytes(pdf_bytes, first_page=1, last_page=1)
138
+ if pages:
139
+ pdf_preview_image = pages[0]
140
+ st.session_state.pdf_preview = pdf_preview_image
141
+ st.image(pdf_preview_image, width=250, caption="PDF Preview (Page 1)")
142
+ else:
143
+ st.warning("PDF preview requires pdf2image. Install with: `pip install pdf2image`")
144
  else:
145
  image = Image.open(uploaded_file)
146
+ st.image(image, width=250, caption="Uploaded Invoice")
 
 
 
 
 
147
 
148
 
149
  # -----------------------------
 
163
  f.write(uploaded_file.getbuffer())
164
 
165
  method = "ml" if "ML" in extraction_method else "rules"
166
+
167
+ # CALL PIPELINE
168
  result = process_invoice(str(temp_path), method=method)
169
 
170
+ # --- SMART STATUS NOTIFICATIONS ---
171
+ db_status = result.get('_db_status', 'disabled')
172
+
173
+ if db_status == 'saved':
174
+ st.success("✅ Extraction & Storage Complete")
175
+ st.toast("Invoice saved to Database!", icon="💾")
176
+
177
+ elif db_status == 'duplicate':
178
+ st.success("✅ Extraction Complete")
179
+ st.toast("Duplicate invoice (already in database)", icon="⚠️")
180
+
181
+ elif db_status == 'disabled':
182
+ st.success("✅ Extraction Complete")
183
+ # Only show "Demo Mode" toast once per session
184
+ if not st.session_state.get('_db_warning_shown', False):
185
+ st.toast("Database disabled (Demo Mode)", icon="ℹ️")
186
+ st.session_state['_db_warning_shown'] = True
187
+
188
+ else:
189
+ st.success("✅ Extraction Complete")
190
+
191
+ # Hard guard — prevents DeltaGenerator bugs
192
  if not isinstance(result, dict):
193
  st.error("Pipeline returned invalid data.")
194
  st.stop()
195
+
196
+ # Remove the metadata field so it doesn't show up in the JSON view
197
+ if '_db_status' in result:
198
+ del result['_db_status']
199
 
200
  st.session_state.data = result
201
  st.session_state.format_info = detect_invoice_format(
 
203
  )
204
  st.session_state.processed_count += 1
205
 
 
 
206
  # --- AI Detection Overlay Visualization ---
207
  raw_predictions = result.get("raw_predictions")
208
+ if raw_predictions:
209
+ # Get the base image for annotation
210
+ if uploaded_file.type == "application/pdf":
211
+ # Use the converted PDF preview image
212
+ if "pdf_preview" in st.session_state:
213
+ overlay_image = st.session_state.pdf_preview.copy().convert("RGB")
214
+ else:
215
+ overlay_image = None
216
+ else:
217
+ # Reload the original image for annotation
218
+ uploaded_file.seek(0)
219
+ overlay_image = Image.open(uploaded_file).convert("RGB")
220
+
221
+ if overlay_image:
222
+ draw = ImageDraw.Draw(overlay_image)
223
+
224
+ # Draw red rectangles around each detected entity's bounding boxes
225
+ for entity_name, entity_data in raw_predictions.items():
226
+ bboxes = entity_data.get("bbox", [])
227
+ for box in bboxes:
228
+ # bbox format: [x, y, width, height]
229
+ x, y, w, h = box
230
+ draw.rectangle(
231
+ [x, y, x + w, y + h],
232
+ outline="red",
233
+ width=2
234
+ )
235
+
236
+ overlay_image.thumbnail((800, 800))
237
+
238
+ st.image(
239
+ overlay_image,
240
+ caption="AI Detection Overlay",
241
+ width="content"
242
+ )
243
 
244
  except Exception as e:
245
  st.error(f"Pipeline error: {e}")
 
325
  st.image(
326
  Image.open(samples[0]),
327
  caption=samples[0].name,
328
+ width=250
329
  )
330
  else:
331
  st.info("No sample invoices found.")
docker-compose.yml CHANGED
@@ -11,7 +11,7 @@ services:
11
  POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
12
  POSTGRES_DB: ${POSTGRES_DB:-invoices_db}
13
  ports:
14
- - "5432:5432"
15
  volumes:
16
  - postgres_data:/var/lib/postgresql/data
17
  healthcheck:
 
11
  POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
12
  POSTGRES_DB: ${POSTGRES_DB:-invoices_db}
13
  ports:
14
+ - "5433:5432"
15
  volumes:
16
  - postgres_data:/var/lib/postgresql/data
17
  healthcheck:
environment.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: invoice-ml
2
+
3
+ channels:
4
+ - pytorch
5
+ - nvidia
6
+ - conda-forge
7
+
8
+ dependencies:
9
+ # ----- Python -----
10
+ - python=3.10
11
+ - pip
12
+
13
+ # ----- CUDA-enabled PyTorch -----
14
+ - pytorch
15
+ - torchvision
16
+ - torchaudio
17
+ - pytorch-cuda=11.8
18
+
19
+ # ----- Core numeric / system -----
20
+ - numpy
21
+ - certifi
22
+ - openssl
23
+ - ca-certificates
24
+
25
+ # ----- Computer Vision / PDF -----
26
+ - poppler
27
+ - ghostscript
28
+
29
+ # ----- App-level Python deps -----
30
+ - pip:
31
+ - -r requirements.txt
requirements.txt CHANGED
@@ -1,23 +1,26 @@
1
- # ----- Streamlit -----
2
  streamlit>=1.28.0
3
 
4
- # ----- OCR -----
5
- python-doctr[torch]>=0.8.0
6
- opencv-python>=4.8.0
7
  Pillow>=10.0.0
8
 
9
- # ----- Data -----
10
- numpy>=1.24.0
11
- pandas>=2.0.0
12
-
13
- # ----- Machine Learning -----
14
- torch>=2.0.0
15
- torchvision>=0.15.0
16
  transformers>=4.30.0
17
  datasets>=2.14.0
18
  huggingface-hub>=0.17.0
19
  seqeval>=1.2.2
20
 
 
 
 
 
 
 
 
 
 
21
  # ----- Data Validation -----
22
  pydantic>=2.12.0
23
 
 
1
+ # ----- Streamlit -----
2
  streamlit>=1.28.0
3
 
4
+ # ----- OCR -----
5
+ python-doctr>=0.8.0
6
+ opencv-python-headless>=4.8.0
7
  Pillow>=10.0.0
8
 
9
+ # ----- NLP / Transformers -----
 
 
 
 
 
 
10
  transformers>=4.30.0
11
  datasets>=2.14.0
12
  huggingface-hub>=0.17.0
13
  seqeval>=1.2.2
14
 
15
+ # ----- Utilities -----
16
+ python-dotenv>=1.0.0
17
+ httpx>=0.28.0
18
+ tenacity>=8.0.0
19
+ validators>=0.22.0
20
+ langdetect>=1.0.9
21
+ RapidFuzz>=3.0.0
22
+ python-dateutil>=2.9.0
23
+
24
  # ----- Data Validation -----
25
  pydantic>=2.12.0
26
 
src/api.py CHANGED
@@ -8,10 +8,8 @@ from pathlib import Path
8
  import uuid
9
  import sys
10
 
11
- # Import modules
12
- sys.path.append(str(Path(__file__).resolve().parent))
13
- from pipeline import process_invoice
14
- from schema import InvoiceData
15
 
16
  app = FastAPI(
17
  title="Invoice Extraction API",
 
8
  import uuid
9
  import sys
10
 
11
+ from src.pipeline import process_invoice
12
+ from src.schema import InvoiceData
 
 
13
 
14
  app = FastAPI(
15
  title="Invoice Extraction API",
src/database.py CHANGED
@@ -1,33 +1,65 @@
1
  # src/database.py
2
 
3
  from sqlmodel import SQLModel, create_engine, Session
4
- from typing import Generator
 
5
  import os
 
6
 
7
- # --- INSTRUCTIONS ---
8
- # 1. Get credentials from environment variables (Os.getenv)
9
- # 2. Construct the DATABASE_URL string: postgresql://user:pass@host:port/db
10
- # 3. Create the SQLModel engine
11
- # 4. Implement the init_db and get_session functions
12
 
13
- # TODO: Define constants for DB params (POSTGRES_USER, etc.)
 
 
 
 
 
14
 
15
- # TODO: Define DATABASE_URL
 
 
 
16
 
17
- # TODO: Create 'engine' using create_engine(DATABASE_URL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def init_db():
20
  """
21
  Idempotent DB initialization.
22
- Should create all tables defined in SQLModel metadata.
23
  """
24
- # TODO: Call SQLModel.metadata.create_all(engine)
25
- pass
 
 
 
 
 
26
 
27
- def get_session() -> Generator[Session, None, None]:
28
  """
29
  Dependency for yielding a database session.
30
- Useful for FastAPI dependencies or context usage.
31
  """
32
- # TODO: Open a session with the engine, yield it, and ensure it closes
33
- pass
 
 
 
 
 
 
1
  # src/database.py
2
 
3
  from sqlmodel import SQLModel, create_engine, Session
4
+ from sqlalchemy import text
5
+ from typing import Generator, Optional
6
  import os
7
+ from dotenv import load_dotenv
8
 
9
+ load_dotenv()
 
 
 
 
10
 
11
+ # 1. Get credentials (with defaults to avoid immediate crashes if vars are missing)
12
+ POSTGRES_USER = os.getenv("POSTGRES_USER")
13
+ POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
14
+ POSTGRES_DB = os.getenv("POSTGRES_DB")
15
+ POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
16
+ POSTGRES_PORT = os.getenv("POSTGRES_PORT", "5432")
17
 
18
+ # 2. Construct DATABASE_URL conditionally
19
+ DATABASE_URL = None
20
+ engine = None
21
+ DB_CONNECTED = False # Track actual connection status
22
 
23
+ if POSTGRES_USER and POSTGRES_PASSWORD and POSTGRES_DB:
24
+ DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
25
+
26
+ try:
27
+ # 3. Create the engine only if we have credentials
28
+ engine = create_engine(DATABASE_URL, echo=False)
29
+
30
+ # 4. Test actual connection (once at startup)
31
+ with engine.connect() as conn:
32
+ conn.execute(text("SELECT 1"))
33
+ DB_CONNECTED = True
34
+ print("✅ Database connection verified.")
35
+ except Exception as e:
36
+ print(f"⚠️ Database unavailable: {e}")
37
+ DB_CONNECTED = False
38
+ else:
39
+ print("⚠️ Database credentials missing. Database features will be disabled.")
40
 
41
  def init_db():
42
  """
43
  Idempotent DB initialization.
44
+ Only runs if engine is successfully configured AND connected.
45
  """
46
+ if engine and DB_CONNECTED:
47
+ try:
48
+ SQLModel.metadata.create_all(engine)
49
+ print("✅ Database tables created/verified.")
50
+ except Exception as e:
51
+ print(f"❌ Error initializing database: {e}")
52
+ # Silent skip when DB is not connected - message already shown at startup
53
 
54
+ def get_session() -> Generator[Optional[Session], None, None]:
55
  """
56
  Dependency for yielding a database session.
57
+ Yields None if database is not configured.
58
  """
59
+ if engine:
60
+ with Session(engine) as session:
61
+ yield session
62
+ else:
63
+ # Yield None so code depending on this doesn't crash immediately,
64
+ # but can check 'if session is None'.
65
+ yield None
src/extraction.py CHANGED
@@ -7,29 +7,83 @@ from difflib import SequenceMatcher
7
 
8
  def extract_dates(text: str) -> List[str]:
9
  """
10
- Robust date extraction that handles noisy OCR separators (spaces, pipes, dots)
11
- and validates using datetime to ensure semantic correctness.
 
 
 
12
  """
13
  if not text: return []
14
 
15
- # Matches DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY, DD MM YYYY
16
- # Also handles OCR noise like pipes (|) instead of slashes
17
- pattern = r'\b(\d{1,2})[\s/|.-](\d{1,2})[\s/|.-](\d{2,4})\b'
18
- matches = re.findall(pattern, text)
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  valid_dates = []
21
- for d, m, y in matches:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
- # Try to parse it to check if it's a real date
24
- # This filters out "99/99/2000" or random phone numbers like 12 34 5678
25
- # Assuming Day-Month-Year format which is common in SROIE/International
26
- # For US format, you might swap d and m
27
  dt = datetime(int(y), int(m), int(d))
28
  valid_dates.append(dt.strftime("%d/%m/%Y"))
29
  except ValueError:
30
- continue # Invalid date logic (e.g. Month 13 or Day 32)
31
 
32
- return list(dict.fromkeys(valid_dates)) # Deduplicate
33
 
34
  def extract_amounts(text: str) -> List[float]:
35
  if not text: return []
 
7
 
8
  def extract_dates(text: str) -> List[str]:
9
  """
10
+ Robust date extraction that handles:
11
+ - Numeric formats: DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY
12
+ - Text month formats: 22 Mar 18, March 22, 2018, 22-Mar-2018
13
+ - OCR noise like pipes (|) instead of slashes
14
+ Validates using datetime to ensure semantic correctness.
15
  """
16
  if not text: return []
17
 
18
+ # Month name mappings
19
+ MONTH_MAP = {
20
+ 'jan': 1, 'january': 1,
21
+ 'feb': 2, 'february': 2,
22
+ 'mar': 3, 'march': 3,
23
+ 'apr': 4, 'april': 4,
24
+ 'may': 5,
25
+ 'jun': 6, 'june': 6,
26
+ 'jul': 7, 'july': 7,
27
+ 'aug': 8, 'august': 8,
28
+ 'sep': 9, 'sept': 9, 'september': 9,
29
+ 'oct': 10, 'october': 10,
30
+ 'nov': 11, 'november': 11,
31
+ 'dec': 12, 'december': 12
32
+ }
33
 
34
  valid_dates = []
35
+
36
+ # Pattern 1: Numeric dates - DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY, DD MM YYYY
37
+ # Also handles OCR noise like pipes (|) instead of slashes
38
+ numeric_pattern = r'\b(\d{1,2})[\s/|.-](\d{1,2})[\s/|.-](\d{2,4})\b'
39
+ for d, m, y in re.findall(numeric_pattern, text):
40
+ try:
41
+ year = int(y)
42
+ if year < 100:
43
+ year = 2000 + year if year < 50 else 1900 + year
44
+ dt = datetime(year, int(m), int(d))
45
+ valid_dates.append(dt.strftime("%d/%m/%Y"))
46
+ except ValueError:
47
+ continue
48
+
49
+ # Pattern 2: DD Mon YY/YYYY (e.g., "22 Mar 18", "22-Mar-2018", "22 March 2018")
50
+ text_month_pattern1 = r'\b(\d{1,2})[\s/.-]?([A-Za-z]{3,9})[\s/.-]?(\d{2,4})\b'
51
+ for d, m, y in re.findall(text_month_pattern1, text, re.IGNORECASE):
52
+ month_num = MONTH_MAP.get(m.lower())
53
+ if month_num:
54
+ try:
55
+ year = int(y)
56
+ if year < 100:
57
+ year = 2000 + year if year < 50 else 1900 + year
58
+ dt = datetime(year, month_num, int(d))
59
+ valid_dates.append(dt.strftime("%d/%m/%Y"))
60
+ except ValueError:
61
+ continue
62
+
63
+ # Pattern 3: Mon DD, YYYY (e.g., "March 22, 2018", "Mar 22 2018")
64
+ text_month_pattern2 = r'\b([A-Za-z]{3,9})[\s.-]?(\d{1,2})[,\s.-]+(\d{2,4})\b'
65
+ for m, d, y in re.findall(text_month_pattern2, text, re.IGNORECASE):
66
+ month_num = MONTH_MAP.get(m.lower())
67
+ if month_num:
68
+ try:
69
+ year = int(y)
70
+ if year < 100:
71
+ year = 2000 + year if year < 50 else 1900 + year
72
+ dt = datetime(year, month_num, int(d))
73
+ valid_dates.append(dt.strftime("%d/%m/%Y"))
74
+ except ValueError:
75
+ continue
76
+
77
+ # Pattern 4: YYYY-MM-DD (ISO format)
78
+ iso_pattern = r'\b(\d{4})[-/](\d{1,2})[-/](\d{1,2})\b'
79
+ for y, m, d in re.findall(iso_pattern, text):
80
  try:
 
 
 
 
81
  dt = datetime(int(y), int(m), int(d))
82
  valid_dates.append(dt.strftime("%d/%m/%Y"))
83
  except ValueError:
84
+ continue
85
 
86
+ return list(dict.fromkeys(valid_dates)) # Deduplicate while preserving order
87
 
88
  def extract_amounts(text: str) -> List[float]:
89
  if not text: return []
src/ml_extraction.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image
8
  from typing import List, Dict, Any, Tuple
9
  import re
10
  import numpy as np
11
- from extraction import extract_invoice_number, extract_total, extract_address
12
  from doctr.io import DocumentFile
13
  from doctr.models import ocr_predictor
14
 
@@ -219,10 +219,12 @@ def extract_ml_based(image_path: str) -> Dict[str, Any]:
219
  encoding = PROCESSOR(
220
  image, text=words, boxes=normalized_boxes,
221
  truncation=True, max_length=512, return_tensors="pt"
222
- ).to(DEVICE)
 
 
223
 
224
  with torch.no_grad():
225
- outputs = MODEL(**encoding)
226
 
227
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
228
  extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
 
8
  from typing import List, Dict, Any, Tuple
9
  import re
10
  import numpy as np
11
+ from src.extraction import extract_invoice_number, extract_total, extract_address
12
  from doctr.io import DocumentFile
13
  from doctr.models import ocr_predictor
14
 
 
219
  encoding = PROCESSOR(
220
  image, text=words, boxes=normalized_boxes,
221
  truncation=True, max_length=512, return_tensors="pt"
222
+ )
223
+ # Move tensors to device for inference, but keep original encoding for word_ids()
224
+ model_inputs = {k: v.to(DEVICE) for k, v in encoding.items()}
225
 
226
  with torch.no_grad():
227
+ outputs = MODEL(**model_inputs)
228
 
229
  predictions = outputs.logits.argmax(-1).squeeze().tolist()
230
  extracted_entities = _process_predictions(words, unnormalized_boxes, encoding, predictions, MODEL.config.id2label)
src/models.py CHANGED
@@ -2,49 +2,54 @@
2
 
3
  from typing import List, Optional
4
  from datetime import date as DateType
 
5
  from decimal import Decimal
6
  from sqlmodel import SQLModel, Field, Relationship
7
 
8
- # --- INSTRUCTIONS ---
9
- # SQLModel classes should mirror the Pydantic models in src/schema.py
10
- # but with database-specific configurations (primary keys, foreign keys).
11
 
12
  class Invoice(SQLModel, table=True):
13
  __tablename__ = "invoices"
 
14
 
15
- # TODO: Define Primary Key 'id' (int, optional, default=None)
 
16
 
17
- # TODO: Add Data Fields
18
- # - receipt_number (str, indexed)
19
- # - date (DateType)
20
- # - total_amount (Decimal, max_digits=10, decimal_places=2)
21
- # - vendor (str)
22
- # - address (str)
23
- # - semantic_hash (str, unique, indexed) -> Critical for deduplication
24
 
25
- # TODO: Add Metadata Fields
26
- # - validation_status (str)
27
- # - validation_errors (str) -> Store as JSON string since we don't need to query inside it yet
28
- # - created_at (DateType) -> Default to today
29
 
30
- # TODO: Define relationship to LineItem (One-to-Many)
31
- # items: List["LineItem"] = Relationship(...)
32
- pass
 
 
 
 
 
33
 
34
 
35
  class LineItem(SQLModel, table=True):
36
  __tablename__ = "line_items"
 
37
 
38
- # TODO: Define Primary Key
 
39
 
40
- # TODO: Define Foreign Key 'invoice_id' linking to 'invoices.id'
 
41
 
42
- # TODO: Add Data Fields
43
- # - description (str)
44
- # - quantity (int)
45
- # - unit_price (Decimal)
46
- # - total (Decimal)
47
-
48
- # TODO: Define relationship back to Invoice
49
- # invoice: Optional[Invoice] = Relationship(...)
50
- pass
 
2
 
3
  from typing import List, Optional
4
  from datetime import date as DateType
5
+ from datetime import datetime
6
  from decimal import Decimal
7
  from sqlmodel import SQLModel, Field, Relationship
8
 
9
+ SQLModel.metadata.clear()
 
 
10
 
11
  class Invoice(SQLModel, table=True):
12
  __tablename__ = "invoices"
13
+ __table_args__ = {"extend_existing": True}
14
 
15
+ # Primary Key
16
+ id: Optional[int] = Field(default=None, primary_key=True)
17
 
18
+ # Data Fields
19
+ receipt_number: Optional[str] = Field(default=None, index=True)
20
+ date: Optional[DateType] = Field(default=None)
21
+ total_amount: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
22
+ vendor: Optional[str] = Field(default=None)
23
+ address: Optional[str] = Field(default=None)
 
24
 
25
+ # Critical for Deduplication
26
+ semantic_hash: str = Field(unique=True, index=True)
 
 
27
 
28
+ # Metadata Fields
29
+ validation_status: str = Field(default="unknown")
30
+ # Store validation_errors as a JSON string because SQLModel/SQLite doesn't always support arrays out of the box
31
+ validation_errors: Optional[str] = Field(default="[]")
32
+ created_at: DateType = Field(default_factory=datetime.now)
33
+
34
+ # Relationship to LineItem (One-to-Many)
35
+ items: List["LineItem"] = Relationship(back_populates="invoice")
36
 
37
 
38
  class LineItem(SQLModel, table=True):
39
  __tablename__ = "line_items"
40
+ __table_args__ = {"extend_existing": True}
41
 
42
+ # Primary Key
43
+ id: Optional[int] = Field(default=None, primary_key=True)
44
 
45
+ # Foreign Key
46
+ invoice_id: Optional[int] = Field(default=None, foreign_key="invoices.id")
47
 
48
+ # Data Fields
49
+ description: str
50
+ quantity: int = Field(default=1)
51
+ unit_price: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
52
+ total: Optional[Decimal] = Field(default=None, max_digits=10, decimal_places=2)
53
+
54
+ # Relationship back to Invoice
55
+ invoice: Optional[Invoice] = Relationship(back_populates="items")
 
src/pipeline.py CHANGED
@@ -12,12 +12,14 @@ from pydantic import ValidationError
12
  import cv2
13
 
14
  # --- IMPORTS ---
15
- from preprocessing import load_image, convert_to_grayscale, remove_noise
16
- from extraction import structure_output
17
- from ml_extraction import extract_ml_based
18
- from schema import InvoiceData
19
- from pdf_utils import extract_text_from_pdf, convert_pdf_to_images
20
- from utils import generate_semantic_hash
 
 
21
 
22
  def process_invoice(image_path: str,
23
  method: str = 'ml',
@@ -136,6 +138,38 @@ def process_invoice(image_path: str,
136
  # We calculate the hash based on the final (or raw) data.
137
  # This gives us a unique fingerprint for this specific business transaction.
138
  final_data['semantic_hash'] = generate_semantic_hash(final_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # --- SAVING STEP ---
141
  if save_results:
 
12
  import cv2
13
 
14
  # --- IMPORTS ---
15
+ from src.preprocessing import load_image, convert_to_grayscale, remove_noise
16
+ from src.extraction import structure_output
17
+ from src.ml_extraction import extract_ml_based
18
+ from src.schema import InvoiceData
19
+ from src.pdf_utils import extract_text_from_pdf, convert_pdf_to_images
20
+ from src.utils import generate_semantic_hash
21
+ from src.repository import InvoiceRepository
22
+ from src.database import DB_CONNECTED
23
 
24
  def process_invoice(image_path: str,
25
  method: str = 'ml',
 
138
  # We calculate the hash based on the final (or raw) data.
139
  # This gives us a unique fingerprint for this specific business transaction.
140
  final_data['semantic_hash'] = generate_semantic_hash(final_data)
141
+
142
+ # --- DATABASE SAVE (The Integration) ---
143
+ if not DB_CONNECTED:
144
+ # Database not available - skip save entirely (message shown once at startup)
145
+ final_data['_db_status'] = 'disabled'
146
+ else:
147
+ final_data['_db_status'] = 'disabled' # Default assumption
148
+ try:
149
+ print("💾 Attempting to save to Database...")
150
+ repo = InvoiceRepository()
151
+
152
+ if repo.session:
153
+ saved_record = repo.save_invoice(final_data)
154
+ if saved_record:
155
+ print(f" ✅ Successfully saved Invoice #{saved_record.id}")
156
+ final_data['_db_status'] = 'saved'
157
+ else:
158
+ # Check if it's a duplicate by looking up the hash
159
+ existing = repo.get_by_hash(final_data.get('semantic_hash', ''))
160
+ if existing:
161
+ print(" ⚠️ Duplicate invoice detected (already in database)")
162
+ final_data['_db_status'] = 'duplicate'
163
+ else:
164
+ print(" ⚠️ Save failed (unknown error)")
165
+ final_data['_db_status'] = 'error'
166
+ else:
167
+ print(" ⚠️ Skipped DB Save (Database disabled)")
168
+ final_data['_db_status'] = 'disabled'
169
+
170
+ except Exception as e:
171
+ print(f" ⚠️ Database Error (Ignored): {e}")
172
+ final_data['_db_status'] = 'error'
173
 
174
  # --- SAVING STEP ---
175
  if save_results:
src/repository.py CHANGED
@@ -3,39 +3,85 @@
3
  from sqlmodel import Session, select
4
  from typing import Dict, Any, Optional
5
  import json
 
6
 
7
  from src.models import Invoice, LineItem
8
- from src.database import get_session, engine
9
 
10
  class InvoiceRepository:
11
- def __init__(self, session: Session = None):
12
  """
13
  Initialize with an optional session.
14
- Allows dependency injection for testing or API usage.
 
15
  """
16
- self.session = session
 
 
 
 
 
17
 
18
- def save_invoice(self, invoice_data: Dict[str, Any]) -> Invoice:
19
  """
20
  Saves an invoice and its line items to the database.
21
-
22
- Steps to implement:
23
- 1. Manage Session: If self.session is None, create a new one using 'engine'.
24
- 2. Clean Data: Separate 'items' list from the main invoice properties.
25
- 3. Create Invoice: Instantiate the Invoice SQLModel.
26
- 4. Deserialize Complex Types: e.g. 'validation_errors' list -> JSON string.
27
- 5. Process Items: Iterate 'items', create LineItem models, check keys match, and append to invoice.items.
28
- 6. Commit: Add to session, commit, and refresh.
29
- 7. Error Handling: Wrap in try/except to rollback on failure.
30
  """
31
- # TODO: Implementation
32
- raise NotImplementedError("Implement the save logic.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def get_by_hash(self, semantic_hash: str) -> Optional[Invoice]:
35
  """
36
  Check if invoice already exists using the semantic hash.
37
  """
38
- # TODO: Create session if needed
39
- # TODO: Execute SELECT statement filtering by hash
40
- # TODO: Return first result or None
41
- raise NotImplementedError("Implement the query logic.")
 
 
 
 
 
 
 
3
  from sqlmodel import Session, select
4
  from typing import Dict, Any, Optional
5
  import json
6
+ from datetime import date
7
 
8
  from src.models import Invoice, LineItem
9
+ from src.database import get_session, engine, DB_CONNECTED
10
 
11
  class InvoiceRepository:
12
+ def __init__(self, session: Optional[Session] = None):
13
  """
14
  Initialize with an optional session.
15
+ If no session is provided, try to get a new one from the engine.
16
+ Only creates session if database is actually connected.
17
  """
18
+ if session:
19
+ self.session = session
20
+ elif engine and DB_CONNECTED:
21
+ self.session = Session(engine)
22
+ else:
23
+ self.session = None
24
 
25
+ def save_invoice(self, invoice_data: Dict[str, Any]) -> Optional[Invoice]:
26
  """
27
  Saves an invoice and its line items to the database.
28
+ Returns the saved Invoice object or None if DB is disabled/failed.
 
 
 
 
 
 
 
 
29
  """
30
+ if not self.session:
31
+ print("⚠️ DB Session missing. Skipping save.")
32
+ return None
33
+
34
+ try:
35
+ # 1. Prepare Data
36
+ data = invoice_data.copy()
37
+
38
+ # Serialize complex types (validation_errors)
39
+ if 'validation_errors' in data and isinstance(data['validation_errors'], list):
40
+ data['validation_errors'] = json.dumps(data['validation_errors'])
41
+
42
+ # Extract items to process separately
43
+ items_data = data.pop('items', [])
44
+
45
+ # 2. Create Invoice Record
46
+ invoice = Invoice(**data)
47
+
48
+ # 3. Process Items
49
+ for item in items_data:
50
+ # Ensure item is a dict (if it's a Pydantic model, convert it)
51
+ if hasattr(item, 'model_dump'):
52
+ item_dict = item.model_dump()
53
+ elif isinstance(item, dict):
54
+ item_dict = item
55
+ else:
56
+ continue
57
+
58
+ line_item = LineItem(**item_dict)
59
+ invoice.items.append(line_item)
60
+
61
+ # 4. Commit
62
+ self.session.add(invoice)
63
+ self.session.commit()
64
+ self.session.refresh(invoice)
65
+
66
+ print(f"✅ Invoice {invoice.id} saved to DB.")
67
+ return invoice
68
+
69
+ except Exception as e:
70
+ print(f"❌ Error saving invoice to DB: {e}")
71
+ self.session.rollback()
72
+ return None
73
 
74
  def get_by_hash(self, semantic_hash: str) -> Optional[Invoice]:
75
  """
76
  Check if invoice already exists using the semantic hash.
77
  """
78
+ if not self.session:
79
+ return None
80
+
81
+ try:
82
+ statement = select(Invoice).where(Invoice.semantic_hash == semantic_hash)
83
+ results = self.session.exec(statement)
84
+ return results.first()
85
+ except Exception as e:
86
+ print(f"❌ Error checking hash: {e}")
87
+ return None