rework the data quality check
Browse files- .env.example +0 -1
- src/main.py +16 -20
- src/service/data_quality.py +58 -24
.env.example
CHANGED
|
@@ -11,4 +11,3 @@ AWS_SECRET_ACCESS_KEY=
|
|
| 11 |
# Data quality
|
| 12 |
EVIDENTLY_API_KEY=
|
| 13 |
EVIDENTLY_PROJECT_ID=
|
| 14 |
-
EVIDENTLY_REF_DATASET_ID=
|
|
|
|
| 11 |
# Data quality
|
| 12 |
EVIDENTLY_API_KEY=
|
| 13 |
EVIDENTLY_PROJECT_ID=
|
|
|
src/main.py
CHANGED
|
@@ -23,8 +23,7 @@ from dotenv import load_dotenv
|
|
| 23 |
from mlflow.exceptions import RestException
|
| 24 |
|
| 25 |
from src.entity.model import ModelInput, ModelOutput
|
| 26 |
-
from src.
|
| 27 |
-
from src.service.data_quality import DataChecker
|
| 28 |
from src.service.model import (
|
| 29 |
run_experiment,
|
| 30 |
predict,
|
|
@@ -172,10 +171,8 @@ async def deploy_model_to_production(
|
|
| 172 |
logger.error(e)
|
| 173 |
|
| 174 |
# Return HTTP error 404
|
| 175 |
-
return
|
| 176 |
-
|
| 177 |
-
detail=f"Model {model_name} (version {version}) not found"
|
| 178 |
-
)
|
| 179 |
|
| 180 |
return {"message": f"Model {model_name} deployed to production"}
|
| 181 |
|
|
@@ -191,34 +188,33 @@ async def undeploy_model_from_production(model_name: str = Query(description="Th
|
|
| 191 |
logger.error(e)
|
| 192 |
|
| 193 |
# Return HTTP error 404
|
| 194 |
-
return
|
| 195 |
-
|
| 196 |
-
detail=f"Model {model_name} not found or not in production"
|
| 197 |
-
)
|
| 198 |
|
| 199 |
return {"message": f"Model {model_name} undeployed from production"}
|
| 200 |
|
| 201 |
@app.get("/check_data_quality", tags=["data"], description="Check the data quality")
|
| 202 |
-
async def check_data_quality(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
Check the data quality
|
| 205 |
"""
|
| 206 |
# Get the API key and project ID from the environment variables
|
| 207 |
api_key = os.getenv("EVIDENTLY_API_KEY")
|
| 208 |
-
project_id = os.getenv("EVIDENTLY_PROJECT_ID")
|
| 209 |
-
ref_dataset_id = os.getenv("EVIDENTLY_REF_DATASET_ID")
|
| 210 |
|
| 211 |
# Check if the API key and project ID are set
|
| 212 |
-
if not api_key or not project_id
|
| 213 |
-
return JSONResponse(content={"
|
| 214 |
status_code=HTTP_503_SERVICE_UNAVAILABLE)
|
| 215 |
|
| 216 |
-
# Get the newest data from the database
|
| 217 |
-
df = load_model_data()
|
| 218 |
-
|
| 219 |
# Schedule the data quality check
|
| 220 |
-
background_tasks.add_task(func=
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
return {"message": "Data quality check scheduled"}
|
| 224 |
|
|
|
|
| 23 |
from mlflow.exceptions import RestException
|
| 24 |
|
| 25 |
from src.entity.model import ModelInput, ModelOutput
|
| 26 |
+
from src.service.data_quality import DataChecker, check_model_data
|
|
|
|
| 27 |
from src.service.model import (
|
| 28 |
run_experiment,
|
| 29 |
predict,
|
|
|
|
| 171 |
logger.error(e)
|
| 172 |
|
| 173 |
# Return HTTP error 404
|
| 174 |
+
return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"},
|
| 175 |
+
status_code=HTTP_404_NOT_FOUND)
|
|
|
|
|
|
|
| 176 |
|
| 177 |
return {"message": f"Model {model_name} deployed to production"}
|
| 178 |
|
|
|
|
| 188 |
logger.error(e)
|
| 189 |
|
| 190 |
# Return HTTP error 404
|
| 191 |
+
return JSONResponse(content={"message": f"Model {model_name} not found or not in production"},
|
| 192 |
+
status_code=HTTP_404_NOT_FOUND)
|
|
|
|
|
|
|
| 193 |
|
| 194 |
return {"message": f"Model {model_name} undeployed from production"}
|
| 195 |
|
| 196 |
@app.get("/check_data_quality", tags=["data"], description="Check the data quality")
|
| 197 |
+
async def check_data_quality(
|
| 198 |
+
background_tasks: BackgroundTasks,
|
| 199 |
+
model_name: str = Query(description="The name of the model to check"),
|
| 200 |
+
project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"),
|
| 201 |
+
):
|
| 202 |
"""
|
| 203 |
Check the data quality
|
| 204 |
"""
|
| 205 |
# Get the API key and project ID from the environment variables
|
| 206 |
api_key = os.getenv("EVIDENTLY_API_KEY")
|
| 207 |
+
project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID")
|
|
|
|
| 208 |
|
| 209 |
# Check if the API key and project ID are set
|
| 210 |
+
if not api_key or not project_id:
|
| 211 |
+
return JSONResponse(content={"message": "Evidently API key or project ID not set"},
|
| 212 |
status_code=HTTP_503_SERVICE_UNAVAILABLE)
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
# Schedule the data quality check
|
| 215 |
+
background_tasks.add_task(func=check_model_data,
|
| 216 |
+
model_name=model_name,
|
| 217 |
+
checker=DataChecker(api_key, project_id))
|
| 218 |
|
| 219 |
return {"message": "Data quality check scheduled"}
|
| 220 |
|
src/service/data_quality.py
CHANGED
|
@@ -1,14 +1,39 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
-
from typing import Optional
|
| 3 |
from evidently import Dataset, DataDefinition, Report
|
| 4 |
from evidently.presets import DataSummaryPreset, DataDriftPreset
|
| 5 |
from evidently.ui.workspace import CloudWorkspace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class DataChecker:
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
self._api_key = api_key
|
| 10 |
self._project_id = project_id
|
| 11 |
-
self._ref_dataset_id = ref_dataset_id
|
| 12 |
|
| 13 |
self.workspace = CloudWorkspace(
|
| 14 |
token=api_key,
|
|
@@ -23,26 +48,14 @@ class DataChecker:
|
|
| 23 |
if not self.project:
|
| 24 |
raise ValueError("Project not found. Please check your project ID.")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if not self.ref_dataset:
|
| 29 |
-
raise ValueError("Reference dataset not found. Please check your reference dataset ID.")
|
| 30 |
-
|
| 31 |
-
def check_data(self, df: pd.DataFrame, schema: Optional[DataDefinition] = None) -> str:
|
| 32 |
-
if not schema:
|
| 33 |
-
schema = DataDefinition(
|
| 34 |
-
id_column="match_id",
|
| 35 |
-
datetime_columns=["date"],
|
| 36 |
-
numerical_columns=["winner_rank", "loser_rank", "winner_points", "loser_points", "w_height_cm", "w_weight_kg",
|
| 37 |
-
"w_year_of_birth", "w_pro_year", "l_height_cm", "l_weight_kg", "l_year_of_birth", "l_pro_year"],
|
| 38 |
-
categorical_columns=["tournament_name", "tournament_series", "tournament_surface", "tournament_court",
|
| 39 |
-
"tournament_location", "winner_name", "w_first_name", "w_last_name", "w_play_hand",
|
| 40 |
-
"w_back_hand", "loser_name", "l_first_name", "l_last_name", "l_play_hand", "l_back_hand"],
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
eval_data = Dataset.from_pandas(
|
| 44 |
data=df,
|
| 45 |
-
data_definition=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
report = Report(
|
|
@@ -54,8 +67,29 @@ class DataChecker:
|
|
| 54 |
)
|
| 55 |
|
| 56 |
# Run the report
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
# Save the evaluation to the workspace
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
+
from typing import Dict, Optional, List
|
| 3 |
from evidently import Dataset, DataDefinition, Report
|
| 4 |
from evidently.presets import DataSummaryPreset, DataDriftPreset
|
| 5 |
from evidently.ui.workspace import CloudWorkspace
|
| 6 |
+
from src.repository.model_data import load_model_data
|
| 7 |
+
from src.service.model import get_training_dataset
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
class DataChecker:
|
| 14 |
+
schemas: Dict[str, DataDefinition] = {
|
| 15 |
+
"raw": DataDefinition(
|
| 16 |
+
id_column="match_id",
|
| 17 |
+
datetime_columns=["date"],
|
| 18 |
+
numerical_columns=["winner_rank", "loser_rank", "winner_points", "loser_points", "w_height_cm", "w_weight_kg",
|
| 19 |
+
"w_year_of_birth", "w_pro_year", "l_height_cm", "l_weight_kg", "l_year_of_birth", "l_pro_year"],
|
| 20 |
+
categorical_columns=["tournament_name", "tournament_series", "tournament_surface", "tournament_court",
|
| 21 |
+
"tournament_location", "winner_name", "w_first_name", "w_last_name", "w_play_hand",
|
| 22 |
+
"w_back_hand", "loser_name", "l_first_name", "l_last_name", "l_play_hand", "l_back_hand"],
|
| 23 |
+
),
|
| 24 |
+
"cleaned": DataDefinition(
|
| 25 |
+
numerical_columns=["diff_rank", "mean_rank",
|
| 26 |
+
"diff_height_cm", "mean_height_cm",
|
| 27 |
+
"diff_weight_kg", "mean_weight_kg",
|
| 28 |
+
"diff_nb_pro_years", "diff_age",
|
| 29 |
+
"diff_play_hand", "diff_back_hand"],
|
| 30 |
+
categorical_columns=["tournament_series", "tournament_surface", "tournament_court",],
|
| 31 |
+
)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def __init__(self, api_key: str, project_id: str):
|
| 35 |
self._api_key = api_key
|
| 36 |
self._project_id = project_id
|
|
|
|
| 37 |
|
| 38 |
self.workspace = CloudWorkspace(
|
| 39 |
token=api_key,
|
|
|
|
| 48 |
if not self.project:
|
| 49 |
raise ValueError("Project not found. Please check your project ID.")
|
| 50 |
|
| 51 |
+
def check_data(self, df: pd.DataFrame, ref_df: pd.DataFrame, tags: Optional[List[str]] = None) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
eval_data = Dataset.from_pandas(
|
| 53 |
data=df,
|
| 54 |
+
data_definition=self.schemas['cleaned']
|
| 55 |
+
)
|
| 56 |
+
ref_data = Dataset.from_pandas(
|
| 57 |
+
data=ref_df,
|
| 58 |
+
data_definition=self.schemas['cleaned']
|
| 59 |
)
|
| 60 |
|
| 61 |
report = Report(
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
# Run the report
|
| 70 |
+
logger.info("Running the report...")
|
| 71 |
+
my_eval = report.run(current_data=eval_data, reference_data=ref_data, tags=tags)
|
| 72 |
|
| 73 |
# Save the evaluation to the workspace
|
| 74 |
+
logger.info("Saving the evaluation to the workspace...")
|
| 75 |
+
snapshot_id = self.workspace.add_run(self.project.id, my_eval, include_data=False)
|
| 76 |
+
|
| 77 |
+
logger.info(f"Evaluation saved with snapshot ID: {snapshot_id}")
|
| 78 |
+
|
| 79 |
+
return snapshot_id
|
| 80 |
+
|
| 81 |
+
def check_model_data(
|
| 82 |
+
model_name: str,
|
| 83 |
+
checker: DataChecker,
|
| 84 |
+
) -> str:
|
| 85 |
+
"""
|
| 86 |
+
Check the model data using Evidently.
|
| 87 |
+
"""
|
| 88 |
+
# Get the newest data from the database
|
| 89 |
+
df = load_model_data()
|
| 90 |
+
|
| 91 |
+
# Get the training dataset
|
| 92 |
+
ref_df = get_training_dataset(model_name=model_name)
|
| 93 |
+
|
| 94 |
+
# Check the data
|
| 95 |
+
return checker.check_data(df, ref_df, tags=[model_name])
|