nivakaran commited on
Commit
122a65f
·
verified ·
1 Parent(s): 7080f90

Create data_validation.py

Browse files
Files changed (1) hide show
  1. src/components/data_validation.py +113 -0
src/components/data_validation.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.entity.artifact_entity import DataIngestionArtifact, DataValidationArtifact
2
+ from src.entity.config_entity import DataValidationConfig
3
+ from src.exception.exception import DeliveryTimeException
4
+ from src.logging.logger import logging
5
+ from src.constants.training_pipeline import SCHEMA_FILE_PATH
6
+ from scipy.stats import ks_2samp
7
+ import pandas as pd
8
+ import os, sys
9
+ from src.utils.main_utils.utils import read_yaml_file, write_yaml_file
10
+
11
+
12
+ class DataValidation:
13
+ def __init__(self, data_ingestion_artifact:DataIngestionArtifact,
14
+ data_validation_config:DataValidationConfig):
15
+ try:
16
+ self.data_ingestion_artifact=data_ingestion_artifact
17
+ self.data_validation_config=data_validation_config
18
+ self._schema_config=read_yaml_file(SCHEMA_FILE_PATH)
19
+ except Exception as e:
20
+ raise DeliveryTimeException(e, sys)
21
+
22
+
23
+ @staticmethod
24
+ def read_data(file_path)->pd.DataFrame:
25
+ try:
26
+ return pd.read_csv(file_path)
27
+ except Exception as e:
28
+ raise DeliveryTimeException(e, sys)
29
+
30
+ def validate_number_of_columns(self, dataframe:pd.DataFrame)->bool:
31
+ try:
32
+ number_of_columns=len(self._schema_config)
33
+ logging.info(f"Required number of columns: {number_of_columns}")
34
+ logging.info(f"Data frame has columns: {len(dataframe.columns)}")
35
+ if len(dataframe.columns) == number_of_columns:
36
+ return True
37
+ return False
38
+
39
+ except Exception as e:
40
+ raise DeliveryTimeException(e, sys)
41
+
42
+ def detect_dataset_drift(self, base_df, current_df, threshold=0.05)->bool:
43
+ try:
44
+ status=True
45
+ report={}
46
+ for column in base_df.columns:
47
+ d1=base_df[column]
48
+ d2=current_df[column]
49
+ is_same_dist=ks_2samp(d1, d2)
50
+ if threshold <= is_same_dist.pvalue:
51
+ is_found=False
52
+ else:
53
+ is_found=True
54
+ status=False
55
+ report.update({column:{
56
+ "p_value":float(is_same_dist.pvalue),
57
+ "drift_status":is_found
58
+ }})
59
+
60
+ drift_report_file_path=self.data_validation_config.drift_report_file_path
61
+
62
+ # Create directory
63
+ dir_path=os.path.dirname(drift_report_file_path)
64
+ os.makedirs(dir_path, exist_ok=True)
65
+ write_yaml_file(file_path=drift_report_file_path, content=report)
66
+
67
+ except Exception as e:
68
+ raise DeliveryTimeException(e, sys)
69
+
70
+
71
+ def initiate_data_validation(self)->DataValidationArtifact:
72
+ try:
73
+ train_file_path=self.data_ingestion_artifact.trained_file_path
74
+ test_file_path=self.data_ingestion_artifact.test_file_path
75
+
76
+ # Read the data from train and test
77
+ train_dataframe=DataValidation.read_data(train_file_path)
78
+ test_dataframe=DataValidation.read_data(test_file_path)
79
+
80
+ # Validate the numer of columns
81
+ status=self.validate_number_of_columns(dataframe=train_dataframe)
82
+ if not status:
83
+ error_message=f"Train dataframe does not contain all columns. \n"
84
+ status = self.validate_number_of_columns(dataframe=test_dataframe)
85
+ if not status:
86
+ error_message=f"Test dataframe does not contain all columns.\n"
87
+
88
+ # Let's check datadrift
89
+ status=self.detect_dataset_drift(base_df=train_dataframe, current_df=test_dataframe)
90
+ dir_path=os.path.dirname(self.data_validation_config.valid_train_file_path)
91
+ os.makedirs(dir_path, exist_ok=True)
92
+
93
+ train_dataframe.to_csv(
94
+ self.data_validation_config.valid_train_file_path, index=False, header=True
95
+ )
96
+
97
+ test_dataframe.to_csv(
98
+ self.data_validation_config.valid_test_file_path, index=False, header=True
99
+ )
100
+
101
+ data_validation_artifact=DataValidationArtifact(
102
+ validation_status=status,
103
+ valid_train_file_path=self.data_ingestion_artifact.trained_file_path,
104
+ valid_test_file_path=self.data_ingestion_artifact.test_file_path,
105
+ invalid_train_file_path=None,
106
+ invalid_test_file_path=None,
107
+ drift_report_file_path=self.data_validation_config.drift_report_file_path,
108
+ )
109
+
110
+ return data_validation_artifact
111
+
112
+ except Exception as e:
113
+ raise DeliveryTimeException(e, sys)