Spaces:
Build error
Build error
| import pandas as pd | |
| from causalml.metrics import * | |
| class ModelEvaluator: | |
| def __init__(self, model, df_eval, X_names): | |
| self.model = model | |
| self.df_eval = df_eval | |
| self.X_names = X_names | |
| def predict_cate(self, discount): | |
| """ | |
| Predicts the Conditional Average Treatment Effect (CATE) for a given discount level. | |
| """ | |
| self.df_eval['cate'] = self.model.predict( | |
| X=self.df_eval[self.X_names].values, | |
| treatment=self.df_eval['treatment_group_key'].values | |
| ).tolist() | |
| self.df_eval[['cate_discount_05', 'cate_discount_10', 'cate_discount_15']] = pd.DataFrame( | |
| self.df_eval.cate.tolist(), | |
| index=self.df_eval.index | |
| ) | |
| def eval_performance(self, discount): | |
| """ | |
| Evaluates the model's performance for a specific discount, calculating Qini curves for conversion and benefit. | |
| """ | |
| # Ensure CATE predictions are available | |
| if 'cate' not in self.df_eval.columns: | |
| self.predict_cate(discount) | |
| df_eval_disc = self.df_eval[self.df_eval['treatment_group_key'].isin(['control', discount])] | |
| df_eval_disc['treatment_num'] = df_eval_disc.apply( | |
| lambda x: 0 if x['treatment_group_key'] == 'control' else 1, | |
| axis=1 | |
| ) | |
| cate_col = 'cate_{}'.format(discount) | |
| df_eval_qini_conversion = pd.DataFrame( | |
| [df_eval_disc[cate_col].ravel(), df_eval_disc.treatment_num.ravel(), df_eval_disc['conversion'].ravel()], | |
| index=['S', 'w', 'y'] | |
| ).T | |
| df_eval_qini_benefit = pd.DataFrame( | |
| [df_eval_disc[cate_col].ravel(), df_eval_disc.treatment_num.ravel(), df_eval_disc['benefit'].ravel()], | |
| index=['S', 'w', 'y'] | |
| ).T | |
| # Assuming get_qini function exists and calculates Qini coefficient | |
| cd_conversion = (get_qini(df_eval_qini_conversion) * 2).reset_index() | |
| cd_conversion = cd_conversion / cd_conversion.shape[0] | |
| cd_benefit = (get_qini(df_eval_qini_benefit) * 2).reset_index() | |
| cd_benefit = cd_benefit / cd_benefit.shape[0] | |
| return cd_conversion, cd_benefit | |