nivakaran commited on
Commit
8d044e1
·
verified ·
1 Parent(s): c304002

Create pipeline/training_pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline/training_pipeline.py +122 -0
src/pipeline/training_pipeline.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from src.exception.exception import DeliveryTimeException
5
+ from src.logging.logger import logging
6
+
7
+ from src.components.data_ingestion import DataIngestion
8
+ from src.components.data_validation import DataValidation
9
+ from src.components.data_transformation import DataTransformation
10
+ from src.components.model_trainer import ModelTrainer
11
+
12
+ from src.entity.config_entity import (
13
+ TrainingPipelineConfig,
14
+ DataIngestionConfig,
15
+ DataValidationConfig,
16
+ DataTransformationConfig,
17
+ ModelTrainerConfig
18
+ )
19
+
20
+ from src.entity.artifact_entity import (
21
+ DataIngestionArtifact,
22
+ DataValidationArtifact,
23
+ DataTransformationArtifact,
24
+ ModelTrainerArtifact
25
+ )
26
+
27
+ from src.constants.training_pipeline import TRAINING_BUCKET_NAME
28
+ from src.cloud.s3_syncer import S3Sync
29
+ from src.constants.training_pipeline import SAVED_MODEL_DIR
30
+ import sys
31
+
32
+ class TrainingPipeline:
33
+ def __init__(self):
34
+ self.training_pipeline_config=TrainingPipelineConfig()
35
+ self.s3_sync = S3Sync()
36
+
37
+ def start_data_ingestion(self):
38
+ try:
39
+ self.data_ingestion_config=DataIngestionConfig(training_pipeline_config=self.training_pipeline_config)
40
+ logging.info("Start data Ingestion")
41
+ data_ingestion=DataIngestion(data_ingestion_config=self.data_ingestion_config)
42
+ data_ingestion_artifact=data_ingestion.initiate_date_ingestion()
43
+ logging.info(f"Data Ingestion completed with artifact: {data_ingestion_artifact}")
44
+ return data_ingestion_artifact
45
+ except Exception as e:
46
+ raise DeliveryTimeException(e, sys)
47
+
48
+ def start_data_validation(self, data_ingestion_artifact:DataIngestionArtifact):
49
+ try:
50
+ data_validation_config=DataValidationConfig(training_pipeline_config=self.training_pipeline_config)
51
+ data_validation=DataValidation(data_ingestion_artifact=data_ingestion_artifact, data_validation_config=data_validation_config)
52
+ logging.info("Initiate data validation")
53
+ data_validation_artifact=data_validation.initiate_data_validation()
54
+ return data_validation_artifact
55
+ except Exception as e:
56
+ raise DeliveryTimeException(e, sys)
57
+
58
+ def start_data_transformation(self, data_validation_artifact:DataValidationArtifact):
59
+ try:
60
+ data_transformation_config=DataTransformationConfig(training_pipeline_config=self.training_pipeline_config)
61
+ data_transformation=DataTransformation(data_validation_artifact=data_validation_artifact,
62
+ data_transformation_config=data_transformation_config)
63
+ data_transformation_artifact=data_transformation.initiate_data_transformation()
64
+ return data_transformation_artifact
65
+ except Exception as e:
66
+ raise DeliveryTimeException(e, sys)
67
+
68
+ def start_model_trainer(self, data_transformation_artifact:DataTransformationArtifact) -> ModelTrainerArtifact:
69
+ try:
70
+ self.model_trainer_config:ModelTrainerConfig = ModelTrainerConfig(
71
+ training_pipeline_config=self.training_pipeline_config
72
+ )
73
+
74
+ model_trainer=ModelTrainer(
75
+ data_transformation_artifact=data_transformation_artifact,
76
+ model_trainer_config=self.model_trainer_config
77
+ )
78
+
79
+ model_trainer_artifact=model_trainer.initiate_model_trainer()
80
+
81
+ return model_trainer_artifact
82
+ except Exception as e:
83
+ raise DeliveryTimeException(e, sys)
84
+
85
+
86
+ def sync_artifact_dir_to_s3(self):
87
+ try:
88
+ aws_bucket_url = f"s3://{TRAINING_BUCKET_NAME}/artifact/{self.training_pipeline_config.timestamp}"
89
+ self.s3_sync.sync_folder_to_s3(folder=self.training_pipeline_config.artifact_dir, aws_bucket_url=aws_bucket_url)
90
+
91
+ except Exception as e:
92
+ raise DeliveryTimeException(e, sys)
93
+
94
+ def sync_saved_model_dir_to_s3(self):
95
+ try:
96
+ aws_bucket_url = f"s3://{TRAINING_BUCKET_NAME}/final_model/{self.training_pipeline_config.timestamp}"
97
+ self.s3_sync.sync_folder_to_s3(folder=self.training_pipeline_config.model_dir, aws_bucket_url=aws_bucket_url)
98
+ except Exception as e:
99
+ raise DeliveryTimeException(e, sys)
100
+
101
+
102
+ def run_pipeline(self):
103
+ try:
104
+ data_ingestion_artifact=self.start_data_ingestion()
105
+ data_validation_artifact=self.start_data_validation(data_ingestion_artifact=data_ingestion_artifact)
106
+ data_transformation_artifact=self.start_data_transformation(data_validation_artifact=data_validation_artifact)
107
+ model_trainer_artifact=self.start_model_trainer(data_transformation_artifact=data_transformation_artifact)
108
+
109
+ self.sync_artifact_dir_to_s3()
110
+ self.sync_saved_model_dir_to_s3()
111
+
112
+ return model_trainer_artifact
113
+
114
+ except Exception as e:
115
+ raise DeliveryTimeException(e, sys)
116
+
117
+ if __name__ == "__main__":
118
+ try:
119
+ training_pipeline = TrainingPipeline()
120
+ training_pipeline.run_pipeline()
121
+ except Exception as e:
122
+ raise DeliveryTimeException(e, sys)