tennis-ml / src /main.py
SlimG's picture
fixup: HTTPException has no 'status' attribute
0192d50
import os
import joblib
import logging
import secrets
from typing import Generator, Optional, Annotated, List
from fastapi import (
FastAPI,
Request,
HTTPException,
Query,
Security,
Depends
)
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.background import BackgroundTasks
from fastapi.security.api_key import APIKeyHeader
from starlette.status import (
HTTP_200_OK,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_503_SERVICE_UNAVAILABLE)
from dotenv import load_dotenv
from mlflow.exceptions import RestException
from src.entity.model import ModelInput, ModelOutput
from src.service.data_quality import DataChecker, check_model_data
from src.service.model import (
run_experiment,
predict,
list_registered_models,
load_model,
deploy_model,
undeploy_model,
)
from src.repository.common import get_connection
from psycopg import Connection
load_dotenv()
logging.basicConfig(level=logging.INFO,
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
def provide_connection() -> Generator[Connection, None, None]:
with get_connection() as conn:
yield conn
# ------------------------------------------------------------------------------
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
async def validate_api_key(request: Request, key: str = Security(api_key_header)):
'''
Check if the API key is valid
Args:
key (str): The API key to check
Raises:
HTTPException: If the API key is invalid
'''
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
)
return None
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
title="Tennis Insights ML API",
description="API for the Tennis Insights ML module",)
# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
'''
Redirect to the API documentation.
'''
return RedirectResponse(url='/docs')
@app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment")
async def run_xp(background_tasks: BackgroundTasks,
algo: str = Query(default="LogisticRegression", description="The algorithm to use for training"),
registered_model_name: Optional[str] = Query(default=None, description="The name of the registered model"),
experiment_name: Optional[str] = Query(default="Tennis Prediction", description="The name of the experiment")):
"""
Train the model
"""
background_tasks.add_task(func=run_experiment,
algo=algo,
registered_model_name=registered_model_name,
experiment_name=experiment_name,)
return {"message": "Experiment scheduled"}
@app.get("/predict",
tags=["model"],
description="Predict the outcome of a tennis match",
response_model=ModelOutput)
async def make_prediction(params: Annotated[ModelInput, Query()]):
"""
Predict the matches
"""
if not params.model:
# check the presence of 'model.pkl' file in data/
if not os.path.exists("/data/model.pkl"):
return {"message": "Model not trained. Please train the model first."}
# Load the model
pipeline = joblib.load("/data/model.pkl")
else:
# Get the model info
try:
pipeline = load_model(name=params.model, alias=params.alias)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Model {params.model} not found"
)
# Make the prediction
prediction = predict(
pipeline=pipeline,
series=params.series,
surface=params.surface,
court=params.court,
p1_rank=params.p1_rank,
p1_play_hand=params.p1_play_hand,
p1_back_hand=params.p1_back_hand,
p1_height=params.p1_height,
p1_weight=params.p1_weight,
p1_year_of_birth=params.p1_year_of_birth,
p1_pro_year=params.p1_pro_year,
p2_rank=params.p2_rank,
p2_play_hand=params.p2_play_hand,
p2_back_hand=params.p2_back_hand,
p2_height=params.p2_height,
p2_weight=params.p2_weight,
p2_year_of_birth=params.p2_year_of_birth,
p2_pro_year=params.p2_pro_year,
)
logger.info(prediction)
return prediction
@app.get("/list_available_models", tags=["model"], description="List the available models")
async def list_available_models(
aliases: Optional[List[str]] = Query(default=None, description="List of model aliases to filter the models")):
"""
List the available models
"""
return list_registered_models(alias_filter=aliases)
@app.post("/deploy_model", tags=["model"], description="Deploy a model")
async def deploy_model_to_production(
model_name: str = Query(description="The name of the model to deploy"),
version: str = Query(description="The version of the model to deploy")):
"""
Deploy a model
"""
# Deploy the model
try:
deploy_model(model_name=model_name, model_version=version)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"},
status_code=HTTP_404_NOT_FOUND)
return {"message": f"Model {model_name} deployed to production"}
@app.post("/undeploy_model", tags=["model"], description="Undeploy a model")
async def undeploy_model_from_production(model_name: str = Query(description="The name of the model to undeploy")):
"""
Undeploy a model
"""
# Undeploy the model
try:
undeploy_model(model_name=model_name)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return JSONResponse(content={"message": f"Model {model_name} not found or not in production"},
status_code=HTTP_404_NOT_FOUND)
return {"message": f"Model {model_name} undeployed from production"}
@app.get("/check_data_quality", tags=["data"], description="Check the data quality")
async def check_data_quality(
background_tasks: BackgroundTasks,
model_name: str = Query(description="The name of the model to check"),
project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"),
):
"""
Check the data quality
"""
# Get the API key and project ID from the environment variables
api_key = os.getenv("EVIDENTLY_API_KEY")
project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID")
# Check if the API key and project ID are set
if not api_key or not project_id:
return JSONResponse(content={"message": "Evidently API key or project ID not set"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
# Schedule the data quality check
background_tasks.add_task(func=check_model_data,
model_name=model_name,
checker=DataChecker(api_key, project_id))
return {"message": "Data quality check scheduled"}
# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the ML module")
async def check_health(session: Connection = Depends(provide_connection)):
"""
Check all the services in the infrastructure are working
"""
# Check if the database is alive
try:
with session.cursor() as cursor:
cursor.execute("SELECT 1").fetchall()
except Exception as e:
logger.error(f"DB check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
# Check if the mlflow endpoint is reachable
if MLFLOW_SERVER_URI := os.getenv("MLFLOW_SERVER_URI"):
import requests
try:
# Ping the mlflow server endpoint
response = requests.get(MLFLOW_SERVER_URI + "/health", timeout=5)
if response.status_code != HTTP_200_OK:
logger.error(f"Mlfow server check failed: {response.status_code}")
return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
except requests.RequestException as e:
logger.error(f"Mlfow server check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)