File size: 9,394 Bytes
a738b50 ea7db66 a738b50 8a395d5 5ef0c38 a738b50 851f0b2 a738b50 a2f5aca a738b50 b8027bf a738b50 0192d50 a738b50 ea7db66 a738b50 ea7db66 a738b50 851f0b2 5ef0c38 851f0b2 5ef0c38 851f0b2 4bb489e 5ef0c38 4bb489e 5ef0c38 4bb489e 5ef0c38 4bb489e 5ef0c38 4bb489e a738b50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | import os
import joblib
import logging
import secrets
from typing import Generator, Optional, Annotated, List
from fastapi import (
FastAPI,
Request,
HTTPException,
Query,
Security,
Depends
)
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.background import BackgroundTasks
from fastapi.security.api_key import APIKeyHeader
from starlette.status import (
HTTP_200_OK,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_503_SERVICE_UNAVAILABLE)
from dotenv import load_dotenv
from mlflow.exceptions import RestException
from src.entity.model import ModelInput, ModelOutput
from src.service.data_quality import DataChecker, check_model_data
from src.service.model import (
run_experiment,
predict,
list_registered_models,
load_model,
deploy_model,
undeploy_model,
)
from src.repository.common import get_connection
from psycopg import Connection
load_dotenv()
logging.basicConfig(level=logging.INFO,
handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)
def provide_connection() -> Generator[Connection, None, None]:
with get_connection() as conn:
yield conn
# ------------------------------------------------------------------------------
FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']
api_key_header = APIKeyHeader(name='Authorization', auto_error=False)
async def validate_api_key(request: Request, key: str = Security(api_key_header)):
'''
Check if the API key is valid
Args:
key (str): The API key to check
Raises:
HTTPException: If the API key is invalid
'''
if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
)
return None
app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
title="Tennis Insights ML API",
description="API for the Tennis Insights ML module",)
# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
'''
Redirect to the API documentation.
'''
return RedirectResponse(url='/docs')
@app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment")
async def run_xp(background_tasks: BackgroundTasks,
algo: str = Query(default="LogisticRegression", description="The algorithm to use for training"),
registered_model_name: Optional[str] = Query(default=None, description="The name of the registered model"),
experiment_name: Optional[str] = Query(default="Tennis Prediction", description="The name of the experiment")):
"""
Train the model
"""
background_tasks.add_task(func=run_experiment,
algo=algo,
registered_model_name=registered_model_name,
experiment_name=experiment_name,)
return {"message": "Experiment scheduled"}
@app.get("/predict",
tags=["model"],
description="Predict the outcome of a tennis match",
response_model=ModelOutput)
async def make_prediction(params: Annotated[ModelInput, Query()]):
"""
Predict the matches
"""
if not params.model:
# check the presence of 'model.pkl' file in data/
if not os.path.exists("/data/model.pkl"):
return {"message": "Model not trained. Please train the model first."}
# Load the model
pipeline = joblib.load("/data/model.pkl")
else:
# Get the model info
try:
pipeline = load_model(name=params.model, alias=params.alias)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"Model {params.model} not found"
)
# Make the prediction
prediction = predict(
pipeline=pipeline,
series=params.series,
surface=params.surface,
court=params.court,
p1_rank=params.p1_rank,
p1_play_hand=params.p1_play_hand,
p1_back_hand=params.p1_back_hand,
p1_height=params.p1_height,
p1_weight=params.p1_weight,
p1_year_of_birth=params.p1_year_of_birth,
p1_pro_year=params.p1_pro_year,
p2_rank=params.p2_rank,
p2_play_hand=params.p2_play_hand,
p2_back_hand=params.p2_back_hand,
p2_height=params.p2_height,
p2_weight=params.p2_weight,
p2_year_of_birth=params.p2_year_of_birth,
p2_pro_year=params.p2_pro_year,
)
logger.info(prediction)
return prediction
@app.get("/list_available_models", tags=["model"], description="List the available models")
async def list_available_models(
aliases: Optional[List[str]] = Query(default=None, description="List of model aliases to filter the models")):
"""
List the available models
"""
return list_registered_models(alias_filter=aliases)
@app.post("/deploy_model", tags=["model"], description="Deploy a model")
async def deploy_model_to_production(
model_name: str = Query(description="The name of the model to deploy"),
version: str = Query(description="The version of the model to deploy")):
"""
Deploy a model
"""
# Deploy the model
try:
deploy_model(model_name=model_name, model_version=version)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"},
status_code=HTTP_404_NOT_FOUND)
return {"message": f"Model {model_name} deployed to production"}
@app.post("/undeploy_model", tags=["model"], description="Undeploy a model")
async def undeploy_model_from_production(model_name: str = Query(description="The name of the model to undeploy")):
"""
Undeploy a model
"""
# Undeploy the model
try:
undeploy_model(model_name=model_name)
except RestException as e:
logger.error(e)
# Return HTTP error 404
return JSONResponse(content={"message": f"Model {model_name} not found or not in production"},
status_code=HTTP_404_NOT_FOUND)
return {"message": f"Model {model_name} undeployed from production"}
@app.get("/check_data_quality", tags=["data"], description="Check the data quality")
async def check_data_quality(
background_tasks: BackgroundTasks,
model_name: str = Query(description="The name of the model to check"),
project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"),
):
"""
Check the data quality
"""
# Get the API key and project ID from the environment variables
api_key = os.getenv("EVIDENTLY_API_KEY")
project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID")
# Check if the API key and project ID are set
if not api_key or not project_id:
return JSONResponse(content={"message": "Evidently API key or project ID not set"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
# Schedule the data quality check
background_tasks.add_task(func=check_model_data,
model_name=model_name,
checker=DataChecker(api_key, project_id))
return {"message": "Data quality check scheduled"}
# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the ML module")
async def check_health(session: Connection = Depends(provide_connection)):
"""
Check all the services in the infrastructure are working
"""
# Check if the database is alive
try:
with session.cursor() as cursor:
cursor.execute("SELECT 1").fetchall()
except Exception as e:
logger.error(f"DB check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
# Check if the mlflow endpoint is reachable
if MLFLOW_SERVER_URI := os.getenv("MLFLOW_SERVER_URI"):
import requests
try:
# Ping the mlflow server endpoint
response = requests.get(MLFLOW_SERVER_URI + "/health", timeout=5)
if response.status_code != HTTP_200_OK:
logger.error(f"Mlfow server check failed: {response.status_code}")
return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
except requests.RequestException as e:
logger.error(f"Mlfow server check failed: {e}")
return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)
|