HYPE_Churn_Analysis / survival_analysis.py
cmmedoro
Upload demo code
f8da90e
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.util import Surv
import pickle
from preprocessing import *
def obtain_scaler_and_label_enc():
dataset_url = "data/train.csv"
target_column = "Exited"
target_column_ttc = ["Exited", "Tenure"]
preprocessor = Preprocessor(dataset_url, target_column, target_column_ttc, resampling="under", scaling='minmax')
X_train_ttc, _, y_train_ttc, _, X_train_df_ttc, _, y_train_df_ttc, _ = preprocessor.process_ttcp()
scaler = preprocessor.scaler
label_encoders = preprocessor.label_encoders
return scaler, label_encoders, X_train_ttc.columns, X_train_df_ttc, y_train_df_ttc
def scale_dataset(test_df, target_column, train_cols, scaler):
cols = []
for c in target_column:
cols.append(c)
X_test = test_df.drop(cols, axis = 1)
X_test_df_ordered = X_test[train_cols]
X_test_scaled = scaler.transform(X_test_df_ordered)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns = X_test_df_ordered.columns, index = X_test_df_ordered.index)
test_df_scaled = pd.concat([X_test_scaled_df, test_df[target_column]], axis = 1)
return X_test_df_ordered, test_df_scaled
def extract_customer(test, test_pd_X_tmp, test_pd_y_tmp, test_unscaled_pd):
cust_indices = test_pd_y_tmp.index
customer_idx = random.choice(cust_indices.tolist())
print("Customer Index:", customer_idx)
customer_pos=test_pd_X_tmp.index.get_loc(customer_idx)
test_pd_X_tmp=test_pd_X_tmp.reset_index()
test_pd_y_tmp=test_pd_y_tmp.reset_index()
customer_x = test_pd_X_tmp[customer_pos:customer_pos+1]
customer_y = test_pd_y_tmp[customer_pos:customer_pos+1]
customer_x_original = test_unscaled_pd.loc[customer_idx]
customer_record = test[test.index == customer_idx]
return customer_pos,customer_idx,customer_x.set_index('index'), customer_y.set_index('index'), customer_x_original,customer_record
def plot_single_customer_survival_curve(customer_idx, df_test, test_features,
cox_model, df_train, max_time=10,
figsize=(12, 6)):
"""
Plotta la curva di sopravvivenza per un singolo cliente confrontata con la popolazione.
Parametri:
-----------
customer_idx : int
Indice del cliente da analizzare
df_test : pandas DataFrame
Dataset di test con colonne 'Tenure' e 'Exited'
test_features : pandas DataFrame
Features preprocessate per il test set
cox_model : modello Cox di scikit-survival
Modello di sopravvivenza già addestrato
df_train : pandas DataFrame
Dati di training con 'Tenure' e 'Exited' per baseline
max_time : int
Tempo massimo da visualizzare (years)
figsize : tuple
Dimensioni della figura
"""
# Estrai dati del cliente
customer_features = test_features.loc[[customer_idx]]
customer_data = df_test.loc[customer_idx]
actual_tenure = customer_data['Tenure']
actual_churn = customer_data['Exited']
# Calcola risk score
risk_score = np.exp(cox_model.predict(customer_features))[0]
# Baseline survival (Kaplan-Meier)
event_observed = df_train['Exited'].astype(bool).values
time = df_train['Tenure'].values
time_points, survival_prob = kaplan_meier_estimator(event_observed, time)
baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points)
baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time]
# Survival curve del cliente
customer_survival = baseline_survival.copy()
customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score
customer_churn_prob = 1 - customer_survival['KM_estimate']
# Crea il plot
fig, ax = plt.subplots(figsize=figsize)
# Plot baseline popolazione
ax.plot(baseline_survival.index, 1 - baseline_survival['KM_estimate'],
'k--', alpha=0.4, linewidth=2, label='Population mean')
# Plot curva del cliente
color = 'red' if actual_churn == 1 else 'orange'
ax.plot(customer_survival.index, customer_churn_prob,
color=color, linewidth=3,
label=f'Customer #{customer_idx} (Risk Score: {risk_score:.2f})')
# Marca il punto attuale
if actual_tenure <= max_time:
current_prob = np.interp(actual_tenure, customer_survival.index,
customer_churn_prob.values)
ax.scatter(actual_tenure, current_prob, color='blue', s=200,
marker='*', edgecolor='black', linewidth=2,
label=f'Actual Position ({actual_tenure:.1f} years)', zorder=10)
# Soglie di rischio - trova l'intersezione esatta con la curva
for prob in [0.25, 0.5, 0.75]:
ax.axhline(y=prob, color='gray', linestyle=':', alpha=0.5)
ax.text(max_time*0.85, prob + 0.02, f'{prob*100:.0f}%',
color='gray', fontsize=9)
# Trova il primo punto in cui la curva supera la soglia
threshold_exceeded = customer_churn_prob[customer_churn_prob >= prob]
if not threshold_exceeded.empty:
first_time = threshold_exceeded.index[0]
# Interpolazione lineare per trovare l'intersezione ESATTA
# Trova i due punti immediatamente prima e dopo la soglia
idx_after = customer_churn_prob[customer_churn_prob >= prob].index[0]
idx_after_pos = customer_survival.index.get_loc(idx_after)
if idx_after_pos > 0:
idx_before = customer_survival.index[idx_after_pos - 1]
prob_before = customer_churn_prob.loc[idx_before]
prob_after = customer_churn_prob.loc[idx_after]
# Interpolazione lineare per trovare il tempo esatto
if prob_after != prob_before: # Evita divisione per zero
time_exact = idx_before + (prob - prob_before) * (idx_after - idx_before) / (prob_after - prob_before)
else:
time_exact = idx_before
else:
time_exact = first_time
# Plotta il pallino all'intersezione esatta
ax.scatter(time_exact, prob, color='red', s=100,
edgecolor='black', linewidth=2, zorder=5)
# Posiziona l'etichetta BEN SOPRA il pallino con sfondo bianco opaco
ax.annotate(f'{time_exact:.1f}y',
xy=(time_exact, prob),
xytext=(time_exact, prob + 0.12), # Aumentato l'offset verticale
fontsize=9, weight='bold',
ha='center', va='bottom',
bbox=dict(boxstyle='round,pad=0.4', facecolor='white',
edgecolor='black', linewidth=1.5, alpha=1.0), # Alpha=1.0 per box opaco
zorder=10)
ax.set_xlabel('Time (years)', fontsize=12)
ax.set_ylabel('Churn Probability', fontsize=12)
ax.set_title(f'Churn Probability over Time - Customer #{customer_idx}',
fontsize=14, weight='bold')
ax.set_xlim(0, max_time)
ax.set_ylim(0, 1)
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def plot_single_customer_risk_timeline(customer_idx, df_test, test_features,
cox_model, df_train, max_time=10,
figsize=(12, 6)):
"""
Plotta la timeline del rischio per un singolo cliente con zone colorate.
Parametri: stessi di plot_single_customer_survival_curve
"""
# Estrai dati del cliente
customer_features = test_features.loc[[customer_idx]]
customer_data = df_test.loc[customer_idx]
actual_tenure = customer_data['Tenure']
risk_score = np.exp(cox_model.predict(customer_features))[0]
# Baseline survival
event_observed = df_train['Exited'].astype(bool).values
time = df_train['Tenure'].values
time_points, survival_prob = kaplan_meier_estimator(event_observed, time)
baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points)
baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time]
# Survival curve del cliente
customer_survival = baseline_survival.copy()
customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score
customer_risk = 1 - customer_survival['KM_estimate']
# Crea il plot
fig, ax = plt.subplots(figsize=figsize)
# Zone di rischio
ax.axhspan(0, 0.25, alpha=0.15, color='green', label='Low Risk')
ax.axhspan(0.25, 0.5, alpha=0.15, color='yellow', label='Medium Risk')
ax.axhspan(0.5, 0.75, alpha=0.15, color='orange', label='High Risk')
ax.axhspan(0.75, 1, alpha=0.15, color='red', label='Critical Risk')
# Plot del rischio
times = customer_survival.index
ax.fill_between(times, 0, customer_risk, alpha=0.4, color='darkred')
ax.plot(times, customer_risk, color='darkred', linewidth=3,
label=f'Customer Risk #{customer_idx}')
# Marca la posizione attuale
if actual_tenure <= max_time:
current_risk = np.interp(actual_tenure, times, customer_risk.values)
ax.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2.5)
ax.scatter(actual_tenure, current_risk, color='blue', s=250,
marker='*', edgecolor='black', linewidth=2, zorder=10)
ax.text(actual_tenure, 0.95, f'Today\n{current_risk:.1%}',
ha='center', fontsize=10, weight='bold',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9,
edgecolor='blue', linewidth=2))
ax.set_xlabel('Time (years)', fontsize=12)
ax.set_ylabel('Churn Risk', fontsize=12)
ax.set_title(f'Risk Timeline - Customer #{customer_idx}',
fontsize=14, weight='bold')
ax.set_xlim(0, max_time)
ax.set_ylim(0, 1)
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig
def plot_single_customer_survival_bars(customer_idx, df_test, test_features,
cox_model, df_train,
time_points=[2, 4, 6, 8, 10],
figsize=(10, 6)):
"""
Plotta bar chart della probabilità di sopravvivenza a intervalli temporali.
Parametri:
-----------
customer_idx : int
Indice del cliente
df_test, test_features, cox_model, df_train : come sopra
time_points : list
Punti temporali da visualizzare (years)
figsize : tuple
Dimensioni della figura
"""
# Estrai dati del cliente
customer_features = test_features.loc[[customer_idx]]
risk_score = np.exp(cox_model.predict(customer_features))[0]
# Baseline survival
event_observed = df_train['Exited'].astype(bool).values
time = df_train['Tenure'].values
time_points_km, survival_prob = kaplan_meier_estimator(event_observed, time)
baseline_survival = pd.DataFrame({'KM_estimate': survival_prob},
index=time_points_km)
# Survival curve del cliente
customer_survival = baseline_survival.copy()
customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score
# Calcola probabilità ai vari time points
max_time = baseline_survival.index.max()
time_points = [t for t in time_points if t <= max_time]
customer_probs = []
population_probs = []
churn_probs = []
for t in time_points:
if t in customer_survival.index:
customer_probs.append(customer_survival.loc[t, 'KM_estimate'])
population_probs.append(baseline_survival.loc[t, 'KM_estimate'])
else:
customer_probs.append(np.interp(t, customer_survival.index,
customer_survival['KM_estimate'].values))
population_probs.append(np.interp(t, baseline_survival.index,
baseline_survival['KM_estimate'].values))
churn_probs.append(1 - customer_probs[-1])
# Crea il plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Plot 1: Probabilità di Sopravvivenza
x_pos = np.arange(len(time_points))
width = 0.35
bars1 = ax1.bar(x_pos - width/2, customer_probs, width,
label=f'Customer #{customer_idx}',
color='steelblue', alpha=0.8, edgecolor='black')
bars2 = ax1.bar(x_pos + width/2, population_probs, width,
label='Population mean',
color='lightgray', alpha=0.8, edgecolor='black')
ax1.set_xticks(x_pos)
ax1.set_xticklabels([f'{t}y' for t in time_points])
ax1.set_ylabel('Survival Probability', fontsize=11)
ax1.set_xlabel('Time', fontsize=11)
ax1.set_title('Survival Probability through Time', fontsize=12, weight='bold')
ax1.set_ylim(0, 1.1)
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3, axis='y')
# Aggiungi valori sulle barre
for bars in [bars1, bars2]:
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{height:.2%}', ha='center', va='bottom', fontsize=8)
# Plot 2: Probabilità di Churn (degradante)
colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(time_points)))
bars3 = ax2.bar(x_pos, churn_probs, color=colors, alpha=0.8,
edgecolor='black', linewidth=1.5)
ax2.set_xticks(x_pos)
ax2.set_xticklabels([f'{t}y' for t in time_points])
ax2.set_ylabel('Churn Probability', fontsize=11)
ax2.set_xlabel('Time', fontsize=11)
ax2.set_title('Churn Probability Evolution', fontsize=12, weight='bold')
ax2.set_ylim(0, 1.1)
ax2.grid(True, alpha=0.3, axis='y')
# Aggiungi valori sulle barre
for i, (bar, prob) in enumerate(zip(bars3, churn_probs)):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{prob:.1%}', ha='center', va='bottom', fontsize=9, weight='bold')
plt.suptitle(f'Temporal Analysis - Customer #{customer_idx}',
fontsize=14, weight='bold', y=1.02)
plt.tight_layout()
return fig
def plot_single_customer_complete(customer_idx, df_test, test_features,
cox_model, df_train, max_time=10,
time_points=[2, 4, 6, 8, 10],
figsize=(16, 10)):
"""
Crea un plot completo con tutte le visualizzazioni di survival analysis
per un singolo cliente.
Parametri: combinazione dei parametri delle funzioni precedenti
"""
# Estrai dati del cliente
customer_features = test_features.loc[[customer_idx]]
customer_data = df_test.loc[customer_idx]
actual_tenure = customer_data['Tenure']
actual_churn = customer_data['Exited']
risk_score = np.exp(cox_model.predict(customer_features))[0]
# Baseline survival
event_observed = df_train['Exited'].astype(bool).values
time = df_train['Tenure'].values
time_points_km, survival_prob = kaplan_meier_estimator(event_observed, time)
baseline_survival = pd.DataFrame({'KM_estimate': survival_prob},
index=time_points_km)
baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time]
# Survival curve del cliente
customer_survival = baseline_survival.copy()
customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score
customer_churn_prob = 1 - customer_survival['KM_estimate']
customer_risk = customer_churn_prob
# Crea figura con 3 subplots
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
# === PLOT 1: Curva di Sopravvivenza ===
ax1 = fig.add_subplot(gs[0, :])
# Plot baseline popolazione (SOPRAVVIVENZA, non churn)
ax1.plot(baseline_survival.index, baseline_survival['KM_estimate'],
'k--', alpha=0.4, linewidth=2, label='Population Mean')
color = 'red' if actual_churn == 1 else 'orange'
# Plot curva del cliente (SOPRAVVIVENZA)
ax1.plot(customer_survival.index, customer_survival['KM_estimate'],
color=color, linewidth=3,
label=f'Customer #{customer_idx} (Risk Score: {risk_score:.2f})')
if actual_tenure <= max_time:
current_survival = np.interp(actual_tenure, customer_survival.index,
customer_survival['KM_estimate'].values)
ax1.scatter(actual_tenure, current_survival, color='blue', s=200,
marker='*', edgecolor='black', linewidth=2,
label=f'Actual Position ({actual_tenure:.1f} years)', zorder=10)
# Soglie di sopravvivenza (75%, 50%, 25% = rischio 25%, 50%, 75%)
survival_thresholds = [0.75, 0.5, 0.25]
for surv_prob in survival_thresholds:
ax1.axhline(y=surv_prob, color='gray', linestyle=':', alpha=0.5)
# Trova quando la sopravvivenza scende sotto questa soglia
threshold_crossed = customer_survival['KM_estimate'][customer_survival['KM_estimate'] <= surv_prob]
if not threshold_crossed.empty:
first_time = threshold_crossed.index[0]
# Interpolazione lineare per intersezione esatta
idx_after = customer_survival['KM_estimate'][customer_survival['KM_estimate'] <= surv_prob].index[0]
idx_after_pos = customer_survival.index.get_loc(idx_after)
if idx_after_pos > 0:
idx_before = customer_survival.index[idx_after_pos - 1]
prob_before = customer_survival.loc[idx_before, 'KM_estimate']
prob_after = customer_survival.loc[idx_after, 'KM_estimate']
if prob_after != prob_before:
time_exact = idx_before + (surv_prob - prob_before) * (idx_after - idx_before) / (prob_after - prob_before)
else:
time_exact = idx_before
else:
time_exact = first_time
ax1.scatter(time_exact, surv_prob, color='green', s=100,
edgecolor='black', linewidth=2, zorder=5)
# Etichetta con sfondo bianco opaco ben sopra il pallino
risk_equivalent = (1 - surv_prob) * 100
ax1.annotate(f'{time_exact:.1f}y\n({risk_equivalent:.0f}% risk)',
xy=(time_exact, surv_prob),
xytext=(time_exact, surv_prob - 0.12), # Sotto per sopravvivenza
fontsize=8, weight='bold',
ha='center', va='top',
bbox=dict(boxstyle='round,pad=0.4', facecolor='white',
edgecolor='green', linewidth=1.5, alpha=1.0),
zorder=10)
ax1.set_xlabel('Time (years)', fontsize=11)
ax1.set_ylabel('Survival Probability', fontsize=11)
ax1.set_title('Survival Probability through Time', fontsize=12, weight='bold')
ax1.set_xlim(0, max_time)
ax1.set_ylim(0, 1)
ax1.legend(loc='upper right', fontsize=9)
ax1.grid(True, alpha=0.3)
# === PLOT 2: Timeline del Rischio ===
"""ax2 = fig.add_subplot(gs[1, 0])
ax2.axhspan(0, 0.25, alpha=0.15, color='green', label='Low')
ax2.axhspan(0.25, 0.5, alpha=0.15, color='yellow', label='Medium')
ax2.axhspan(0.5, 0.75, alpha=0.15, color='orange', label='High')
ax2.axhspan(0.75, 1, alpha=0.15, color='red', label='Critical')
times = customer_survival.index
ax2.fill_between(times, 0, customer_risk, alpha=0.4, color='darkred')
ax2.plot(times, customer_risk, color='darkred', linewidth=2.5)
if actual_tenure <= max_time:
current_risk = np.interp(actual_tenure, times, customer_risk.values)
ax2.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2)
ax2.scatter(actual_tenure, current_risk, color='blue', s=200,
marker='*', edgecolor='black', linewidth=2, zorder=10)
ax2.set_xlabel('Time (years)', fontsize=11)
ax2.set_ylabel('Risk', fontsize=11)
ax2.set_title('Risk Timeline', fontsize=12, weight='bold')
ax2.set_xlim(0, max_time)
ax2.set_ylim(0, 1)
ax2.legend(loc='upper left', fontsize=8, title='Level')
ax2.grid(True, alpha=0.3)"""
ax2 = fig.add_subplot(gs[1, 0])
# Calcola cumulative hazard (può superare 1!)
baseline_cumhaz = -np.log(baseline_survival['KM_estimate'])
customer_cumhaz = baseline_cumhaz * risk_score # Scaled by risk score
# Zone di rischio basate su Risk Score
ax2.axhspan(0, 0.8, alpha=0.15, color='green', label='Low (<0.8)')
ax2.axhspan(0.8, 1.2, alpha=0.15, color='yellow', label='Medium (0.8-1.2)')
ax2.axhspan(1.2, 1.8, alpha=0.15, color='orange', label='High (1.2-1.8)')
ax2.axhspan(1.8, max(customer_cumhaz.max(), 3), alpha=0.15, color='red', label='Critical (>1.8)')
times = baseline_survival.index
ax2.fill_between(times, 0, customer_cumhaz, alpha=0.4, color='darkred')
ax2.plot(times, customer_cumhaz, color='darkred', linewidth=2.5,
label=f'Customer (RS={risk_score:.2f})')
# Linea baseline (risk score = 1)
ax2.plot(times, baseline_cumhaz, color='gray', linewidth=2,
linestyle='--', alpha=0.6, label='Population Mean (RS=1.0)')
if actual_tenure <= max_time:
current_cumhaz = np.interp(actual_tenure, times, customer_cumhaz.values)
ax2.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2)
ax2.scatter(actual_tenure, current_cumhaz, color='blue', s=200,
marker='*', edgecolor='black', linewidth=2, zorder=10)
ax2.text(actual_tenure, current_cumhaz + 0.1,
f'Today\nCH={current_cumhaz:.2f}',
ha='center', fontsize=8, weight='bold',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
ax2.set_xlabel('Time (years)', fontsize=11)
ax2.set_ylabel('Cumulative Hazard', fontsize=11)
ax2.set_title('Cumulative Hazard Timeline', fontsize=12, weight='bold')
ax2.set_xlim(0, max_time)
ax2.set_ylim(0, max(customer_cumhaz.max() * 1.2, 3)) # Dynamic y-axis
ax2.legend(loc='upper left', fontsize=8)
ax2.grid(True, alpha=0.3)
# Aggiungi nota esplicativa
ax2.text(0.98, 0.02,
f'Risk Score: {risk_score:.2f}\n' +
('Critical Risk' if risk_score > 1.8 else 'High Risk' if risk_score > 1.2 else 'Medium Risk' if risk_score > 0.8 else 'Low Risk'),
transform=ax2.transAxes, fontsize=9, va='bottom', ha='right',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9,
edgecolor='red' if risk_score > 1.5 else 'orange' if risk_score > 0.8 else 'green',
linewidth=2))
# === PLOT 3: Bar Chart ===
ax3 = fig.add_subplot(gs[1, 1])
time_points = [t for t in time_points if t <= max_time]
churn_probs = []
for t in time_points:
if t in customer_survival.index:
prob = 1 - customer_survival.loc[t, 'KM_estimate']
else:
prob = 1 - np.interp(t, customer_survival.index,
customer_survival['KM_estimate'].values)
churn_probs.append(prob)
colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(time_points)))
bars = ax3.bar(range(len(time_points)), churn_probs, color=colors,
alpha=0.8, edgecolor='black', linewidth=1.5)
ax3.set_xticks(range(len(time_points)))
ax3.set_xticklabels([f'{t}y' for t in time_points])
ax3.set_ylabel('Churn Probability', fontsize=11)
ax3.set_xlabel('Time', fontsize=11)
ax3.set_title('Churn Probability at Invervals', fontsize=12, weight='bold')
ax3.set_ylim(0, 1.1)
ax3.grid(True, alpha=0.3, axis='y')
for bar, prob in zip(bars, churn_probs):
height = bar.get_height()
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{prob:.1%}', ha='center', va='bottom', fontsize=9, weight='bold')
status = 'CHURNER' if actual_churn == 1 else 'NON-CHURNER'
plt.suptitle(f'Survival Analysis - Customer #{customer_idx} ({status})',
fontsize=15, weight='bold')
plt.tight_layout()
return fig
if __name__ == "__main__":
print("Download the dataset")
hundred_churners_val = pd.read_csv('data/hundred_val_churners.csv', index_col=0)
hundred_non_churners_val = pd.read_csv('data/hundred_val_non_churners.csv', index_col=0)
val_df = pd.concat([hundred_churners_val, hundred_non_churners_val], axis=0)
y_val_df = val_df[['Exited', 'Tenure']]
print("Scale the dataset")
scaler, label_encs, train_cols, X_train, y_train = obtain_scaler_and_label_enc()
val_ordered_df, val_scaled_df = scale_dataset(val_df, ["Exited", "Tenure"], train_cols, scaler)
y_val = Surv.from_dataframe("Exited", "Tenure", y_val_df)
val_unscaled_pd = val_ordered_df
val2 = pd.DataFrame(val_scaled_df, columns=val_ordered_df.columns, index = val_ordered_df.index)
val_pd_X = val2
val_pd_y = val_df['Exited']
cph = CoxPHSurvivalAnalysis()
# Load the trained model
with open('models/cox_model.pkl', 'rb') as f:
cph = pickle.load(f)
# Predict risk scores for the test set
prediction = cph.predict(val_scaled_df.drop(['Exited', 'Tenure'], axis = 1))
val_scaled_df['preds'] = prediction
# Predict survival functions
surv_func = cph.predict_survival_function(val_scaled_df.drop(['Exited', 'Tenure', 'preds'], axis = 1), return_array = True)
df_surv = pd.DataFrame(surv_func.T, columns = val_scaled_df.index)
threshold = 0.5
predicted_time_to_churn = (df_surv <= threshold)
churns = predicted_time_to_churn.idxmax().where(predicted_time_to_churn.any())
val_scaled_df['absolute_time_to_churn'] = churns
val_scaled_df['absolute_time_to_churn'].fillna(11, inplace=True)
val_scaled_df['Churn_Prediction'] = (val_scaled_df['absolute_time_to_churn'] <= 10).astype(int)
churners_and_non = val_scaled_df
churners_and_non_X = val_pd_X[val_pd_X.index.isin(churners_and_non.index.tolist())]
X_val_final = val_scaled_df
# Create dataframe for plotting survival curves
df_surv_final = pd.concat([df_surv, X_val_final[['Exited', 'Tenure', 'absolute_time_to_churn', 'Churn_Prediction']].T], axis = 0)
df_surv_def = df_surv_final.T
sample_size=30
customer_pos, customer_idx, customer_x, customer_y, customer_x_original,customer_record = extract_customer(val2, churners_and_non_X.sample(sample_size,random_state=42), churners_and_non.sample(sample_size,random_state=42), val_unscaled_pd)
df_train = pd.concat([X_train, y_train], axis = 1)
# Nel tuo codice esistente, dopo aver estratto il cliente random e plottato SHAP:
test_features = pd.DataFrame(val_scaled_df, columns = val_scaled_df.drop(['preds', 'absolute_time_to_churn', 'Churn_Prediction', 'Exited', 'Tenure'], axis = 1).columns, index = val_scaled_df.index)
# Plot singoli
fig1 = plot_single_customer_survival_curve(customer_idx, X_val_final, test_features, cph, df_train)
fig2 = plot_single_customer_risk_timeline(customer_idx, X_val_final, test_features, cph, df_train)
fig3 = plot_single_customer_survival_bars(customer_idx, X_val_final, test_features, cph, df_train)
# Oppure tutto insieme
fig = plot_single_customer_complete(customer_idx, X_val_final, test_features, cph, df_train, max_time=10)
plt.savefig(f'img/survival_customer_{customer_idx}.png', dpi=300, bbox_inches='tight')
plt.show()