File size: 9,394 Bytes
a738b50
 
 
 
ea7db66
a738b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a395d5
5ef0c38
a738b50
 
 
 
 
851f0b2
 
a738b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f5aca
 
 
a738b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8027bf
a738b50
 
 
 
 
0192d50
a738b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea7db66
 
a738b50
 
 
ea7db66
a738b50
851f0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef0c38
 
851f0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef0c38
 
851f0b2
 
 
4bb489e
5ef0c38
 
 
 
 
4bb489e
 
 
 
 
5ef0c38
4bb489e
 
5ef0c38
 
4bb489e
 
 
5ef0c38
 
 
4bb489e
 
 
a738b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import joblib
import logging
import secrets
from typing import Generator, Optional, Annotated, List
from fastapi import (
    FastAPI,
    Request,
    HTTPException,
    Query,
    Security,
    Depends
)
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.background import BackgroundTasks
from fastapi.security.api_key import APIKeyHeader
from starlette.status import (
    HTTP_200_OK,
    HTTP_403_FORBIDDEN,
    HTTP_404_NOT_FOUND,
    HTTP_503_SERVICE_UNAVAILABLE)
from dotenv import load_dotenv
from mlflow.exceptions import RestException

from src.entity.model import ModelInput, ModelOutput
from src.service.data_quality import DataChecker, check_model_data
from src.service.model import (
    run_experiment,
    predict,
    list_registered_models,
    load_model,
    deploy_model,
    undeploy_model,
)
from src.repository.common import get_connection
from psycopg import Connection

load_dotenv()

logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler()])
logger = logging.getLogger(__name__)

def provide_connection() -> Generator[Connection, None, None]:
    with get_connection() as conn:
        yield conn

# ------------------------------------------------------------------------------

FASTAPI_API_KEY = os.getenv("FASTAPI_API_KEY")
safe_clients = ['127.0.0.1']

api_key_header = APIKeyHeader(name='Authorization', auto_error=False)

async def validate_api_key(request: Request, key: str = Security(api_key_header)):
    '''
    Check if the API key is valid

    Args:
        key (str): The API key to check
    
    Raises:
        HTTPException: If the API key is invalid
    '''
    if request.client.host not in safe_clients and not secrets.compare_digest(str(key), str(FASTAPI_API_KEY)):
        raise HTTPException(
            status_code=HTTP_403_FORBIDDEN, detail="Unauthorized - API Key is wrong"
        )
    return None

app = FastAPI(dependencies=[Depends(validate_api_key)] if FASTAPI_API_KEY else None,
              title="Tennis Insights ML API",
              description="API for the Tennis Insights ML module",)

# ------------------------------------------------------------------------------
@app.get("/", include_in_schema=False)
def redirect_to_docs():
    '''
    Redirect to the API documentation.
    '''
    return RedirectResponse(url='/docs')

@app.get("/run_experiment", tags=["model"], description="Schedule a run of the ML experiment")
async def run_xp(background_tasks: BackgroundTasks,
                algo: str = Query(default="LogisticRegression", description="The algorithm to use for training"),
                registered_model_name: Optional[str] = Query(default=None, description="The name of the registered model"),
                experiment_name: Optional[str] = Query(default="Tennis Prediction", description="The name of the experiment")):
    """
    Train the model
    """
    background_tasks.add_task(func=run_experiment,
                              algo=algo,
                              registered_model_name=registered_model_name,
                              experiment_name=experiment_name,)
    
    return {"message": "Experiment scheduled"}

@app.get("/predict",
         tags=["model"],
         description="Predict the outcome of a tennis match",
         response_model=ModelOutput)
async def make_prediction(params: Annotated[ModelInput, Query()]):
    """
    Predict the matches
    """
    if not params.model:
        # check the presence of 'model.pkl' file in data/
        if not os.path.exists("/data/model.pkl"):
            return {"message": "Model not trained. Please train the model first."}
    
        # Load the model
        pipeline = joblib.load("/data/model.pkl")
    else:
        # Get the model info
        try:
            pipeline = load_model(name=params.model, alias=params.alias)
        except RestException as e:
            logger.error(e)

            # Return HTTP error 404
            return HTTPException(
                status_code=HTTP_404_NOT_FOUND,
                detail=f"Model {params.model} not found"
            )

    # Make the prediction
    prediction = predict(
        pipeline=pipeline,
        series=params.series,
        surface=params.surface,
        court=params.court,
        p1_rank=params.p1_rank,
        p1_play_hand=params.p1_play_hand,
        p1_back_hand=params.p1_back_hand,
        p1_height=params.p1_height,
        p1_weight=params.p1_weight,
        p1_year_of_birth=params.p1_year_of_birth,
        p1_pro_year=params.p1_pro_year,
        p2_rank=params.p2_rank,
        p2_play_hand=params.p2_play_hand,
        p2_back_hand=params.p2_back_hand,
        p2_height=params.p2_height,
        p2_weight=params.p2_weight,
        p2_year_of_birth=params.p2_year_of_birth,
        p2_pro_year=params.p2_pro_year,
    )

    logger.info(prediction)

    return prediction

@app.get("/list_available_models", tags=["model"], description="List the available models")
async def list_available_models(
    aliases: Optional[List[str]] = Query(default=None, description="List of model aliases to filter the models")):
    """
    List the available models
    """
    return list_registered_models(alias_filter=aliases)

@app.post("/deploy_model", tags=["model"], description="Deploy a model")
async def deploy_model_to_production(
    model_name: str = Query(description="The name of the model to deploy"),
    version: str = Query(description="The version of the model to deploy")):
    """
    Deploy a model
    """
    # Deploy the model
    try:
        deploy_model(model_name=model_name, model_version=version)
    except RestException as e:
        logger.error(e)

        # Return HTTP error 404
        return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"},
                            status_code=HTTP_404_NOT_FOUND)

    return {"message": f"Model {model_name} deployed to production"}

@app.post("/undeploy_model", tags=["model"], description="Undeploy a model")
async def undeploy_model_from_production(model_name: str = Query(description="The name of the model to undeploy")):
    """
    Undeploy a model
    """
    # Undeploy the model
    try:
        undeploy_model(model_name=model_name)
    except RestException as e:
        logger.error(e)

        # Return HTTP error 404
        return JSONResponse(content={"message": f"Model {model_name} not found or not in production"},
                            status_code=HTTP_404_NOT_FOUND)

    return {"message": f"Model {model_name} undeployed from production"}

@app.get("/check_data_quality", tags=["data"], description="Check the data quality")
async def check_data_quality(
    background_tasks: BackgroundTasks,
    model_name: str = Query(description="The name of the model to check"),
    project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"),
):
    """
    Check the data quality
    """
    # Get the API key and project ID from the environment variables
    api_key = os.getenv("EVIDENTLY_API_KEY")
    project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID")

    # Check if the API key and project ID are set
    if not api_key or not project_id:
        return JSONResponse(content={"message": "Evidently API key or project ID not set"},
                            status_code=HTTP_503_SERVICE_UNAVAILABLE)
    
    # Schedule the data quality check
    background_tasks.add_task(func=check_model_data,
                              model_name=model_name,
                              checker=DataChecker(api_key, project_id))
    
    return {"message": "Data quality check scheduled"}

# ------------------------------------------------------------------------------
@app.get("/check_health", tags=["general"], description="Check the health of the ML module")
async def check_health(session: Connection = Depends(provide_connection)):
    """
    Check all the services in the infrastructure are working
    """
    # Check if the database is alive
    try:
        with session.cursor() as cursor:
            cursor.execute("SELECT 1").fetchall()
    except Exception as e:
        logger.error(f"DB check failed: {e}")
        return JSONResponse(content={"status": "unhealthy", "detail": "Database not reachable"},
                            status_code=HTTP_503_SERVICE_UNAVAILABLE)
    
    # Check if the mlflow endpoint is reachable
    if MLFLOW_SERVER_URI := os.getenv("MLFLOW_SERVER_URI"):
        import requests

        try:
            # Ping the mlflow server endpoint
            response = requests.get(MLFLOW_SERVER_URI + "/health", timeout=5)
            if response.status_code != HTTP_200_OK:
                logger.error(f"Mlfow server check failed: {response.status_code}")
                return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
                                    status_code=HTTP_503_SERVICE_UNAVAILABLE)
        except requests.RequestException as e:
            logger.error(f"Mlfow server check failed: {e}")
            return JSONResponse(content={"status": "unhealthy", "detail": "Mlfow server not reachable"},
                                status_code=HTTP_503_SERVICE_UNAVAILABLE)
    
    return JSONResponse(content={"status": "healthy"}, status_code=HTTP_200_OK)