Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import xgboost | |
| import pickle | |
| from sksurv.linear_model import CoxPHSurvivalAnalysis | |
| from sksurv.util import Surv | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import survival_analysis as ttc | |
| import xai as xai | |
| print("Demo Initialization...") | |
| print("Download the dataset") | |
| hundred_churners_val = pd.read_csv('data/hundred_val_churners.csv', index_col=0) | |
| hundred_non_churners_val = pd.read_csv('data/hundred_val_non_churners.csv', index_col=0) | |
| val_df = pd.concat([hundred_churners_val, hundred_non_churners_val], axis=0) | |
| y_val_df = val_df[['Exited', 'Tenure']] | |
| print("Preprocessing for XAI on Churn Prediction...") | |
| scaler_xai, label_encs_xai, train_cols_xai = xai.obtain_scaler_and_label_enc() | |
| val_ordered_df_xai, val_scaled_df_xai = xai.scale_dataset(val_df, "Exited", train_cols_xai, scaler_xai) | |
| print("Loading XGBoost...") | |
| xgb_model = xgboost.XGBClassifier() | |
| xgb_model.load_model("models/xgb_churn_model.json") | |
| booster = xgb_model.get_booster() | |
| print("Preparing data for XAI...") | |
| val_unscaled_pd_xai = val_ordered_df_xai | |
| val_pd_X_xai = pd.DataFrame(val_scaled_df_xai, columns=val_ordered_df_xai.columns, index = val_ordered_df_xai.index) | |
| val_pd_y_xai = val_df['Exited'] | |
| print("Initialization SHAP explainer and obtaining classification predictions...") | |
| explainer, shap_values, shap_values_df, explaination = xai.obtain_explanations(booster, val_pd_X_xai) | |
| churners_and_non, churners_and_non_X, new_y = xai.obtain_predictions(booster, val_pd_X_xai, val_pd_y_xai) | |
| print("Preprocessing for Survival Analysis...") | |
| scaler_ttc, label_enc_ttc, train_cols_ttc, X_train_ttc, y_train_ttc = ttc.obtain_scaler_and_label_enc() | |
| val_ordered_df_ttc, val_scaled_df_ttc = ttc.scale_dataset(val_df, ["Exited", "Tenure"], train_cols_ttc, scaler_ttc) | |
| y_val_ttc = Surv.from_dataframe("Exited", "Tenure", y_val_df) | |
| val_unscaled_pd = val_ordered_df_ttc | |
| val_pd_X_ttc = pd.DataFrame(val_scaled_df_ttc, columns=val_ordered_df_ttc.columns, index = val_ordered_df_ttc.index) | |
| val_pd_y_ttc = val_df['Exited'] | |
| print("Loading Cox Proportional Hazards model...") | |
| cph = CoxPHSurvivalAnalysis() | |
| with open('models/cox_model.pkl', 'rb') as f: | |
| cph = pickle.load(f) | |
| print("Obtaining time to churn predictions...") | |
| prediction = cph.predict(val_scaled_df_ttc.drop(['Exited', 'Tenure'], axis = 1)) | |
| val_scaled_df_ttc['preds'] = prediction | |
| # Predict survival functions | |
| surv_func = cph.predict_survival_function(val_scaled_df_ttc.drop(['Exited', 'Tenure', 'preds'], axis = 1), return_array = True) | |
| df_surv = pd.DataFrame(surv_func.T, columns = val_scaled_df_ttc.index) | |
| threshold = 0.5 | |
| predicted_time_to_churn = (df_surv <= threshold) | |
| churns = predicted_time_to_churn.idxmax().where(predicted_time_to_churn.any()) | |
| val_scaled_df_ttc['absolute_time_to_churn'] = churns | |
| val_scaled_df_ttc['absolute_time_to_churn'].fillna(11, inplace=True) | |
| val_scaled_df_ttc['Churn_Prediction'] = (val_scaled_df_ttc['absolute_time_to_churn'] <= 10).astype(int) | |
| #churners_and_non = val_scaled_df_ttc | |
| #churners_and_non_X = val_pd_X_ttc[val_pd_X_ttc.index.isin(churners_and_non.index.tolist())] | |
| X_val_final = val_scaled_df_ttc | |
| df_train = pd.concat([X_train_ttc, y_train_ttc], axis = 1) | |
| # Nel tuo codice esistente, dopo aver estratto il cliente random e plottato SHAP: | |
| test_features = pd.DataFrame(val_scaled_df_ttc, columns = val_scaled_df_ttc.drop(['preds', 'absolute_time_to_churn', 'Churn_Prediction', 'Exited', 'Tenure'], axis = 1).columns, index = val_scaled_df_ttc.index) | |
| print("Setup completed!") | |
| print(f"Dataset: {len(val_df)} customers") | |
| print(f" - Churners: {val_df['Exited'].sum()}") | |
| print(f" - Non-churners: {(1-val_df['Exited']).sum()}") | |
| def prepare_customer_table(): | |
| """Prepara una tabella riassuntiva dei clienti per la selezione interattiva""" | |
| # Crea DataFrame con informazioni chiave | |
| table_data = pd.DataFrame({ | |
| 'CustomerID': val_df.index, | |
| 'Age': val_df['Age'].astype(int), | |
| 'Tenure': val_df['Tenure'].astype(int), | |
| 'Balance': val_df['Balance'].round(0), | |
| 'NumOfProducts': val_df['NumOfProducts'].astype(int), | |
| 'IsActiveMember': val_df['IsActiveMember'].map({1: 'Yes', 0: 'No'}), | |
| 'Exited': val_df['Exited'].map({1: '๐ด Churner', 0: '๐ข Non-Churner'}) | |
| }) | |
| # Aggiungi predizioni se disponibili | |
| if 'predicted_prediction' in new_y.columns: | |
| table_data['ChurnProb'] = new_y['predicted_prediction'].values | |
| table_data['ChurnProb'] = (table_data['ChurnProb'] * 100).round(1).astype(str) + '%' | |
| table_data = table_data.sample(frac=1, random_state=42) #.reset_index(drop=True) | |
| return table_data | |
| # Prepara la tabella | |
| customer_table_df = prepare_customer_table() | |
| # ===== PAGINAZIONE ===== | |
| ROWS_PER_PAGE = 10 | |
| total_pages = (len(customer_table_df) + ROWS_PER_PAGE - 1) // ROWS_PER_PAGE | |
| def get_page_data(page_num): | |
| """Ottiene i dati per una specifica pagina""" | |
| start_idx = page_num * ROWS_PER_PAGE | |
| end_idx = min(start_idx + ROWS_PER_PAGE, len(customer_table_df)) | |
| page_data = customer_table_df.iloc[start_idx:end_idx] | |
| page_info = f"**Page {page_num + 1} of {total_pages}** | Showing customers {start_idx + 1}-{end_idx} of {len(customer_table_df)}" | |
| # Abilita/disabilita bottoni | |
| #prev_interactive = page_num > 0 | |
| #next_interactive = page_num < total_pages - 1 | |
| prev_update = gr.update(interactive=(page_num > 0)) | |
| next_update = gr.update(interactive=(page_num < total_pages - 1)) | |
| return page_data, page_info, prev_update, next_update #prev_interactive, next_interactive | |
| def next_page(current_page): | |
| """Va alla pagina successiva""" | |
| new_page = min(current_page + 1, total_pages - 1) | |
| return get_page_data(new_page) + (new_page,) | |
| def prev_page(current_page): | |
| """Va alla pagina precedente""" | |
| new_page = max(current_page - 1, 0) | |
| return get_page_data(new_page) + (new_page,) | |
| def analyze_customer(sample_mode, customer_id_input, sample_size=30): | |
| """ | |
| Analizza un cliente con XAI e Survival Analysis | |
| Args: | |
| sample_mode: "Cliente Casuale" o "Cliente Specifico" | |
| customer_id_input: ID del cliente (se modalitร specifica) | |
| sample_size: Dimensione del sample per estrazione casuale | |
| Returns: | |
| tuple: (fig_xai, fig_survival, info_text, customer_details) | |
| """ | |
| try: | |
| plt.close('all') # Chiudi eventuali figure precedenti | |
| # Estrai cliente | |
| if sample_mode == "Random Customer": | |
| # Usa la tua funzione extract_customer | |
| customer_pos, customer_idx, customer_x, customer_y, customer_x_original, customer_record = xai.extract_customer( | |
| val_pd_X_xai, | |
| churners_and_non_X.sample(sample_size, random_state=42), | |
| churners_and_non.sample(sample_size, random_state=42), | |
| val_unscaled_pd_xai | |
| ) | |
| customer_x_display = customer_x_original.copy() | |
| # Inversione label encoding | |
| for col, le in label_encs_xai.items(): | |
| if col in customer_x_display.index: | |
| try: | |
| val = int(customer_x_display.loc[col]) | |
| decoded = le.inverse_transform([val])[0] | |
| customer_x_display.loc[col] = decoded | |
| except Exception as e: | |
| print(f"Error on column {col}: {e}") | |
| else: | |
| # Cliente specifico | |
| customer_id_input = int(customer_id_input) | |
| if customer_id_input not in val_pd_X_xai.index: | |
| available_ids = val_pd_X_xai.index.tolist()[:20] | |
| return None, None, f"โ **Error**: Customer ID {customer_id_input} not found!\n\n**IDs available (top 20):** {available_ids}", "" | |
| customer_idx = customer_id_input | |
| customer_x = val_pd_X_xai.loc[[customer_idx]] | |
| customer_y = new_y[new_y.index == customer_idx] | |
| customer_x_original = val_unscaled_pd_xai.loc[customer_idx] | |
| customer_record = val_df.loc[customer_idx] | |
| customer_x_display = customer_x_original.copy() | |
| # Inversione label encoding | |
| for col, le in label_encs_xai.items(): | |
| if col in customer_x_display.index: | |
| try: | |
| val = int(customer_x_display.loc[col]) | |
| decoded = le.inverse_transform([val])[0] | |
| customer_x_display.loc[col] = decoded | |
| except Exception as e: | |
| print(f"Error on column {col}: {e}") | |
| # Verifica che il cliente esista in entrambi i dataset | |
| if customer_idx not in test_features.index: | |
| return None, None, f"โ **Error**: Customer {customer_idx} not found in the survival dataset", "" | |
| # Informazioni del cliente | |
| actual_churn = customer_y['Exited'] | |
| actual_tenure = customer_x_original['Tenure'] | |
| churn_prob = customer_y['predicted_prediction'].values[0] | |
| customer_features_ttc = test_features.loc[[customer_idx]] | |
| risk_score = np.exp(cph.predict(customer_features_ttc))[0] | |
| status = "๐ด CHURNER" if actual_churn.values == 1 else "๐ข NON-CHURNER" | |
| #risk_level = "๐ฅ HIGH" if risk_score > 1.5 else ("โ ๏ธ MEDIUM" if risk_score > 0.8 else "โ LOW") | |
| risk_level = "๐ฅ๐ฅ CRITICAL" if risk_score > 1.8 else ("๐ฅ HIGH" if risk_score > 1.2 else ("โ ๏ธ MEDIUM" if risk_score > 0.8 else "โ LOW")) | |
| # INFO TEXT | |
| info_text = f""" | |
| ## ๐ Customer #{customer_idx} | |
| | Metric | Value | | |
| |---------|--------| | |
| | **Actual Status** | {status} | | |
| | **Tenure** | {actual_tenure:.0f} years | | |
| | **Churn Probability (Classifier)** | {churn_prob:.1%} | | |
| | **Risk Score (Cox)** | {risk_score:.2f} | | |
| | **Risk Level** | {risk_level} | | |
| --- | |
| """ | |
| # DETTAGLI CLIENTE (tabella espandibile) | |
| int_cols = ["Age", "Tenure", "NumOfProducts", "HasCrCard", "IsActiveMember"] | |
| customer_x_display[int_cols] = customer_x_display[int_cols].astype("Int64") | |
| details_df = pd.DataFrame(customer_x_display.T, index = customer_x_display.T.index) #pd.DataFrame(customer_x_original.T, index = customer_x_original.T.index) | |
| details_df.columns = ['Value'] | |
| customer_details = details_df.to_markdown(floatfmt=".2f") #to_markdown() # o .to_string()??? Verifica con gradio | |
| # ===== PLOT XAI ===== | |
| print(f"๐จ Generating XAI plot for customer {customer_idx}...") | |
| customer_shap_values = np.array(shap_values_df[shap_values_df.index == customer_idx]) | |
| fig_xai = xai.plot_waterfall( | |
| customer_shap_values, | |
| explainer.expected_value, | |
| customer_x, | |
| customer_x_display, | |
| customer_y['predicted_prediction'], | |
| 'Yes', | |
| customer_idx | |
| ) | |
| # ===== PLOT SURVIVAL ===== | |
| print(f"๐ Generating Survival plot for customer {customer_idx}...") | |
| fig_survival = ttc.plot_single_customer_complete( | |
| customer_idx, | |
| X_val_final, | |
| test_features, | |
| cph, | |
| df_train, | |
| max_time=10 | |
| ) | |
| print(f"โ Analysis completed for customer {customer_idx}") | |
| return fig_xai, fig_survival, info_text, f"### ๐ Features Details\n\n{customer_details}" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"โ **Error during the analysis:**\n\n```\n{str(e)}\n```\n\n**Traceback:**\n```\n{traceback.format_exc()}\n```" | |
| print(error_msg) | |
| return None, None, error_msg, "" | |
| # ===== NUOVA FUNZIONE: Gestisce la selezione dalla tabella ===== | |
| #def on_table_select(evt: gr.SelectData): | |
| def on_table_select(current_page, evt: gr.SelectData): | |
| """Callback quando l'utente clicca su una riga della tabella - analizza immediatamente""" | |
| if evt.index is not None: | |
| row_index = evt.index[0] # Riga nella pagina corrente | |
| # Calcola l'indice reale nel dataframe completo | |
| actual_index = current_page * ROWS_PER_PAGE + row_index | |
| customer_id = customer_table_df.iloc[actual_index]['CustomerID'] | |
| fig_xai, fig_survival, info_text, customer_details = analyze_customer( | |
| "Specific Customer", | |
| customer_id | |
| ) | |
| return fig_xai, fig_survival, info_text, customer_details | |
| return None, None, "Select a customer from the table", "" | |
| # ===================== INTERFACCIA GRADIO ===================== | |
| with gr.Blocks(title="Customer Churn Analysis", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ๐ Customer Churn Survival Analysis | |
| Managing their customers is one of the core activities of any business: it is important to keep customers engaged and avoid losing too many of them, in order to keep making profits. | |
| All these activities fall under the name of Custome Retention, which implies a series of analysis to help any business to identify customers at risk of leaving the offered services and identify the actions to take to counteract the phenomenon. | |
| This demo serves to show a couple of instruments that can be used to identify the customers in need of attention within the banking sector. | |
| Customer Churn in this case implies closing a bank account. | |
| A Churn Classification pipeline is useful to distinguish between Churners and Non-Churners. The Explainable AI framework can be useful to give insights on the phenomenon, from the point of view of the classifier: the SHAP values, in fact, explain *why* the classifier assigned a specific label to a customer, showing how each data feature impacted on it. | |
| A Survival Analysis pipeline, on the other hand, is useful to give an indication of *when* the churn is likely to happen and to estimate the risk of the phenomenon occurrence. | |
| Thanks to these two instruments, any business (in this specific case, a bank) is able to analyse its customers and understand their current position with respect to the churn phenomenon: having an indication of the risk and which are the aspects that may be more relevant for churning is the first step to later choose retention strategies. | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธ Configuration") | |
| sample_mode = gr.Radio( | |
| choices=["Random Customer", "Specific Customer"], | |
| value="Random Customer", | |
| label="Selection Mode", | |
| info="Choose how to select the customer to analyse" | |
| ) | |
| gr.Markdown(""" | |
| ### ๐ก Hint | |
| Use **"Random Customer"** to explore the dataset, | |
| or **"Specific Customer"** to analyze a specific customer. | |
| """) | |
| with gr.Column(scale=2): | |
| info_output = gr.Markdown(""" | |
| ### ๐ Welcome! | |
| Click on **"๐ Analyse Customer"** to begin with the analysis, | |
| or select **"Specific Customer"** and click on a row in the table below. | |
| The system will analyse: | |
| 1. The most relevant features for the classification (SHAP) | |
| 2. The survival probability over time | |
| 3. The churn risk and when it could happen | |
| """) | |
| gr.Markdown("---") | |
| # ===== BOTTONE PER RANDOM (visibile solo in modalitร Random) ===== | |
| analyze_btn = gr.Button( | |
| "๐ Analyse the customer", | |
| variant="primary", | |
| size="lg", | |
| visible=True | |
| ) | |
| # ===== TABELLA INTERATTIVA ===== | |
| customer_table_section = gr.Column(visible=False) | |
| #with customer_table_section: | |
| # gr.Markdown("### ๐ Customer Database") | |
| #gr.Markdown(""" | |
| #**๐ก Tip:** Click on any row to immediately analyse that customer. | |
| #The table shows all 200 customers with key information. | |
| # """) | |
| #customer_table = gr.Dataframe( | |
| # value=customer_table_df, | |
| # headers=list(customer_table_df.columns), | |
| # datatype=["number", "number", "number", "number", "number", "str", "str", "str"], | |
| # interactive=False, | |
| # wrap=False, | |
| # label=None, | |
| # height=300, | |
| # column_widths=["10%", "8%", "8%", "12%", "10%", "12%", "15%", "12%"] | |
| #) | |
| with customer_table_section: | |
| gr.Markdown("### ๐ Customer Database") | |
| page_info = gr.Markdown(f"**Page 1 of {total_pages}** | Showing customers 1-{min(ROWS_PER_PAGE, len(customer_table_df))} of {len(customer_table_df)}") | |
| customer_table = gr.Dataframe( | |
| value=customer_table_df.iloc[0:ROWS_PER_PAGE], | |
| headers=list(customer_table_df.columns), | |
| datatype=["number", "number", "number", "number", "number", "str", "str", "str"], | |
| interactive=False, | |
| wrap=False, | |
| label="Click on a row to analyse that customer" | |
| ) | |
| with gr.Row(): | |
| prev_btn = gr.Button("โฌ ๏ธ Previous", interactive=True, size="sm") # False | |
| next_btn = gr.Button("Next โก๏ธ", interactive=True, size="sm") | |
| # Stato della pagina corrente (nascosto) | |
| current_page_state = gr.State(0) | |
| gr.Markdown("---") | |
| # TAB per organizzare i grafici | |
| with gr.Tabs(): | |
| with gr.Tab("๐ง Explainable AI (XAI)"): | |
| gr.Markdown(""" | |
| ### ๐ How to read the SHAP Plot | |
| With respect to the average prediction, reported at the bottom, it shows the impact each feature had on the classification. | |
| Each feature can have a positive impact, towards the classification as **"Churner"**, or a negative one, towards the classification as **"Non-Churner"**. | |
| - **Red Bars**: Features that **increase** the churn probability | |
| - **Blue Bars**: Features that **reduce** the churn probability | |
| - **Bar Length**: Indicates how much the feature impacts on the final classification | |
| - **Feature Value**: Shown beside the feature name | |
| """) | |
| shap_plot = gr.Plot(label="SHAP Waterfall Plot") | |
| with gr.Tab("โฑ๏ธ Survival Analysis"): | |
| gr.Markdown(""" | |
| ### ๐ How to read the Plots | |
| **๐น Survival Probability** (at the top): | |
| - Curve showing the probability that a customer stays throughout an observation time window | |
| - Green Dots: when the custumer arrives at 75%, 50%, 25% chance of survival | |
| - Blue Star: Actual position of the customer | |
| **๐น Risk Timeline** (bottom left): | |
| - Colored areas indicate increasing risk levels | |
| - Red curve shows the cumulative risk through time | |
| **๐น Churn Probability at Intervals** (bottom right): | |
| - Shows the risk at fixed intervals (2, 4, 6, 8, 10 years) | |
| """) | |
| survival_plot = gr.Plot(label="Complete Survival Analysis") | |
| with gr.Tab("๐ Customer Details"): | |
| customer_details_output = gr.Markdown("Select a customer to view the details") | |
| # ===== EVENT HANDLERS ===== | |
| def toggle_ui_elements(mode): | |
| if mode == "Random Customer": | |
| return ( | |
| gr.Button(visible=True), # Mostra bottone | |
| gr.Column(visible=False) # Nascondi tabella | |
| ) | |
| else: # Specific Customer | |
| return ( | |
| gr.Button(visible=False), # Nascondi bottone | |
| gr.Column(visible=True) # Mostra tabella | |
| ) | |
| sample_mode.change( | |
| fn=toggle_ui_elements, | |
| inputs=[sample_mode], | |
| outputs=[analyze_btn, customer_table_section] | |
| ) | |
| # Selezione da tabella - analizza immediatamente | |
| #customer_table.select( | |
| # fn=on_table_select, | |
| # outputs=[shap_plot, survival_plot, info_output, customer_details_output] | |
| #) | |
| # Paginazione | |
| next_btn.click( | |
| fn=next_page, | |
| inputs=[current_page_state], | |
| outputs=[customer_table, page_info, prev_btn, next_btn, current_page_state] | |
| ) | |
| prev_btn.click( | |
| fn=prev_page, | |
| inputs=[current_page_state], | |
| outputs=[customer_table, page_info, prev_btn, next_btn, current_page_state] | |
| ) | |
| customer_table.select( | |
| fn=on_table_select, | |
| inputs=[current_page_state], | |
| outputs=[shap_plot, survival_plot, info_output, customer_details_output] | |
| ) | |
| # Analizza cliente random | |
| def analyze_random_customer(): | |
| return analyze_customer("Random Customer", None) | |
| analyze_btn.click( | |
| fn=analyze_random_customer, | |
| outputs=[shap_plot, survival_plot, info_output, customer_details_output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ๐ Technical Information | |
| **Used Models:** | |
| - **XGBoost Classifier**: Binary Prediction Churn/Non-Churn | |
| - **Cox Proportional Hazards**: Analysis of Time-To-Churn | |
| - **Kaplan-Meier Estimator**: Survival baseline | |
| **Metrics:** | |
| - **Risk Score**: Relative risk with respect to the average population (Cox model) | |
| - **SHAP Values**: Feature contribution on the classification prediction | |
| - **Churn Probability**: Classifier output (0-100%) | |
| --- | |
| ๐ **Dataset**: Bank Customer Churn | ๐ค **Framework**: Gradio + scikit-survival + SHAP | |
| """) | |
| # Lancia la demo | |
| if __name__ == "__main__": | |
| demo.launch( | |
| #share=True, #False # Su HF Space viene gestito automaticamente | |
| #server_name="0.0.0.0", # Necessario per HF Spaces | |
| #server_port=7860 # Porta standard per HF Spaces | |
| ) |