invoice-processor-ml / src /repository.py
GSoumyajit2005's picture
feat: PDF preview, database integration, and improved error handling
2a944a5
# src/repository.py
from sqlmodel import Session, select
from typing import Dict, Any, Optional
import json
from datetime import date
from src.models import Invoice, LineItem
from src.database import get_session, engine, DB_CONNECTED
class InvoiceRepository:
def __init__(self, session: Optional[Session] = None):
"""
Initialize with an optional session.
If no session is provided, try to get a new one from the engine.
Only creates session if database is actually connected.
"""
if session:
self.session = session
elif engine and DB_CONNECTED:
self.session = Session(engine)
else:
self.session = None
def save_invoice(self, invoice_data: Dict[str, Any]) -> Optional[Invoice]:
"""
Saves an invoice and its line items to the database.
Returns the saved Invoice object or None if DB is disabled/failed.
"""
if not self.session:
print("⚠️ DB Session missing. Skipping save.")
return None
try:
# 1. Prepare Data
data = invoice_data.copy()
# Serialize complex types (validation_errors)
if 'validation_errors' in data and isinstance(data['validation_errors'], list):
data['validation_errors'] = json.dumps(data['validation_errors'])
# Extract items to process separately
items_data = data.pop('items', [])
# 2. Create Invoice Record
invoice = Invoice(**data)
# 3. Process Items
for item in items_data:
# Ensure item is a dict (if it's a Pydantic model, convert it)
if hasattr(item, 'model_dump'):
item_dict = item.model_dump()
elif isinstance(item, dict):
item_dict = item
else:
continue
line_item = LineItem(**item_dict)
invoice.items.append(line_item)
# 4. Commit
self.session.add(invoice)
self.session.commit()
self.session.refresh(invoice)
print(f"✅ Invoice {invoice.id} saved to DB.")
return invoice
except Exception as e:
print(f"❌ Error saving invoice to DB: {e}")
self.session.rollback()
return None
def get_by_hash(self, semantic_hash: str) -> Optional[Invoice]:
"""
Check if invoice already exists using the semantic hash.
"""
if not self.session:
return None
try:
statement = select(Invoice).where(Invoice.semantic_hash == semantic_hash)
results = self.session.exec(statement)
return results.first()
except Exception as e:
print(f"❌ Error checking hash: {e}")
return None