SlimG commited on
Commit
5ef0c38
·
1 Parent(s): 863ede9

rework the data quality check

Browse files
Files changed (3) hide show
  1. .env.example +0 -1
  2. src/main.py +16 -20
  3. src/service/data_quality.py +58 -24
.env.example CHANGED
@@ -11,4 +11,3 @@ AWS_SECRET_ACCESS_KEY=
11
  # Data quality
12
  EVIDENTLY_API_KEY=
13
  EVIDENTLY_PROJECT_ID=
14
- EVIDENTLY_REF_DATASET_ID=
 
11
  # Data quality
12
  EVIDENTLY_API_KEY=
13
  EVIDENTLY_PROJECT_ID=
 
src/main.py CHANGED
@@ -23,8 +23,7 @@ from dotenv import load_dotenv
23
  from mlflow.exceptions import RestException
24
 
25
  from src.entity.model import ModelInput, ModelOutput
26
- from src.repository.model_data import load_model_data
27
- from src.service.data_quality import DataChecker
28
  from src.service.model import (
29
  run_experiment,
30
  predict,
@@ -172,10 +171,8 @@ async def deploy_model_to_production(
172
  logger.error(e)
173
 
174
  # Return HTTP error 404
175
- return HTTPException(
176
- status_code=HTTP_404_NOT_FOUND,
177
- detail=f"Model {model_name} (version {version}) not found"
178
- )
179
 
180
  return {"message": f"Model {model_name} deployed to production"}
181
 
@@ -191,34 +188,33 @@ async def undeploy_model_from_production(model_name: str = Query(description="Th
191
  logger.error(e)
192
 
193
  # Return HTTP error 404
194
- return HTTPException(
195
- status_code=HTTP_404_NOT_FOUND,
196
- detail=f"Model {model_name} not found or not in production"
197
- )
198
 
199
  return {"message": f"Model {model_name} undeployed from production"}
200
 
201
  @app.get("/check_data_quality", tags=["data"], description="Check the data quality")
202
- async def check_data_quality(background_tasks: BackgroundTasks):
 
 
 
 
203
  """
204
  Check the data quality
205
  """
206
  # Get the API key and project ID from the environment variables
207
  api_key = os.getenv("EVIDENTLY_API_KEY")
208
- project_id = os.getenv("EVIDENTLY_PROJECT_ID")
209
- ref_dataset_id = os.getenv("EVIDENTLY_REF_DATASET_ID")
210
 
211
  # Check if the API key and project ID are set
212
- if not api_key or not project_id or not ref_dataset_id:
213
- return JSONResponse(content={"status": "unhealthy", "detail": "Evidently API key or project ID not set"},
214
  status_code=HTTP_503_SERVICE_UNAVAILABLE)
215
 
216
- # Get the newest data from the database
217
- df = load_model_data()
218
-
219
  # Schedule the data quality check
220
- background_tasks.add_task(func=DataChecker(api_key, project_id, ref_dataset_id).check_data,
221
- df=df)
 
222
 
223
  return {"message": "Data quality check scheduled"}
224
 
 
23
  from mlflow.exceptions import RestException
24
 
25
  from src.entity.model import ModelInput, ModelOutput
26
+ from src.service.data_quality import DataChecker, check_model_data
 
27
  from src.service.model import (
28
  run_experiment,
29
  predict,
 
171
  logger.error(e)
172
 
173
  # Return HTTP error 404
174
+ return JSONResponse(content={"message": f"Model {model_name} (version {version}) not found"},
175
+ status_code=HTTP_404_NOT_FOUND)
 
 
176
 
177
  return {"message": f"Model {model_name} deployed to production"}
178
 
 
188
  logger.error(e)
189
 
190
  # Return HTTP error 404
191
+ return JSONResponse(content={"message": f"Model {model_name} not found or not in production"},
192
+ status_code=HTTP_404_NOT_FOUND)
 
 
193
 
194
  return {"message": f"Model {model_name} undeployed from production"}
195
 
196
  @app.get("/check_data_quality", tags=["data"], description="Check the data quality")
197
+ async def check_data_quality(
198
+ background_tasks: BackgroundTasks,
199
+ model_name: str = Query(description="The name of the model to check"),
200
+ project_id: Optional[str] = Query(default=None, description="The ID of the project to send the data quality report to"),
201
+ ):
202
  """
203
  Check the data quality
204
  """
205
  # Get the API key and project ID from the environment variables
206
  api_key = os.getenv("EVIDENTLY_API_KEY")
207
+ project_id = project_id or os.getenv("EVIDENTLY_PROJECT_ID")
 
208
 
209
  # Check if the API key and project ID are set
210
+ if not api_key or not project_id:
211
+ return JSONResponse(content={"message": "Evidently API key or project ID not set"},
212
  status_code=HTTP_503_SERVICE_UNAVAILABLE)
213
 
 
 
 
214
  # Schedule the data quality check
215
+ background_tasks.add_task(func=check_model_data,
216
+ model_name=model_name,
217
+ checker=DataChecker(api_key, project_id))
218
 
219
  return {"message": "Data quality check scheduled"}
220
 
src/service/data_quality.py CHANGED
@@ -1,14 +1,39 @@
1
  import pandas as pd
2
- from typing import Optional
3
  from evidently import Dataset, DataDefinition, Report
4
  from evidently.presets import DataSummaryPreset, DataDriftPreset
5
  from evidently.ui.workspace import CloudWorkspace
 
 
 
 
 
 
6
 
7
  class DataChecker:
8
- def __init__(self, api_key: str, project_id: str, ref_dataset_id: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  self._api_key = api_key
10
  self._project_id = project_id
11
- self._ref_dataset_id = ref_dataset_id
12
 
13
  self.workspace = CloudWorkspace(
14
  token=api_key,
@@ -23,26 +48,14 @@ class DataChecker:
23
  if not self.project:
24
  raise ValueError("Project not found. Please check your project ID.")
25
 
26
- self.ref_dataset = self.workspace.load_dataset(ref_dataset_id)
27
-
28
- if not self.ref_dataset:
29
- raise ValueError("Reference dataset not found. Please check your reference dataset ID.")
30
-
31
- def check_data(self, df: pd.DataFrame, schema: Optional[DataDefinition] = None) -> str:
32
- if not schema:
33
- schema = DataDefinition(
34
- id_column="match_id",
35
- datetime_columns=["date"],
36
- numerical_columns=["winner_rank", "loser_rank", "winner_points", "loser_points", "w_height_cm", "w_weight_kg",
37
- "w_year_of_birth", "w_pro_year", "l_height_cm", "l_weight_kg", "l_year_of_birth", "l_pro_year"],
38
- categorical_columns=["tournament_name", "tournament_series", "tournament_surface", "tournament_court",
39
- "tournament_location", "winner_name", "w_first_name", "w_last_name", "w_play_hand",
40
- "w_back_hand", "loser_name", "l_first_name", "l_last_name", "l_play_hand", "l_back_hand"],
41
- )
42
-
43
  eval_data = Dataset.from_pandas(
44
  data=df,
45
- data_definition=schema
 
 
 
 
46
  )
47
 
48
  report = Report(
@@ -54,8 +67,29 @@ class DataChecker:
54
  )
55
 
56
  # Run the report
57
- my_eval = report.run(eval_data, self.ref_dataset)
 
58
 
59
  # Save the evaluation to the workspace
60
- return self.workspace.add_run(self.project.id, my_eval, include_data=False)
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ from typing import Dict, Optional, List
3
  from evidently import Dataset, DataDefinition, Report
4
  from evidently.presets import DataSummaryPreset, DataDriftPreset
5
  from evidently.ui.workspace import CloudWorkspace
6
+ from src.repository.model_data import load_model_data
7
+ from src.service.model import get_training_dataset
8
+
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
 
13
  class DataChecker:
14
+ schemas: Dict[str, DataDefinition] = {
15
+ "raw": DataDefinition(
16
+ id_column="match_id",
17
+ datetime_columns=["date"],
18
+ numerical_columns=["winner_rank", "loser_rank", "winner_points", "loser_points", "w_height_cm", "w_weight_kg",
19
+ "w_year_of_birth", "w_pro_year", "l_height_cm", "l_weight_kg", "l_year_of_birth", "l_pro_year"],
20
+ categorical_columns=["tournament_name", "tournament_series", "tournament_surface", "tournament_court",
21
+ "tournament_location", "winner_name", "w_first_name", "w_last_name", "w_play_hand",
22
+ "w_back_hand", "loser_name", "l_first_name", "l_last_name", "l_play_hand", "l_back_hand"],
23
+ ),
24
+ "cleaned": DataDefinition(
25
+ numerical_columns=["diff_rank", "mean_rank",
26
+ "diff_height_cm", "mean_height_cm",
27
+ "diff_weight_kg", "mean_weight_kg",
28
+ "diff_nb_pro_years", "diff_age",
29
+ "diff_play_hand", "diff_back_hand"],
30
+ categorical_columns=["tournament_series", "tournament_surface", "tournament_court",],
31
+ )
32
+ }
33
+
34
+ def __init__(self, api_key: str, project_id: str):
35
  self._api_key = api_key
36
  self._project_id = project_id
 
37
 
38
  self.workspace = CloudWorkspace(
39
  token=api_key,
 
48
  if not self.project:
49
  raise ValueError("Project not found. Please check your project ID.")
50
 
51
+ def check_data(self, df: pd.DataFrame, ref_df: pd.DataFrame, tags: Optional[List[str]] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  eval_data = Dataset.from_pandas(
53
  data=df,
54
+ data_definition=self.schemas['cleaned']
55
+ )
56
+ ref_data = Dataset.from_pandas(
57
+ data=ref_df,
58
+ data_definition=self.schemas['cleaned']
59
  )
60
 
61
  report = Report(
 
67
  )
68
 
69
  # Run the report
70
+ logger.info("Running the report...")
71
+ my_eval = report.run(current_data=eval_data, reference_data=ref_data, tags=tags)
72
 
73
  # Save the evaluation to the workspace
74
+ logger.info("Saving the evaluation to the workspace...")
75
+ snapshot_id = self.workspace.add_run(self.project.id, my_eval, include_data=False)
76
+
77
+ logger.info(f"Evaluation saved with snapshot ID: {snapshot_id}")
78
+
79
+ return snapshot_id
80
+
81
+ def check_model_data(
82
+ model_name: str,
83
+ checker: DataChecker,
84
+ ) -> str:
85
+ """
86
+ Check the model data using Evidently.
87
+ """
88
+ # Get the newest data from the database
89
+ df = load_model_data()
90
+
91
+ # Get the training dataset
92
+ ref_df = get_training_dataset(model_name=model_name)
93
+
94
+ # Check the data
95
+ return checker.check_data(df, ref_df, tags=[model_name])