InsafQ commited on
Commit
93e8299
·
verified ·
1 Parent(s): b2faf0e

Add tabgan/adversarial_model.py

Browse files
Files changed (1) hide show
  1. tabgan/adversarial_model.py +223 -0
tabgan/adversarial_model.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from lightgbm import LGBMClassifier
4
+ from scipy.stats import rankdata
5
+ from sklearn.metrics import roc_auc_score
6
+ from sklearn.model_selection import StratifiedKFold
7
+ import warnings
8
+
9
+ from tabgan.encoders import MultipleEncoder, DoubleValidationEncoderNumerical
10
+
11
+ warnings.filterwarnings("ignore", message="No further splits with positive gain")
12
+
13
+
14
+ class AdversarialModel:
15
+ def __init__(
16
+ self,
17
+ cat_validation="Single",
18
+ encoders_names=("OrdinalEncoder",),
19
+ cat_cols=None,
20
+ model_validation=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
21
+ model_params=None,
22
+ ):
23
+ """
24
+ Class for fit predicting tabular models, mostly - boosting. Several encoders for categorical features are
25
+ supported
26
+
27
+ Args:
28
+ cat_validation: categorical type of validation, examples: "None", "Single" and "Double"
29
+ encoders_names: different categorical encoders from category_encoders library, example CatBoostEncoder
30
+ cat_cols: list of categorical columns
31
+ model_validation: model training cross validation type from sklearn.model_selection,
32
+ example StratifiedKFold(5)
33
+ model_params: model training hyperparameters
34
+ """
35
+ self.metrics = None
36
+ self.trained_model = None
37
+ self.cat_validation = cat_validation
38
+ self.encoders_names = encoders_names
39
+ self.cat_cols = cat_cols
40
+ self.model_validation = model_validation
41
+ self.model_params = model_params
42
+
43
+ def adversarial_test(self, left_df, right_df):
44
+ """
45
+ Trains adversarial model to distinguish train from test
46
+ :param left_df: dataframe
47
+ :param right_df: dataframe
48
+ :return: trained model
49
+ """
50
+ # sample to shuffle the data
51
+ left_df = left_df.copy().sample(frac=1).reset_index(drop=True)
52
+ right_df = right_df.copy().sample(frac=1).reset_index(drop=True)
53
+
54
+ left_df = left_df.head(right_df.shape[0])
55
+ right_df = right_df.head(left_df.shape[0])
56
+
57
+ left_df["gt"] = 0
58
+ right_df["gt"] = 1
59
+
60
+ concated = pd.concat([left_df, right_df], ignore_index=True)
61
+ lgb_model = Model(
62
+ cat_validation=self.cat_validation,
63
+ encoders_names=self.encoders_names,
64
+ cat_cols=self.cat_cols,
65
+ model_validation=self.model_validation,
66
+ model_params=self.model_params,
67
+ )
68
+ train_score, val_score, avg_num_trees = lgb_model.fit(
69
+ concated.drop("gt", axis=1), concated["gt"]
70
+ )
71
+ self.metrics = {"train_score": train_score,
72
+ "val_score": val_score,
73
+ "avg_num_trees": avg_num_trees}
74
+ self.trained_model = lgb_model
75
+
76
+
77
+ class Model:
78
+ def __init__(
79
+ self,
80
+ cat_validation="None",
81
+ encoders_names=None,
82
+ cat_cols=None,
83
+ model_validation=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
84
+ model_params=None,
85
+ ):
86
+ """
87
+ Class for fit predicting tabular models, mostly - boosting. Several encoders for categorical features are
88
+ supported
89
+
90
+ Args:
91
+ cat_validation: categorical type of validation, examples: "None", "Single" and "Double"
92
+ encoders_names: different categorical encoders from category_encoders library, example CatBoostEncoder
93
+ cat_cols: list of categorical columns
94
+ model_validation: model training cross validation type from sklearn.model_selection, example
95
+ StratifiedKFold(5)
96
+ model_params: model training hyperparameters
97
+ """
98
+ self.cat_validation = cat_validation
99
+ self.encoders_names = encoders_names
100
+ self.cat_cols = cat_cols
101
+ self.model_validation = model_validation
102
+
103
+ if model_params is None:
104
+ self.model_params = {
105
+ "metrics": "AUC",
106
+ "n_estimators": 150,
107
+ "learning_rate": 0.04,
108
+ "random_state": 42,
109
+ }
110
+ else:
111
+ self.model_params = model_params
112
+
113
+ self.encoders_list = []
114
+ self.models_list = []
115
+ self.scores_list_train = []
116
+ self.scores_list_val = []
117
+ self.models_trees = []
118
+
119
+ def fit(self, X: pd.DataFrame, y: np.array) -> tuple:
120
+ """
121
+ Fits model with specified in init params
122
+ Args:
123
+ X: Input training dataframe
124
+ y: Target for X
125
+
126
+ Returns:
127
+ mean_score_train, mean_score_val, avg_num_trees
128
+ """
129
+ # process cat cols
130
+ warnings.filterwarnings("ignore", category=UserWarning)
131
+
132
+ if self.cat_validation == "None":
133
+ encoder = MultipleEncoder(
134
+ cols=self.cat_cols, encoders_names_tuple=self.encoders_names
135
+ )
136
+ X = encoder.fit_transform(X, y)
137
+
138
+ for n_fold, (train_idx, val_idx) in enumerate(
139
+ self.model_validation.split(X, y)
140
+ ):
141
+
142
+ X_train = X.loc[train_idx]
143
+ y_train = y.loc[train_idx]
144
+
145
+ X_val = X.loc[val_idx]
146
+ y_val = y.loc[val_idx]
147
+
148
+ if self.cat_cols is not None:
149
+ if self.cat_validation == "Single":
150
+ encoder = MultipleEncoder(
151
+ cols=self.cat_cols, encoders_names_tuple=self.encoders_names
152
+ )
153
+
154
+ X_train = encoder.fit_transform(X_train, y_train)
155
+ X_val = encoder.transform(X_val)
156
+ if self.cat_validation == "Double":
157
+ encoder = DoubleValidationEncoderNumerical(
158
+ cols=self.cat_cols, encoders_names_tuple=self.encoders_names
159
+ )
160
+ X_train = encoder.fit_transform(X_train, y_train)
161
+ X_val = encoder.transform(X_val)
162
+ self.encoders_list.append(encoder)
163
+
164
+ # check for OrdinalEncoder encoding
165
+ for col in [col for col in X_train.columns if "OrdinalEncoder" in col]:
166
+ X_train[col] = X_train[col].astype("category")
167
+ X_val[col] = X_val[col].astype("category")
168
+
169
+ # fit model
170
+ model = LGBMClassifier(**self.model_params, verbose=-1)
171
+ model.fit(
172
+ X_train,
173
+ y_train,
174
+ eval_set=[(X_train, y_train), (X_val, y_val)]
175
+ )
176
+ self.models_trees.append(model.best_iteration_)
177
+ self.models_list.append(model)
178
+
179
+ y_hat = model.predict_proba(X_train)[:, 1]
180
+ score_train = roc_auc_score(y_train, y_hat)
181
+ self.scores_list_train.append(score_train)
182
+ y_hat = model.predict_proba(X_val)[:, 1]
183
+ score_val = roc_auc_score(y_val, y_hat)
184
+ self.scores_list_val.append(score_val)
185
+
186
+ mean_score_train = np.mean(self.scores_list_train)
187
+ mean_score_val = np.mean(self.scores_list_val)
188
+ if None in self.models_trees:
189
+ # calling without early-stopping returns Nones as best_iteration
190
+ avg_num_trees = None
191
+ else:
192
+ avg_num_trees = int(np.mean(self.models_trees))
193
+
194
+ return mean_score_train, mean_score_val, avg_num_trees
195
+
196
+ def predict(self, X: pd.DataFrame) -> np.array:
197
+ """
198
+ Making inference with trained models for input dataframe
199
+ Args:
200
+ X: input dataframe for inference
201
+
202
+ Returns: Predicted ranks
203
+
204
+ """
205
+ y_hat = np.zeros(X.shape[0])
206
+ if self.encoders_list is not None and self.encoders_list != []:
207
+ for encoder, model in zip(self.encoders_list, self.models_list):
208
+ X_test = X.copy()
209
+ X_test = encoder.transform(X_test)
210
+
211
+ # check for OrdinalEncoder encoding
212
+ for col in [col for col in X_test.columns if "OrdinalEncoder" in col]:
213
+ X_test[col] = X_test[col].astype("category")
214
+
215
+ unranked_preds = model.predict_proba(X_test)[:, 1]
216
+ y_hat += rankdata(unranked_preds)
217
+ else:
218
+ for model in self.models_list:
219
+ X_test = X.copy()
220
+
221
+ unranked_preds = model.predict_proba(X_test)[:, 1]
222
+ y_hat += rankdata(unranked_preds)
223
+ return y_hat