SamadhiDBS's picture
Update app/main.py
b1f3a3c verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from app.database import engine
from app.ml_models import load_models
from app.routes import customers, predictions
from sqlalchemy import text
# load ML models on startup
models_loaded = load_models()
# create fastapi app
app = FastAPI(
title="E-Commerce Customer Intelligence API",
description="""
This API provides customer insights from your e-commerce data.
## Features
* **Customer Segmentation** - Predict customer segments using K-Means
* **CLV Prediction** - Predict Customer Lifetime Value
* **Customer Data** - Access customer information from database
## Models
* K-Means Clustering (4 segments)
* Random Forest CLV Predictor (95% accuracy)
""",
version="1.0.0",
contact={
"name": "Your Name",
"email": "your.email@example.com",
},
)
# Redirect root to Swagger UI
@app.get("/", include_in_schema=False)
async def redirect_to_docs():
"""Redirect root to Swagger UI"""
return RedirectResponse(url="/docs")
# add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# include routers
app.include_router(customers.router)
app.include_router(predictions.router)
@app.get("/health", tags=["Health"])
def health_check():
"""Check if API and database are working"""
db_status = "unknown"
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
return {
"status": "healthy",
"models_loaded": {
"kmeans": models_loaded,
"clv": models_loaded
},
"database": db_status,
"timestamp": "2024-01-01T00:00:00Z"
}
@app.get("/info", tags=["Info"])
def api_info():
"""Get API information and available endpoints"""
return {
"name": "E-Commerce Customer Intelligence API",
"version": "1.0.0",
"endpoints": {
"GET /": "Redirects to Swagger UI",
"GET /health": "Health check",
"GET /info": "This information",
"GET /customers": "List all customers",
"GET /customers/{id}": "Get customer by ID",
"GET /customers/{id}/transactions": "Get customer transactions",
"POST /predict/segment": "Predict customer segment from RFM",
"POST /predict/clv": "Predict Customer Lifetime Value"
},
"models": {
"customer_segmentation": "K-Means (4 clusters)",
"clv_prediction": "Random Forest (95% accuracy)"
}
}