dipan004's picture
Update app.py
c922038 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict
import sqlite3
import joblib
import pandas as pd
from datetime import datetime, timedelta
from pathlib import Path
from filelock import FileLock
from fastapi.responses import JSONResponse
import json
import sys
import logging
# Setup logging for entire app
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)],
force=True
)
# Setup paths
BASE_DIR = Path(__file__).parent
DB_PATH = BASE_DIR / "data" / "invoices.db" # Inside container: /app/data/invoices.db
LOCK_PATH = BASE_DIR / "data" / "invoices.db.lock"
MODEL_PATH = BASE_DIR / "ml" / "models" / "payment_predictor_model_20251124_194847.pkl"
LOG_DIR = BASE_DIR / "data" / "logs"
PREDICTIONS_LOG = LOG_DIR / "predictions.csv"
# Ensure directories exist
LOG_DIR.mkdir(parents=True, exist_ok=True)
# Add backend to path
sys.path.append(str(BASE_DIR / "backend"))
# Import feature builder
from backend.feature_builder.feature_builder import build_features, features_to_dataframe
from backend.ingest.ingest_invoice_sqlite import ingest_invoice as ingest_func
# ============================================
# IMPORT INGEST ROUTER (NEW)
# ============================================
from backend.app.api.ingest import router as ingest_router
# Load ML model
print(" Loading ML model...")
try:
model_artifacts = joblib.load(MODEL_PATH)
model = model_artifacts['model']
print(f" Model loaded: {MODEL_PATH.name}")
except Exception as e:
print(f" Failed to load model: {e}")
model = None
# FastAPI app
app = FastAPI(
title="Invoice Digitization",
description="Degitize invoice",
version="1.0.0"
)
# ============================================
# REGISTER INGEST ROUTER (NEW)
# ============================================
app.include_router(ingest_router)
# ============================================
# Pydantic Models
# ============================================
class InvoiceIngest(BaseModel):
invoice_id: int
business_code: str
cust_number: str
name_customer: Optional[str] = None
posting_date: str
document_create_date: Optional[str] = None
document_create_date_alt: Optional[str] = None
due_in_date: Optional[str] = None
baseline_create_date: Optional[str] = None
clear_date: Optional[str] = None
total_open_amount: float
invoice_currency: str = "USD"
document_type: Optional[str] = "RV"
cust_payment_terms: Optional[str] = None
posting_id: Optional[float] = None
business_year: Optional[int] = None
class PredictionRequest(BaseModel):
invoice_id: Optional[int] = None
cust_number: str
posting_date: str
total_open_amount: float
business_code: str = "U001"
cust_payment_terms: str = "NAH4"
invoice_currency: str = "USD"
document_type: str = "RV"
due_in_date: Optional[str] = None
business_year: Optional[int] = None
# ============================================
# Helper Functions
# ============================================
def get_customer_aggregates(cust_number: str) -> Optional[Dict]:
"""Fetch customer aggregates from SQLite."""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM customer_aggregates WHERE cust_number = ?
""", (cust_number,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
except Exception as e:
print(f"Error fetching customer aggregates: {e}")
return None
def get_payment_terms_aggregates(payment_terms: str) -> Optional[Dict]:
"""Fetch payment terms aggregates from SQLite."""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM payment_terms_aggregates WHERE cust_payment_terms = ?
""", (payment_terms,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
except Exception as e:
print(f"Error fetching payment terms: {e}")
return None
def get_business_code_aggregates(business_code: str) -> Optional[Dict]:
"""Fetch business code aggregates from SQLite."""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT * FROM business_code_aggregates WHERE business_code = ?
""", (business_code,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
except Exception as e:
print(f"Error fetching business code: {e}")
return None
def log_prediction_to_csv(prediction_data: Dict):
"""Append prediction to CSV log."""
df = pd.DataFrame([prediction_data])
if not PREDICTIONS_LOG.exists():
df.to_csv(PREDICTIONS_LOG, index=False)
else:
df.to_csv(PREDICTIONS_LOG, mode='a', header=False, index=False)
def log_prediction_to_db(prediction_data: Dict):
"""Insert prediction into SQLite predictions_log."""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.cursor()
cursor.execute("""
INSERT INTO predictions_log (
invoice_id, cust_number, posting_date, total_open_amount,
business_code, cust_payment_terms, predicted_days_to_clear,
predicted_clear_date, model_version, features_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
prediction_data.get('invoice_id'),
prediction_data['cust_number'],
prediction_data['posting_date'],
prediction_data['total_open_amount'],
prediction_data.get('business_code'),
prediction_data.get('cust_payment_terms'),
prediction_data['predicted_days_to_clear'],
prediction_data['predicted_clear_date'],
prediction_data.get('model_version', 'v1.0'),
json.dumps(prediction_data.get('features', {}))
))
prediction_id = cursor.lastrowid
conn.commit()
conn.close()
return prediction_id
except Exception as e:
print(f"Error logging to DB: {e}")
return None
# ============================================
# API Endpoints
# ============================================
@app.get("/")
def root():
"""Root endpoint."""
return {
"service": "Invoice Digitization",
"version": "1.0.0",
"status": "operational",
"model_loaded": model is not None
}
@app.get("/health")
def health():
return JSONResponse(
content={
"status": "ok",
"model_loaded": model is not None,
"db_exists": DB_PATH.exists()
},
media_type="application/json"
)
@app.post("/ingest")
def ingest_invoice(invoice: InvoiceIngest):
"""
Ingest invoice into SQLite database.
Computes derived fields and stores data.
"""
try:
result = ingest_func(invoice.dict())
return {
"status": "success",
"message": "Invoice ingested successfully",
"data": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Ingestion failed: {str(e)}")
@app.get("/features/{cust_number}")
def get_features(cust_number: str):
"""
Get customer aggregate features.
Returns cached aggregates or defaults for new customers.
"""
customer_agg = get_customer_aggregates(cust_number)
if not customer_agg:
return {
"cust_number": cust_number,
"status": "new_customer",
"message": "No historical data found, using defaults",
"features": {
"cust_avg_days": 18.0,
"cust_median_days": 15.0,
"cust_invoice_count": 0
}
}
return {
"cust_number": cust_number,
"status": "existing_customer",
"features": customer_agg
}
@app.post("/predict")
def predict(request: PredictionRequest):
"""
Predict payment clearing time for an invoice.
Returns:
- predicted_days_to_clear
- predicted_clear_date
- confidence info
"""
if model is None:
raise HTTPException(status_code=503, detail="ML model not loaded")
try:
# Fetch aggregates
customer_agg = get_customer_aggregates(request.cust_number)
payment_agg = get_payment_terms_aggregates(request.cust_payment_terms)
business_agg = get_business_code_aggregates(request.business_code)
# Build invoice data dict
invoice_data = request.dict()
# Compute days_posting_to_due if due_in_date provided
if request.due_in_date:
posting_dt = datetime.strptime(request.posting_date, "%Y-%m-%d")
due_dt = datetime.strptime(request.due_in_date, "%Y-%m-%d")
invoice_data['days_posting_to_due'] = (due_dt - posting_dt).days
else:
invoice_data['days_posting_to_due'] = 15 # Default
# Build features
features = build_features(invoice_data, customer_agg, payment_agg, business_agg)
features_df = features_to_dataframe(features)
# Predict
predicted_days = float(model.predict(features_df)[0])
# Calculate predicted clear date
posting_dt = datetime.strptime(request.posting_date, "%Y-%m-%d")
predicted_clear_dt = posting_dt + timedelta(days=predicted_days)
# Prepare response
response = {
"invoice_id": request.invoice_id,
"cust_number": request.cust_number,
"posting_date": request.posting_date,
"total_open_amount": request.total_open_amount,
"predicted_days_to_clear": round(predicted_days, 2),
"predicted_clear_date": predicted_clear_dt.strftime("%Y-%m-%d"),
"customer_history": "available" if customer_agg else "new_customer",
"model_version": "v1.0"
}
# Log prediction
log_prediction_to_csv(response)
prediction_id = log_prediction_to_db({
**response,
'business_code': request.business_code,
'cust_payment_terms': request.cust_payment_terms,
'features': features
})
response['prediction_id'] = prediction_id
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/predictions/recent")
def get_recent_predictions(limit: int = 10):
"""Get recent predictions from log."""
try:
with FileLock(str(LOCK_PATH), timeout=10):
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT
prediction_id,
cust_number,
posting_date,
predicted_days_to_clear,
predicted_clear_date,
predicted_at
FROM predictions_log
ORDER BY predicted_at DESC
LIMIT ?
""", (limit,))
rows = cursor.fetchall()
conn.close()
return {
"count": len(rows),
"predictions": [dict(row) for row in rows]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch predictions: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
timeout_keep_alive=75,
timeout_graceful_shutdown=10
)