Spaces:
Runtime error
Runtime error
Commit
·
e46d7eb
1
Parent(s):
938a6c1
Initial commit
Browse files- src/tools.py +7 -0
- src/web_app.py +54 -47
src/tools.py
CHANGED
|
@@ -115,6 +115,13 @@ def filter_by_recency(data: pd.DataFrame, recency_filter: list) -> pd.DataFrame:
|
|
| 115 |
|
| 116 |
|
| 117 |
def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
data = filter_by_newbie(data, filters['newbie_filter'])
|
| 119 |
if data.shape[0] == 0:
|
| 120 |
return None
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
def filter_data(data: pd.DataFrame, filters: dict) -> pd.DataFrame or None:
|
| 118 |
+
"""
|
| 119 |
+
Filter data by user filters
|
| 120 |
+
|
| 121 |
+
:param data: filtered data
|
| 122 |
+
:param filters: dict of filters
|
| 123 |
+
:return: filtered data
|
| 124 |
+
"""
|
| 125 |
data = filter_by_newbie(data, filters['newbie_filter'])
|
| 126 |
if data.shape[0] == 0:
|
| 127 |
return None
|
src/web_app.py
CHANGED
|
@@ -9,25 +9,27 @@ import catboost
|
|
| 9 |
|
| 10 |
import tools
|
| 11 |
|
| 12 |
-
|
| 13 |
dataset, target, treatment = tools.get_data()
|
| 14 |
|
|
|
|
| 15 |
ct_cbc_model = catboost.CatBoostClassifier()
|
| 16 |
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
|
| 17 |
-
|
| 18 |
sm_cbc_model = catboost.CatBoostClassifier()
|
| 19 |
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
| 20 |
-
|
| 21 |
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
| 22 |
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
| 23 |
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
| 24 |
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
| 25 |
-
|
| 26 |
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
| 27 |
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
| 28 |
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
| 29 |
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
| 30 |
|
|
|
|
| 31 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
| 32 |
data_test_index = pd.read_csv('data/data_test_index.csv')
|
| 33 |
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
|
|
@@ -35,7 +37,6 @@ treatment_test_index = pd.read_csv('data/treatment_test_index.csv')
|
|
| 35 |
target_train_index = pd.read_csv('data/target_train_index.csv')
|
| 36 |
target_test_index = pd.read_csv('data/target_test_index.csv')
|
| 37 |
|
| 38 |
-
|
| 39 |
# фиксируем выборки, чтобы результат работы ML был предсказуем
|
| 40 |
data_train = dataset.loc[data_train_index['0']]
|
| 41 |
data_test = dataset.loc[data_test_index['0']]
|
|
@@ -44,9 +45,6 @@ treatment_test = treatment.loc[treatment_test_index['0']]
|
|
| 44 |
target_train = target.loc[target_train_index['0']]
|
| 45 |
target_test = target.loc[target_test_index['0']]
|
| 46 |
|
| 47 |
-
if 'filter_data' not in st.session_state.keys():
|
| 48 |
-
st.session_state.filter_data = True
|
| 49 |
-
|
| 50 |
st.title('Uplift lab')
|
| 51 |
|
| 52 |
st.markdown(
|
|
@@ -69,6 +67,7 @@ st.markdown(
|
|
| 69 |
Пример данных приведен ниже.
|
| 70 |
"""
|
| 71 |
)
|
|
|
|
| 72 |
refresh = st.button('Обновить выборку')
|
| 73 |
title_subsample = data_train.sample(7)
|
| 74 |
if refresh:
|
|
@@ -132,6 +131,7 @@ with st.expander('Развернуть блок анализа данных'):
|
|
| 132 |
|
| 133 |
filters = {}
|
| 134 |
|
|
|
|
| 135 |
with st.form(key='filter-clients'):
|
| 136 |
st.subheader('Выберем клиентов, которым отправим рекламу.')
|
| 137 |
|
|
@@ -195,7 +195,7 @@ with st.form(key='filter-clients'):
|
|
| 195 |
|
| 196 |
filter_form_submit_button = st.form_submit_button('Применить фильтр')
|
| 197 |
|
| 198 |
-
|
| 199 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
| 200 |
st.error('Необходимо выбрать хотя бы один класс')
|
| 201 |
st.stop()
|
|
@@ -203,7 +203,10 @@ elif not surburban and not urban and not rural:
|
|
| 203 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
| 204 |
st.stop()
|
| 205 |
|
|
|
|
| 206 |
filtered_dataset = tools.filter_data(data_test, filters)
|
|
|
|
|
|
|
| 207 |
if filtered_dataset is None:
|
| 208 |
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
|
| 209 |
st.stop()
|
|
@@ -213,25 +216,27 @@ uplift = [1 for _ in filtered_dataset.index]
|
|
| 213 |
target_filtered = target_test.loc[filtered_dataset.index]
|
| 214 |
treatment_filtered = treatment_test.loc[filtered_dataset.index]
|
| 215 |
|
|
|
|
| 216 |
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
|
| 217 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
| 218 |
example = filtered_dataset.sample(sample_size)
|
| 219 |
st.dataframe(example)
|
| 220 |
res = st.button('Обновить')
|
| 221 |
|
| 222 |
-
|
| 223 |
with st.form(key='user_metricks'):
|
|
|
|
| 224 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
| 225 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
| 226 |
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
|
| 227 |
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
|
|
|
|
| 228 |
col1, col2, col3 = st.columns(3)
|
| 229 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
|
| 230 |
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
|
| 231 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
| 232 |
st.write('Uplift по процентилям')
|
| 233 |
st.write(user_metric_uplift_by_percentile)
|
| 234 |
-
|
| 235 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 236 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 237 |
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
|
@@ -242,39 +247,41 @@ with st.form(key='user_metricks'):
|
|
| 242 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
| 243 |
if show_ml_reasons:
|
| 244 |
with st.expander('Решение с помощью CatBoost'):
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import tools
|
| 11 |
|
| 12 |
+
# загрузим датасет
|
| 13 |
dataset, target, treatment = tools.get_data()
|
| 14 |
|
| 15 |
+
# загрузим модель для ClassTransform
|
| 16 |
ct_cbc_model = catboost.CatBoostClassifier()
|
| 17 |
ct_cbc_model.load_model('src/models/ct_cbc.cbm')
|
| 18 |
+
# загрузим модель для SingleMod
|
| 19 |
sm_cbc_model = catboost.CatBoostClassifier()
|
| 20 |
sm_cbc_model.load_model('src/models/sm_cbc.cbm')
|
| 21 |
+
# загрузим модели для независимого класификатора
|
| 22 |
tm_ctrl_cbc_model = catboost.CatBoostClassifier()
|
| 23 |
tm_ctrl_cbc_model.load_model('src/models/tm_ctrl_cbc.cbm')
|
| 24 |
tm_trmnt_cbc_model = catboost.CatBoostClassifier()
|
| 25 |
tm_ctrl_cbc_model.load_model('src/models/tm_trmnt_cbc.cbm')
|
| 26 |
+
# загрузим модели для зависимого класификатора
|
| 27 |
tm_dependend_ctrl_cbc = catboost.CatBoostClassifier()
|
| 28 |
tm_ctrl_cbc_model.load_model('src/models/tm_dependend_ctrl_cbc.cbm')
|
| 29 |
tm_dependend_trmntl_cbc = catboost.CatBoostClassifier()
|
| 30 |
tm_dependend_trmntl_cbc.load_model('src/models/tm_dependend_trmnt_cbc.cbm')
|
| 31 |
|
| 32 |
+
# загрузим данные
|
| 33 |
data_train_index = pd.read_csv('data/data_train_index.csv')
|
| 34 |
data_test_index = pd.read_csv('data/data_test_index.csv')
|
| 35 |
treatment_train_index = pd.read_csv('data/treatment_train_index.csv')
|
|
|
|
| 37 |
target_train_index = pd.read_csv('data/target_train_index.csv')
|
| 38 |
target_test_index = pd.read_csv('data/target_test_index.csv')
|
| 39 |
|
|
|
|
| 40 |
# фиксируем выборки, чтобы результат работы ML был предсказуем
|
| 41 |
data_train = dataset.loc[data_train_index['0']]
|
| 42 |
data_test = dataset.loc[data_test_index['0']]
|
|
|
|
| 45 |
target_train = target.loc[target_train_index['0']]
|
| 46 |
target_test = target.loc[target_test_index['0']]
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
st.title('Uplift lab')
|
| 49 |
|
| 50 |
st.markdown(
|
|
|
|
| 67 |
Пример данных приведен ниже.
|
| 68 |
"""
|
| 69 |
)
|
| 70 |
+
|
| 71 |
refresh = st.button('Обновить выборку')
|
| 72 |
title_subsample = data_train.sample(7)
|
| 73 |
if refresh:
|
|
|
|
| 131 |
|
| 132 |
filters = {}
|
| 133 |
|
| 134 |
+
# блок фильтров
|
| 135 |
with st.form(key='filter-clients'):
|
| 136 |
st.subheader('Выберем клиентов, которым отправим рекламу.')
|
| 137 |
|
|
|
|
| 195 |
|
| 196 |
filter_form_submit_button = st.form_submit_button('Применить фильтр')
|
| 197 |
|
| 198 |
+
# проверка корректности заполнения форм
|
| 199 |
if not first_group and not second_group and not third_group and not fourth_group and not fifth_group and not sixth_group and not seventh_group:
|
| 200 |
st.error('Необходимо выбрать хотя бы один класс')
|
| 201 |
st.stop()
|
|
|
|
| 203 |
st.error('Необходимо выбрать хотя бы один почтовый индекс')
|
| 204 |
st.stop()
|
| 205 |
|
| 206 |
+
# фильтруем тестовые данные по пользовательскому выбору
|
| 207 |
filtered_dataset = tools.filter_data(data_test, filters)
|
| 208 |
+
|
| 209 |
+
# проверяем, что данные отфильтровались
|
| 210 |
if filtered_dataset is None:
|
| 211 |
st.error('Не найдено пользователей для данных фильтров. Попробуйте изменить фильтры.')
|
| 212 |
st.stop()
|
|
|
|
| 216 |
target_filtered = target_test.loc[filtered_dataset.index]
|
| 217 |
treatment_filtered = treatment_test.loc[filtered_dataset.index]
|
| 218 |
|
| 219 |
+
# блок с демонстрацией отфильтрованных данных
|
| 220 |
with st.expander(label='Посмотреть пример пользователей, которым будет отправлена реклама'):
|
| 221 |
sample_size = 7 if filtered_dataset.shape[0] >= 7 else filtered_dataset.shape[0]
|
| 222 |
example = filtered_dataset.sample(sample_size)
|
| 223 |
st.dataframe(example)
|
| 224 |
res = st.button('Обновить')
|
| 225 |
|
|
|
|
| 226 |
with st.form(key='user_metricks'):
|
| 227 |
+
# считаем метрики для пользователя
|
| 228 |
user_metric_uplift_at_k = uplift_at_k(target_filtered, uplift, treatment_filtered, strategy='overall', k=k)
|
| 229 |
user_metric_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift, treatment_filtered)
|
| 230 |
user_metric_qini_auc_score = qini_auc_score(target_filtered, uplift, treatment_filtered)
|
| 231 |
user_metric_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift, treatment_filtered)
|
| 232 |
+
# отображаем метрики
|
| 233 |
col1, col2, col3 = st.columns(3)
|
| 234 |
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{user_metric_uplift_at_k:.4f}')
|
| 235 |
col2.metric(label=f'Qini AUC score', value=f'{user_metric_qini_auc_score:.4f}', help='Всегда будет 0 для пользова��еля')
|
| 236 |
col3.metric(label=f'Weighted average uplift', value=f'{user_metric_weighted_average_uplift:.4f}')
|
| 237 |
st.write('Uplift по процентилям')
|
| 238 |
st.write(user_metric_uplift_by_percentile)
|
| 239 |
+
# отображаем графики
|
| 240 |
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 241 |
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 242 |
st.pyplot(plot_qini_curve(target_filtered, uplift, treatment_filtered, perfect=perfect_qini).figure_)
|
|
|
|
| 247 |
show_ml_reasons = st.checkbox('Показать решения с помощью ML')
|
| 248 |
if show_ml_reasons:
|
| 249 |
with st.expander('Решение с помощью CatBoost'):
|
| 250 |
+
with st.form(key='catboost_metricks'):
|
| 251 |
+
|
| 252 |
+
tm_ctrl = TwoModels(
|
| 253 |
+
estimator_trmnt=tm_dependend_trmntl_cbc,
|
| 254 |
+
estimator_ctrl=tm_dependend_ctrl_cbc,
|
| 255 |
+
method='ddr_control'
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
tm_ctrl = tm_ctrl.fit(
|
| 259 |
+
data_train, target_train, treatment_train,
|
| 260 |
+
estimator_trmnt_fit_params={
|
| 261 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']},
|
| 262 |
+
estimator_ctrl_fit_params={
|
| 263 |
+
'cat_features': ['womens', 'mens', 'channel', 'zip_code', 'history_segment', 'newbie']}
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
uplift_tm_ctrl = tm_ctrl.predict(filtered_dataset)
|
| 267 |
+
|
| 268 |
+
tm_ctrl_score = uplift_at_k(y_true=target_filtered, uplift=uplift_tm_ctrl, treatment=treatment_filtered,
|
| 269 |
+
strategy='by_group', k=k)
|
| 270 |
+
# считаем метрики для ML
|
| 271 |
+
catboost_uplift_at_k = uplift_at_k(target_filtered, uplift_tm_ctrl, treatment_filtered, strategy='overall', k=k)
|
| 272 |
+
catboost_uplift_by_percentile = uplift_by_percentile(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
| 273 |
+
catboost_qini_auc_score = qini_auc_score(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
| 274 |
+
catboost_weighted_average_uplift = tools.get_weighted_average_uplift(target_filtered, uplift_tm_ctrl, treatment_filtered)
|
| 275 |
+
# отображаем метрики
|
| 276 |
+
col1, col2, col3 = st.columns(3)
|
| 277 |
+
col1.metric(label=f'Uplift для {k}% пользователей', value=f'{catboost_uplift_at_k:.4f}', delta=f'{catboost_uplift_at_k - user_metric_uplift_at_k:.4f}')
|
| 278 |
+
col2.metric(label=f'Qini AUC score', value=f'{catboost_qini_auc_score:.4f}', help='Всегда будет 0 для пользователя', delta=f'{catboost_qini_auc_score - user_metric_qini_auc_score:.4f}')
|
| 279 |
+
col3.metric(label=f'Weighted average uplift', value=f'{catboost_weighted_average_uplift:.4f}', delta=f'{catboost_weighted_average_uplift - user_metric_weighted_average_uplift:.4f}')
|
| 280 |
+
st.write('Uplift по процентилям')
|
| 281 |
+
st.write(catboost_uplift_by_percentile)
|
| 282 |
+
|
| 283 |
+
st.form_submit_button('Обновить графики', help='При изменении флагов')
|
| 284 |
+
perfect_qini = st.checkbox('Отрисовать идеальную метрику qini')
|
| 285 |
+
st.pyplot(plot_qini_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=perfect_qini).figure_)
|
| 286 |
+
prefect_uplift = st.checkbox('Отрисовать идеальную метрику uplift')
|
| 287 |
+
st.pyplot(plot_uplift_curve(target_filtered, uplift_tm_ctrl, treatment_filtered, perfect=prefect_uplift).figure_)
|