Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict | |
| import sqlite3 | |
| import joblib | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from filelock import FileLock | |
| from fastapi.responses import JSONResponse | |
| import json | |
| import sys | |
| import logging | |
| # Setup logging for entire app | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| force=True | |
| ) | |
| # Setup paths | |
| BASE_DIR = Path(__file__).parent | |
| DB_PATH = BASE_DIR / "data" / "invoices.db" # Inside container: /app/data/invoices.db | |
| LOCK_PATH = BASE_DIR / "data" / "invoices.db.lock" | |
| MODEL_PATH = BASE_DIR / "ml" / "models" / "payment_predictor_model_20251124_194847.pkl" | |
| LOG_DIR = BASE_DIR / "data" / "logs" | |
| PREDICTIONS_LOG = LOG_DIR / "predictions.csv" | |
| # Ensure directories exist | |
| LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| # Add backend to path | |
| sys.path.append(str(BASE_DIR / "backend")) | |
| # Import feature builder | |
| from backend.feature_builder.feature_builder import build_features, features_to_dataframe | |
| from backend.ingest.ingest_invoice_sqlite import ingest_invoice as ingest_func | |
| # ============================================ | |
| # IMPORT INGEST ROUTER (NEW) | |
| # ============================================ | |
| from backend.app.api.ingest import router as ingest_router | |
| # Load ML model | |
| print(" Loading ML model...") | |
| try: | |
| model_artifacts = joblib.load(MODEL_PATH) | |
| model = model_artifacts['model'] | |
| print(f" Model loaded: {MODEL_PATH.name}") | |
| except Exception as e: | |
| print(f" Failed to load model: {e}") | |
| model = None | |
| # FastAPI app | |
| app = FastAPI( | |
| title="Invoice Digitization", | |
| description="Degitize invoice", | |
| version="1.0.0" | |
| ) | |
| # ============================================ | |
| # REGISTER INGEST ROUTER (NEW) | |
| # ============================================ | |
| app.include_router(ingest_router) | |
| # ============================================ | |
| # Pydantic Models | |
| # ============================================ | |
| class InvoiceIngest(BaseModel): | |
| invoice_id: int | |
| business_code: str | |
| cust_number: str | |
| name_customer: Optional[str] = None | |
| posting_date: str | |
| document_create_date: Optional[str] = None | |
| document_create_date_alt: Optional[str] = None | |
| due_in_date: Optional[str] = None | |
| baseline_create_date: Optional[str] = None | |
| clear_date: Optional[str] = None | |
| total_open_amount: float | |
| invoice_currency: str = "USD" | |
| document_type: Optional[str] = "RV" | |
| cust_payment_terms: Optional[str] = None | |
| posting_id: Optional[float] = None | |
| business_year: Optional[int] = None | |
| class PredictionRequest(BaseModel): | |
| invoice_id: Optional[int] = None | |
| cust_number: str | |
| posting_date: str | |
| total_open_amount: float | |
| business_code: str = "U001" | |
| cust_payment_terms: str = "NAH4" | |
| invoice_currency: str = "USD" | |
| document_type: str = "RV" | |
| due_in_date: Optional[str] = None | |
| business_year: Optional[int] = None | |
| # ============================================ | |
| # Helper Functions | |
| # ============================================ | |
| def get_customer_aggregates(cust_number: str) -> Optional[Dict]: | |
| """Fetch customer aggregates from SQLite.""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT * FROM customer_aggregates WHERE cust_number = ? | |
| """, (cust_number,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| return dict(row) | |
| except Exception as e: | |
| print(f"Error fetching customer aggregates: {e}") | |
| return None | |
| def get_payment_terms_aggregates(payment_terms: str) -> Optional[Dict]: | |
| """Fetch payment terms aggregates from SQLite.""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT * FROM payment_terms_aggregates WHERE cust_payment_terms = ? | |
| """, (payment_terms,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| return dict(row) | |
| except Exception as e: | |
| print(f"Error fetching payment terms: {e}") | |
| return None | |
| def get_business_code_aggregates(business_code: str) -> Optional[Dict]: | |
| """Fetch business code aggregates from SQLite.""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT * FROM business_code_aggregates WHERE business_code = ? | |
| """, (business_code,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| return dict(row) | |
| except Exception as e: | |
| print(f"Error fetching business code: {e}") | |
| return None | |
| def log_prediction_to_csv(prediction_data: Dict): | |
| """Append prediction to CSV log.""" | |
| df = pd.DataFrame([prediction_data]) | |
| if not PREDICTIONS_LOG.exists(): | |
| df.to_csv(PREDICTIONS_LOG, index=False) | |
| else: | |
| df.to_csv(PREDICTIONS_LOG, mode='a', header=False, index=False) | |
| def log_prediction_to_db(prediction_data: Dict): | |
| """Insert prediction into SQLite predictions_log.""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO predictions_log ( | |
| invoice_id, cust_number, posting_date, total_open_amount, | |
| business_code, cust_payment_terms, predicted_days_to_clear, | |
| predicted_clear_date, model_version, features_json | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| prediction_data.get('invoice_id'), | |
| prediction_data['cust_number'], | |
| prediction_data['posting_date'], | |
| prediction_data['total_open_amount'], | |
| prediction_data.get('business_code'), | |
| prediction_data.get('cust_payment_terms'), | |
| prediction_data['predicted_days_to_clear'], | |
| prediction_data['predicted_clear_date'], | |
| prediction_data.get('model_version', 'v1.0'), | |
| json.dumps(prediction_data.get('features', {})) | |
| )) | |
| prediction_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return prediction_id | |
| except Exception as e: | |
| print(f"Error logging to DB: {e}") | |
| return None | |
| # ============================================ | |
| # API Endpoints | |
| # ============================================ | |
| def root(): | |
| """Root endpoint.""" | |
| return { | |
| "service": "Invoice Digitization", | |
| "version": "1.0.0", | |
| "status": "operational", | |
| "model_loaded": model is not None | |
| } | |
| def health(): | |
| return JSONResponse( | |
| content={ | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "db_exists": DB_PATH.exists() | |
| }, | |
| media_type="application/json" | |
| ) | |
| def ingest_invoice(invoice: InvoiceIngest): | |
| """ | |
| Ingest invoice into SQLite database. | |
| Computes derived fields and stores data. | |
| """ | |
| try: | |
| result = ingest_func(invoice.dict()) | |
| return { | |
| "status": "success", | |
| "message": "Invoice ingested successfully", | |
| "data": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Ingestion failed: {str(e)}") | |
| def get_features(cust_number: str): | |
| """ | |
| Get customer aggregate features. | |
| Returns cached aggregates or defaults for new customers. | |
| """ | |
| customer_agg = get_customer_aggregates(cust_number) | |
| if not customer_agg: | |
| return { | |
| "cust_number": cust_number, | |
| "status": "new_customer", | |
| "message": "No historical data found, using defaults", | |
| "features": { | |
| "cust_avg_days": 18.0, | |
| "cust_median_days": 15.0, | |
| "cust_invoice_count": 0 | |
| } | |
| } | |
| return { | |
| "cust_number": cust_number, | |
| "status": "existing_customer", | |
| "features": customer_agg | |
| } | |
| def predict(request: PredictionRequest): | |
| """ | |
| Predict payment clearing time for an invoice. | |
| Returns: | |
| - predicted_days_to_clear | |
| - predicted_clear_date | |
| - confidence info | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="ML model not loaded") | |
| try: | |
| # Fetch aggregates | |
| customer_agg = get_customer_aggregates(request.cust_number) | |
| payment_agg = get_payment_terms_aggregates(request.cust_payment_terms) | |
| business_agg = get_business_code_aggregates(request.business_code) | |
| # Build invoice data dict | |
| invoice_data = request.dict() | |
| # Compute days_posting_to_due if due_in_date provided | |
| if request.due_in_date: | |
| posting_dt = datetime.strptime(request.posting_date, "%Y-%m-%d") | |
| due_dt = datetime.strptime(request.due_in_date, "%Y-%m-%d") | |
| invoice_data['days_posting_to_due'] = (due_dt - posting_dt).days | |
| else: | |
| invoice_data['days_posting_to_due'] = 15 # Default | |
| # Build features | |
| features = build_features(invoice_data, customer_agg, payment_agg, business_agg) | |
| features_df = features_to_dataframe(features) | |
| # Predict | |
| predicted_days = float(model.predict(features_df)[0]) | |
| # Calculate predicted clear date | |
| posting_dt = datetime.strptime(request.posting_date, "%Y-%m-%d") | |
| predicted_clear_dt = posting_dt + timedelta(days=predicted_days) | |
| # Prepare response | |
| response = { | |
| "invoice_id": request.invoice_id, | |
| "cust_number": request.cust_number, | |
| "posting_date": request.posting_date, | |
| "total_open_amount": request.total_open_amount, | |
| "predicted_days_to_clear": round(predicted_days, 2), | |
| "predicted_clear_date": predicted_clear_dt.strftime("%Y-%m-%d"), | |
| "customer_history": "available" if customer_agg else "new_customer", | |
| "model_version": "v1.0" | |
| } | |
| # Log prediction | |
| log_prediction_to_csv(response) | |
| prediction_id = log_prediction_to_db({ | |
| **response, | |
| 'business_code': request.business_code, | |
| 'cust_payment_terms': request.cust_payment_terms, | |
| 'features': features | |
| }) | |
| response['prediction_id'] = prediction_id | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| def get_recent_predictions(limit: int = 10): | |
| """Get recent predictions from log.""" | |
| try: | |
| with FileLock(str(LOCK_PATH), timeout=10): | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT | |
| prediction_id, | |
| cust_number, | |
| posting_date, | |
| predicted_days_to_clear, | |
| predicted_clear_date, | |
| predicted_at | |
| FROM predictions_log | |
| ORDER BY predicted_at DESC | |
| LIMIT ? | |
| """, (limit,)) | |
| rows = cursor.fetchall() | |
| conn.close() | |
| return { | |
| "count": len(rows), | |
| "predictions": [dict(row) for row in rows] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch predictions: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| timeout_keep_alive=75, | |
| timeout_graceful_shutdown=10 | |
| ) |