| | import os |
| | import joblib |
| | import logging |
| | import secrets |
| | from typing import Generator, Optional, Annotated, List |
| | from fastapi import ( |
| | FastAPI, |
| | Request, |
| | HTTPException, |
| | Query, |
| | Security, |
| | Depends |
| | ) |
| | from fastapi.responses import RedirectResponse, JSONResponse |
| | from fastapi.background import BackgroundTasks |
| | from fastapi.security.api_key import APIKeyHeader |
| | from starlette.status import ( |
| | HTTP_200_OK, |
| | HTTP_403_FORBIDDEN, |
| | HTTP_404_NOT_FOUND, |
| | HTTP_503_SERVICE_UNAVAILABLE) |
| | from dotenv import load_dotenv |
| | from mlflow.exceptions import RestException |
| |
|
| | from src.entity.model import ModelInput, ModelOutput |
| | from src.service.data_quality import DataChecker, check_model_data |
| | from src.service.model import ( |
| | run_experiment, |
| | predict, |
| | list_registered_models, |
| | load_model, |
| | deploy_model, |
| | undeploy_model, |
| | ) |
| | from src.repository.common import get_connection |
| | from psycopg import Connection |
| |
|
| | load_dotenv() |
| |
|
| | logging.basicConfig(level=logging.INFO, |
| | handlers=[logging.StreamHandler()]) |
| | logger = logging.getLogger(__name__) |
| |
|
| | def provide_connection() -> Generator[Connection, None, None]: |
| | with get_connection() as conn: |
| | yield conn |
| |
|
| | |
| |
|
| | FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY") |
| | safe_clients = ['127.0.0.1'] |
| |
|
| | api_key_header = APIKeyHeader(name='Authorization', auto_error=False) |
| |
|
| | async def validate_api_key(request: Request, key: str = Security(api_key_header)): |
| | ''' |
| | Check if the API key is valid |
| | |
| | Args: |
| | key (str): The API key to check |
| | |
| | Raises: |
| | HTTPException: If the API key is invalid |
| | ''' |
| | if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)): |
| | raise HTTPException( |
| | status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong" |
| | ) |
| | return None |
| |
|
| | app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None, |
| | title="Tennis Insights ML API", |
| | description="API for the Tennis Insights ML module",) |
| |
|
| | |
| | @app.get("/", include_in_schema=False) |
| | def redirect_to_docs(): |
| | ''' |
| | Redirect to the API documentation. |
| | ''' |
| | return RedirectResponse(url='/docs') |
| |
|
| | @app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment") |
| | async def run_xp(background_tasks: BackgroundTasks, |
| | algo: str = Query(default="LogisticRegression", description="The algorithm to use for training"), |
| | registered_model_name: Optional[str] = Query(default=None, description="The name of the registered model"), |
| | experiment_name: Optional[str] = Query(default="Tennis Prediction", description="The name of the experiment")): |
| | """ |
| | Train the model |
| | """ |
| | background_tasks.add_task(func=run_experiment, |
| | algo=algo, |
| | registered_model_name=registered_model_name, |
| | experiment_name=experiment_name,) |
| | |
| | return {"message": "Experiment scheduled"} |
| |
|
| | @app.get("/predict", |
| | tags=["model"], |
| | description="Predict the outcome of a tennis match", |
| | response_model=ModelOutput) |
| | async def make_prediction(params: Annotated[ModelInput, Query()]): |
| | """ |
| | Predict the matches |
| | """ |
| | if not params.model: |
| | |
| | if not os.path.exists("/data/model.pkl"): |
| | return {"message": "Model not trained. Please train the model first."} |
| | |
| | |
| | pipeline = joblib.load("/data/model.pkl") |
| | else: |
| | |
| | try: |
| | pipeline = load_model(name=params.model, alias=params.alias) |
| | except RestException as e: |
| | logger.error(e) |
| |
|
| | |
| | return HTTPException( |
| | status_code=HTTP_404_NOT_FOUND, |
| | detail=f"Model {params.model} not found" |
| | ) |
| |
|
| | |
| | prediction = predict( |
| | pipeline=pipeline, |
| | series=params.series, |
| | surface=params.surface, |
| | court=params.court, |
| | p1_rank=params.p1_rank, |
| | p1_play_hand=params.p1_play_hand, |
| | p1_back_hand=params.p1_back_hand, |
| | p1_height=params.p1_height, |
| | p1_weight=params.p1_weight, |
| | p1_year_of_birth=params.p1_year_of_birth, |
| | p1_pro_year=params.p1_pro_year, |
| | p2_rank=params.p2_rank, |
| | p2_play_hand=params.p2_play_hand, |
| | p2_back_hand=params.p2_back_hand, |
| | p2_height=params.p2_height, |
| | p2_weight=params.p2_weight, |
| | p2_year_of_birth=params.p2_year_of_birth, |
| | p2_pro_year=params.p2_pro_year, |
| | ) |
| |
|
| | logger.info(prediction) |
| |
|
| | return prediction |
| |
|
| | @app.get("/list_available_models", tags=["model"], description="List the available models") |
| | async def list_available_models( |
| | aliases: Optional[List[str]] = Query(default=None, description="List of model aliases to filter the models")): |
| | """ |
| | List the available models |
| | """ |
| | return list_registered_models(alias_filter=aliases) |
| |
|
| | @app.post("/deploy_model", tags=["model"], description="Deploy a model") |
| | async def deploy_model_to_production( |
| | model_name: str = Query(description="The name of the model to deploy"), |
| | version: str = Query(description="The version of the model to deploy")): |
| | """ |
| | Deploy a model |
| | """ |
| | |
| | try: |
| | deploy_model(model_name=model_name, model_version=version) |
| | except RestException as e: |
| | logger.error(e) |
| |
|
| | |
| | return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"}, |
| | status_code=HTTP_404_NOT_FOUND) |
| |
|
| | return {"message": f"Model {model_name} deployed to production"} |
| |
|
| | @app.post("/undeploy_model", tags=["model"], description="Undeploy a model") |
| | async def undeploy_model_from_production(model_name: str = Query(description="The name of the model to undeploy")): |
| | """ |
| | Undeploy a model |
| | """ |
| | |
| | try: |
| | undeploy_model(model_name=model_name) |
| | except RestException as e: |
| | logger.error(e) |
| |
|
| | |
| | return JSONResponse(content={"message": f"Model {model_name} not found or not in production"}, |
| | status_code=HTTP_404_NOT_FOUND) |
| |
|
| | return {"message": f"Model {model_name} undeployed from production"} |
| |
|
| | @app.get("/check_data_quality", tags=["data"], description="Check the data quality") |
| | async def check_data_quality( |
| | background_tasks: BackgroundTasks, |
| | model_name: str = Query(description="The name of the model to check"), |
| | project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"), |
| | ): |
| | """ |
| | Check the data quality |
| | """ |
| | |
| | api_key = os.getenv("EVIDENTLY_API_KEY") |
| | project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID") |
| |
|
| | |
| | if not api_key or not project_id: |
| | return JSONResponse(content={"message": "Evidently API key or project ID not set"}, |
| | status_code=HTTP_503_SERVICE_UNAVAILABLE) |
| | |
| | |
| | background_tasks.add_task(func=check_model_data, |
| | model_name=model_name, |
| | checker=DataChecker(api_key, project_id)) |
| | |
| | return {"message": "Data quality check scheduled"} |
| |
|
| | |
| | @app.get("/check_health", tags=["general"], description="Check the health of the ML module") |
| | async def check_health(session: Connection = Depends(provide_connection)): |
| | """ |
| | Check all the services in the infrastructure are working |
| | """ |
| | |
| | try: |
| | with session.cursor() as cursor: |
| | cursor.execute("SELECT 1").fetchall() |
| | except Exception as e: |
| | logger.error(f"DB check failed: {e}") |
| | return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"}, |
| | status_code=HTTP_503_SERVICE_UNAVAILABLE) |
| | |
| | |
| | if MLFLOW_SERVER_URI := os.getenv("MLFLOW_SERVER_URI"): |
| | import requests |
| |
|
| | try: |
| | |
| | response = requests.get(MLFLOW_SERVER_URI + "/health", timeout=5) |
| | if response.status_code != HTTP_200_OK: |
| | logger.error(f"Mlfow server check failed: {response.status_code}") |
| | return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"}, |
| | status_code=HTTP_503_SERVICE_UNAVAILABLE) |
| | except requests.RequestException as e: |
| | logger.error(f"Mlfow server check failed: {e}") |
| | return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"}, |
| | status_code=HTTP_503_SERVICE_UNAVAILABLE) |
| | |
| | return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK) |
| |
|