Spaces:
Sleeping
Sleeping
Commit ·
b5a495b
1
Parent(s): e46d7eb
Initial commit
Browse files- src/models/ct_cbc.cbm +0 -0
- src/models/sm_cbc.cbm +0 -0
- src/models/tm_ctrl_cbc.cbm +0 -0
- src/models/tm_dependend_ctrl_cbc.cbm +0 -0
- src/models/tm_dependend_trmnt_cbc.cbm +0 -0
- src/models/tm_trmnt_cbc.cbm +0 -0
- src/test.ipynb +0 -0
- src/web_app.py +13 -46
src/models/ct_cbc.cbm
DELETED
|
Binary file (466 kB)
|
|
|
src/models/sm_cbc.cbm
DELETED
|
Binary file (498 kB)
|
|
|
src/models/tm_ctrl_cbc.cbm
DELETED
|
Binary file (544 kB)
|
|
|
src/models/tm_dependend_ctrl_cbc.cbm
DELETED
|
Binary file (544 kB)
|
|
|
src/models/tm_dependend_trmnt_cbc.cbm
DELETED
|
Binary file (361 kB)
|
|
|
src/models/tm_trmnt_cbc.cbm
DELETED
|
Binary file (361 kB)
|
|
|
src/test.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/web_app.py
CHANGED
|
@@ -12,22 +12,11 @@ import tools
|
|
| 12 |
# загрузим датасет
|
| 13 |
dataset, target, treatment = tools.get_data()
|
| 14 |
|
| 15 |
-
# загрузим
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
| 21 |
-
# загрузим модели для независимого класификатора
|
| 22 |
-
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
| 23 |
-
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
| 24 |
-
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
| 25 |
-
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
| 26 |
-
# загрузим модели для зависимого класификатора
|
| 27 |
-
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
| 28 |
-
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
| 29 |
-
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
| 30 |
-
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
| 31 |
|
| 32 |
# загрузим данные
|
| 33 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
|
@@ -223,7 +212,7 @@ with st.expander(label='Посмотреть пример пользовател
|
|
| 223 |
st.dataframe(example)
|
| 224 |
res = st.button('Обновить')
|
| 225 |
|
| 226 |
-
with st.
|
| 227 |
# считаем метрики для пользователя
|
| 228 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
| 229 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
|
@@ -236,12 +225,6 @@ with st.form(key='user_metricks'):
|
|
| 236 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
| 237 |
st.write('Uplift по процентилям')
|
| 238 |
st.write(user_metric_uplift_by_percentile)
|
| 239 |
-
# отображаем графики
|
| 240 |
-
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 241 |
-
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 242 |
-
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
| 243 |
-
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
| 244 |
-
st.pyplot(plot_uplift_curve(target_filtered, uplift, treatment_filtered, perfect=prefect_uplift).figure_)
|
| 245 |
|
| 246 |
|
| 247 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
|
@@ -249,29 +232,13 @@ if show_ml_reasons:
|
|
| 249 |
with st.expander('Решение с помощью CatBoost'):
|
| 250 |
with st.form(key='catboost_metricks'):
|
| 251 |
|
| 252 |
-
|
| 253 |
-
estimator_trmnt=tm_dependend_trmntl_cbc,
|
| 254 |
-
estimator_ctrl=tm_dependend_ctrl_cbc,
|
| 255 |
-
method='ddr_control'
|
| 256 |
-
)
|
| 257 |
|
| 258 |
-
tm_ctrl = tm_ctrl.fit(
|
| 259 |
-
data_train, target_train, treatment_train,
|
| 260 |
-
estimator_trmnt_fit_params={
|
| 261 |
-
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
|
| 262 |
-
estimator_ctrl_fit_params={
|
| 263 |
-
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
|
| 264 |
-
)
|
| 265 |
-
|
| 266 |
-
uplift_tm_ctrl = tm_ctrl.predict(filtered_dataset)
|
| 267 |
-
|
| 268 |
-
tm_ctrl_score = uplift_at_k(y_true=target_filtered, uplift=uplift_tm_ctrl, treatment=treatment_filtered,
|
| 269 |
-
strategy='by_group', k=k)
|
| 270 |
# считаем метрики для ML
|
| 271 |
-
catboost_uplift_at_k = uplift_at_k(target_filtered,
|
| 272 |
-
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered,
|
| 273 |
-
catboost_qini_auc_score = qini_auc_score(target_filtered,
|
| 274 |
-
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered,
|
| 275 |
# отображаем метрики
|
| 276 |
col1, col2, col3 = st.columns(3)
|
| 277 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
|
|
@@ -282,6 +249,6 @@ if show_ml_reasons:
|
|
| 282 |
|
| 283 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 284 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 285 |
-
st.pyplot(plot_qini_curve(target_filtered,
|
| 286 |
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
| 287 |
-
st.pyplot(plot_uplift_curve(target_filtered,
|
|
|
|
| 12 |
# загрузим датасет
|
| 13 |
dataset, target, treatment = tools.get_data()
|
| 14 |
|
| 15 |
+
# загрузим предикты моделей
|
| 16 |
+
ct_cbc = pd.read_csv('src/model_predictions/ct_cbc.csv', index_col='Unnamed: 0')
|
| 17 |
+
sm_cbc = pd.read_csv('src/model_predictions/sm_cbc.csv', index_col='Unnamed: 0')
|
| 18 |
+
tm_dependend_cbc = pd.read_csv('src/model_predictions/tm_dependend_cbc.csv', index_col='Unnamed: 0')
|
| 19 |
+
tm_independend_cbc = pd.read_csv('src/model_predictions/tm_independend_cbc.csv', index_col='Unnamed: 0')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# загрузим данные
|
| 22 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
|
|
|
| 212 |
st.dataframe(example)
|
| 213 |
res = st.button('Обновить')
|
| 214 |
|
| 215 |
+
with st.expander('Результаты ручной фильтрации', expanded=True):
|
| 216 |
# считаем метрики для пользователя
|
| 217 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
| 218 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
|
|
|
| 225 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
| 226 |
st.write('Uplift по процентилям')
|
| 227 |
st.write(user_metric_uplift_by_percentile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
|
|
|
| 232 |
with st.expander('Решение с помощью CatBoost'):
|
| 233 |
with st.form(key='catboost_metricks'):
|
| 234 |
|
| 235 |
+
final_uplift = tm_dependend_cbc.loc[target_filtered.index]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# считаем метрики для ML
|
| 238 |
+
catboost_uplift_at_k = uplift_at_k(target_filtered, final_uplift, treatment_filtered, strategy='overall', k=k)
|
| 239 |
+
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, final_uplift, treatment_filtered)
|
| 240 |
+
catboost_qini_auc_score = qini_auc_score(target_filtered, final_uplift, treatment_filtered)
|
| 241 |
+
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, final_uplift, treatment_filtered)
|
| 242 |
# отображаем метрики
|
| 243 |
col1, col2, col3 = st.columns(3)
|
| 244 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
|
|
|
|
| 249 |
|
| 250 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 251 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 252 |
+
st.pyplot(plot_qini_curve(target_filtered, final_uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
| 253 |
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
| 254 |
+
st.pyplot(plot_uplift_curve(target_filtered, final_uplift, treatment_filtered, perfect=prefect_uplift).figure_)
|