perachon's picture
fix & améliorations
0ad7b2b
from fastapi import FastAPI, HTTPException
import pandas as pd
from core.middlewares import LoggingMiddleware
import time
import json
from core.logging_config import get_logger
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi import Request
import os
# import cProfile
import pstats
import io
from src.model import load_model
from src.schemas import ClientData
# Seuil métier
THRESHOLD = 0.2
# Modèle chargé une seule fois (lazy loading)
_model = None
app = FastAPI(
title="Credit Scoring API",
description="API de prédiction du risque de défaut de paiement",
version="1.0"
)
# Chargement du modèle au démarrage
@app.on_event("startup")
def load_model_on_startup():
global _model
_model = load_model()
def get_model():
return _model
# Logs
app.add_middleware(LoggingMiddleware)
logger = get_logger()
@app.get("/health")
def health_check():
return {
"status": "ok",
"model_loaded": _model is not None
}
@app.post("/predict")
def predict(client: ClientData):
enable_profiling = os.getenv("ENABLE_PROFILING", "0") == "1"
# Profiling est désactivé par défaut (trop coûteux en prod)
pr = None
if enable_profiling:
import cProfile
pr = cProfile.Profile()
pr.enable()
start_time = time.time()
model = get_model()
X = pd.DataFrame([client.model_dump()])
proba = model.predict_proba(X)[0, 1]
decision = "REFUSED" if proba >= THRESHOLD else "ACCEPTED"
latency_ms = (time.time() - start_time) * 1000
# Log structuré (utilisé par monitoring/export_prod_data.py)
logger.info(
json.dumps(
{
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"event": "prediction",
"model": {"name": "lightgbm_credit_scoring", "version": "v1"},
"input": {"num_features": len(X.columns)},
"output": {
"probability_default": round(float(proba), 4),
"decision": decision,
"threshold": THRESHOLD,
},
"latency_ms": round(latency_ms, 2),
},
ensure_ascii=False,
)
)
if enable_profiling and pr is not None:
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats("cumulative")
ps.print_stats(10)
logger.info(s.getvalue())
return {
"probability_default": round(float(proba), 4),
"threshold": THRESHOLD,
"decision": decision
}
# Handler global
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
# Rendre l'erreur JSON-safe
safe_errors = [
{
"loc": err.get("loc"),
"msg": str(err.get("msg")),
"type": err.get("type")
}
for err in exc.errors()
]
log_entry = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"event": "validation_error",
"endpoint": request.url.path,
"method": request.method,
"status_code": 422,
"error": safe_errors
}
logger.info(json.dumps(log_entry, ensure_ascii=False)) # ensure_ascii : False pour les accents
# Comportement FastAPI standard restauré
return JSONResponse(
status_code=422,
content={"detail": safe_errors}
)