Add tabgan/adversarial_model.py
Browse files- tabgan/adversarial_model.py +223 -0
tabgan/adversarial_model.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from lightgbm import LGBMClassifier
|
| 4 |
+
from scipy.stats import rankdata
|
| 5 |
+
from sklearn.metrics import roc_auc_score
|
| 6 |
+
from sklearn.model_selection import StratifiedKFold
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
from tabgan.encoders import MultipleEncoder, DoubleValidationEncoderNumerical
|
| 10 |
+
|
| 11 |
+
warnings.filterwarnings("ignore", message="No further splits with positive gain")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AdversarialModel:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
cat_validation="Single",
|
| 18 |
+
encoders_names=("OrdinalEncoder",),
|
| 19 |
+
cat_cols=None,
|
| 20 |
+
model_validation=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
|
| 21 |
+
model_params=None,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Class for fit predicting tabular models, mostly - boosting. Several encoders for categorical features are
|
| 25 |
+
supported
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
cat_validation: categorical type of validation, examples: "None", "Single" and "Double"
|
| 29 |
+
encoders_names: different categorical encoders from category_encoders library, example CatBoostEncoder
|
| 30 |
+
cat_cols: list of categorical columns
|
| 31 |
+
model_validation: model training cross validation type from sklearn.model_selection,
|
| 32 |
+
example StratifiedKFold(5)
|
| 33 |
+
model_params: model training hyperparameters
|
| 34 |
+
"""
|
| 35 |
+
self.metrics = None
|
| 36 |
+
self.trained_model = None
|
| 37 |
+
self.cat_validation = cat_validation
|
| 38 |
+
self.encoders_names = encoders_names
|
| 39 |
+
self.cat_cols = cat_cols
|
| 40 |
+
self.model_validation = model_validation
|
| 41 |
+
self.model_params = model_params
|
| 42 |
+
|
| 43 |
+
def adversarial_test(self, left_df, right_df):
|
| 44 |
+
"""
|
| 45 |
+
Trains adversarial model to distinguish train from test
|
| 46 |
+
:param left_df: dataframe
|
| 47 |
+
:param right_df: dataframe
|
| 48 |
+
:return: trained model
|
| 49 |
+
"""
|
| 50 |
+
# sample to shuffle the data
|
| 51 |
+
left_df = left_df.copy().sample(frac=1).reset_index(drop=True)
|
| 52 |
+
right_df = right_df.copy().sample(frac=1).reset_index(drop=True)
|
| 53 |
+
|
| 54 |
+
left_df = left_df.head(right_df.shape[0])
|
| 55 |
+
right_df = right_df.head(left_df.shape[0])
|
| 56 |
+
|
| 57 |
+
left_df["gt"] = 0
|
| 58 |
+
right_df["gt"] = 1
|
| 59 |
+
|
| 60 |
+
concated = pd.concat([left_df, right_df], ignore_index=True)
|
| 61 |
+
lgb_model = Model(
|
| 62 |
+
cat_validation=self.cat_validation,
|
| 63 |
+
encoders_names=self.encoders_names,
|
| 64 |
+
cat_cols=self.cat_cols,
|
| 65 |
+
model_validation=self.model_validation,
|
| 66 |
+
model_params=self.model_params,
|
| 67 |
+
)
|
| 68 |
+
train_score, val_score, avg_num_trees = lgb_model.fit(
|
| 69 |
+
concated.drop("gt", axis=1), concated["gt"]
|
| 70 |
+
)
|
| 71 |
+
self.metrics = {"train_score": train_score,
|
| 72 |
+
"val_score": val_score,
|
| 73 |
+
"avg_num_trees": avg_num_trees}
|
| 74 |
+
self.trained_model = lgb_model
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Model:
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
cat_validation="None",
|
| 81 |
+
encoders_names=None,
|
| 82 |
+
cat_cols=None,
|
| 83 |
+
model_validation=StratifiedKFold(n_splits=5, shuffle=True, random_state=42),
|
| 84 |
+
model_params=None,
|
| 85 |
+
):
|
| 86 |
+
"""
|
| 87 |
+
Class for fit predicting tabular models, mostly - boosting. Several encoders for categorical features are
|
| 88 |
+
supported
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
cat_validation: categorical type of validation, examples: "None", "Single" and "Double"
|
| 92 |
+
encoders_names: different categorical encoders from category_encoders library, example CatBoostEncoder
|
| 93 |
+
cat_cols: list of categorical columns
|
| 94 |
+
model_validation: model training cross validation type from sklearn.model_selection, example
|
| 95 |
+
StratifiedKFold(5)
|
| 96 |
+
model_params: model training hyperparameters
|
| 97 |
+
"""
|
| 98 |
+
self.cat_validation = cat_validation
|
| 99 |
+
self.encoders_names = encoders_names
|
| 100 |
+
self.cat_cols = cat_cols
|
| 101 |
+
self.model_validation = model_validation
|
| 102 |
+
|
| 103 |
+
if model_params is None:
|
| 104 |
+
self.model_params = {
|
| 105 |
+
"metrics": "AUC",
|
| 106 |
+
"n_estimators": 150,
|
| 107 |
+
"learning_rate": 0.04,
|
| 108 |
+
"random_state": 42,
|
| 109 |
+
}
|
| 110 |
+
else:
|
| 111 |
+
self.model_params = model_params
|
| 112 |
+
|
| 113 |
+
self.encoders_list = []
|
| 114 |
+
self.models_list = []
|
| 115 |
+
self.scores_list_train = []
|
| 116 |
+
self.scores_list_val = []
|
| 117 |
+
self.models_trees = []
|
| 118 |
+
|
| 119 |
+
def fit(self, X: pd.DataFrame, y: np.array) -> tuple:
|
| 120 |
+
"""
|
| 121 |
+
Fits model with specified in init params
|
| 122 |
+
Args:
|
| 123 |
+
X: Input training dataframe
|
| 124 |
+
y: Target for X
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
mean_score_train, mean_score_val, avg_num_trees
|
| 128 |
+
"""
|
| 129 |
+
# process cat cols
|
| 130 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 131 |
+
|
| 132 |
+
if self.cat_validation == "None":
|
| 133 |
+
encoder = MultipleEncoder(
|
| 134 |
+
cols=self.cat_cols, encoders_names_tuple=self.encoders_names
|
| 135 |
+
)
|
| 136 |
+
X = encoder.fit_transform(X, y)
|
| 137 |
+
|
| 138 |
+
for n_fold, (train_idx, val_idx) in enumerate(
|
| 139 |
+
self.model_validation.split(X, y)
|
| 140 |
+
):
|
| 141 |
+
|
| 142 |
+
X_train = X.loc[train_idx]
|
| 143 |
+
y_train = y.loc[train_idx]
|
| 144 |
+
|
| 145 |
+
X_val = X.loc[val_idx]
|
| 146 |
+
y_val = y.loc[val_idx]
|
| 147 |
+
|
| 148 |
+
if self.cat_cols is not None:
|
| 149 |
+
if self.cat_validation == "Single":
|
| 150 |
+
encoder = MultipleEncoder(
|
| 151 |
+
cols=self.cat_cols, encoders_names_tuple=self.encoders_names
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
X_train = encoder.fit_transform(X_train, y_train)
|
| 155 |
+
X_val = encoder.transform(X_val)
|
| 156 |
+
if self.cat_validation == "Double":
|
| 157 |
+
encoder = DoubleValidationEncoderNumerical(
|
| 158 |
+
cols=self.cat_cols, encoders_names_tuple=self.encoders_names
|
| 159 |
+
)
|
| 160 |
+
X_train = encoder.fit_transform(X_train, y_train)
|
| 161 |
+
X_val = encoder.transform(X_val)
|
| 162 |
+
self.encoders_list.append(encoder)
|
| 163 |
+
|
| 164 |
+
# check for OrdinalEncoder encoding
|
| 165 |
+
for col in [col for col in X_train.columns if "OrdinalEncoder" in col]:
|
| 166 |
+
X_train[col] = X_train[col].astype("category")
|
| 167 |
+
X_val[col] = X_val[col].astype("category")
|
| 168 |
+
|
| 169 |
+
# fit model
|
| 170 |
+
model = LGBMClassifier(**self.model_params, verbose=-1)
|
| 171 |
+
model.fit(
|
| 172 |
+
X_train,
|
| 173 |
+
y_train,
|
| 174 |
+
eval_set=[(X_train, y_train), (X_val, y_val)]
|
| 175 |
+
)
|
| 176 |
+
self.models_trees.append(model.best_iteration_)
|
| 177 |
+
self.models_list.append(model)
|
| 178 |
+
|
| 179 |
+
y_hat = model.predict_proba(X_train)[:, 1]
|
| 180 |
+
score_train = roc_auc_score(y_train, y_hat)
|
| 181 |
+
self.scores_list_train.append(score_train)
|
| 182 |
+
y_hat = model.predict_proba(X_val)[:, 1]
|
| 183 |
+
score_val = roc_auc_score(y_val, y_hat)
|
| 184 |
+
self.scores_list_val.append(score_val)
|
| 185 |
+
|
| 186 |
+
mean_score_train = np.mean(self.scores_list_train)
|
| 187 |
+
mean_score_val = np.mean(self.scores_list_val)
|
| 188 |
+
if None in self.models_trees:
|
| 189 |
+
# calling without early-stopping returns Nones as best_iteration
|
| 190 |
+
avg_num_trees = None
|
| 191 |
+
else:
|
| 192 |
+
avg_num_trees = int(np.mean(self.models_trees))
|
| 193 |
+
|
| 194 |
+
return mean_score_train, mean_score_val, avg_num_trees
|
| 195 |
+
|
| 196 |
+
def predict(self, X: pd.DataFrame) -> np.array:
|
| 197 |
+
"""
|
| 198 |
+
Making inference with trained models for input dataframe
|
| 199 |
+
Args:
|
| 200 |
+
X: input dataframe for inference
|
| 201 |
+
|
| 202 |
+
Returns: Predicted ranks
|
| 203 |
+
|
| 204 |
+
"""
|
| 205 |
+
y_hat = np.zeros(X.shape[0])
|
| 206 |
+
if self.encoders_list is not None and self.encoders_list != []:
|
| 207 |
+
for encoder, model in zip(self.encoders_list, self.models_list):
|
| 208 |
+
X_test = X.copy()
|
| 209 |
+
X_test = encoder.transform(X_test)
|
| 210 |
+
|
| 211 |
+
# check for OrdinalEncoder encoding
|
| 212 |
+
for col in [col for col in X_test.columns if "OrdinalEncoder" in col]:
|
| 213 |
+
X_test[col] = X_test[col].astype("category")
|
| 214 |
+
|
| 215 |
+
unranked_preds = model.predict_proba(X_test)[:, 1]
|
| 216 |
+
y_hat += rankdata(unranked_preds)
|
| 217 |
+
else:
|
| 218 |
+
for model in self.models_list:
|
| 219 |
+
X_test = X.copy()
|
| 220 |
+
|
| 221 |
+
unranked_preds = model.predict_proba(X_test)[:, 1]
|
| 222 |
+
y_hat += rankdata(unranked_preds)
|
| 223 |
+
return y_hat
|