nivakaran commited on
Commit
23606de
·
verified ·
1 Parent(s): bca442e

Create utils.py

Browse files
Files changed (1) hide show
  1. src/utils/main_utils/utils.py +102 -0
src/utils/main_utils/utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from src.exception.exception import DeliveryTimeException
3
+ from src.logging.logger import logging
4
+ import os, sys
5
+ import numpy as np
6
+ import pickle
7
+
8
+ from sklearn.metrics import r2_score
9
+ from sklearn.model_selection import GridSearchCV
10
+
11
+ def read_yaml_file(file_path:str) -> dict:
12
+ try:
13
+ with open(file_path, 'rb') as yaml_file:
14
+ return yaml.safe_load(yaml_file)
15
+
16
+ except Exception as e:
17
+ DeliveryTimeException(e, sys)
18
+
19
+ def write_yaml_file(file_path:str, content:object, replace:bool=False) -> None:
20
+ try:
21
+ if replace:
22
+ if os.path.exists(file_path):
23
+ os.remove(file_path)
24
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
25
+ with open(file_path, 'w') as file:
26
+ yaml.dump(content, file)
27
+
28
+ except Exception as e:
29
+ raise DeliveryTimeException(e, sys)
30
+
31
+ def save_numpy_array_data(file_path:str, array:np.array):
32
+ """
33
+ Save numpy array data to file
34
+ file_path: str location of file to save
35
+ array:np.array data to save
36
+ """
37
+ try:
38
+ dir_path=os.path.dirname(file_path)
39
+ os.makedirs(dir_path, exist_ok=True)
40
+ with open(file_path, 'wb') as file_obj:
41
+ np.save(file_obj, array)
42
+ except Exception as e:
43
+ raise DeliveryTimeException(e, sys)
44
+
45
+ def save_object(file_path:str, obj:object) -> None:
46
+ try:
47
+ logging.info("Entered the save_object method of MainUtils class")
48
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
49
+ with open(file_path, "wb") as file_obj:
50
+ pickle.dump(obj, file_obj)
51
+ logging.info("Exited the save_object method of MainUtils class")
52
+ except Exception as e:
53
+ raise DeliveryTimeException(e, sys)
54
+
55
+ def load_object(file_path:str) ->object:
56
+ try:
57
+ if not os.path.exists(file_path):
58
+ raise Exception(f"The file: {file_path} does not exist")
59
+ with open(file_path, 'rb') as file_obj:
60
+ print(file_obj)
61
+ return pickle.load(file_obj)
62
+ except Exception as e:
63
+ raise DeliveryTimeException(e, sys)
64
+
65
+ def load_numpy_array_data(file_path:str) -> np.array:
66
+ """
67
+ Load numpy array data from file
68
+ file_path: str location of file to load
69
+ return: np.array data loaded
70
+ """
71
+ try:
72
+ with open(file_path, 'rb') as file_obj:
73
+ return np.load(file_obj)
74
+ except Exception as e:
75
+ raise DeliveryTimeException(e, sys)
76
+
77
+ def evaluate_models(X_train, y_train, X_test, y_test, models, param):
78
+ try:
79
+ report = {}
80
+
81
+ for i in range(len(list(models))):
82
+ model = list(models.values())[i]
83
+ para = param[list(models.keys())[i]]
84
+
85
+ gs = GridSearchCV(model, para, cv=3)
86
+ gs.fit(X_train, y_train)
87
+
88
+ model.set_params(**gs.best_params_)
89
+ model.fit(X_train, y_train)
90
+
91
+ y_train_pred = model.predict(X_train)
92
+ y_test_pred = model.predict(X_test)
93
+
94
+ test_model_score = r2_score(y_test, y_test_pred)
95
+ report[list(models.keys())[i]] = test_model_score
96
+
97
+
98
+ return report
99
+
100
+ except Exception as e:
101
+ raise DeliveryTimeException(e, sys)
102
+