import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import random from sksurv.linear_model import CoxPHSurvivalAnalysis from sksurv.nonparametric import kaplan_meier_estimator from sksurv.util import Surv import pickle from preprocessing import * def obtain_scaler_and_label_enc(): dataset_url = "data/train.csv" target_column = "Exited" target_column_ttc = ["Exited", "Tenure"] preprocessor = Preprocessor(dataset_url, target_column, target_column_ttc, resampling="under", scaling='minmax') X_train_ttc, _, y_train_ttc, _, X_train_df_ttc, _, y_train_df_ttc, _ = preprocessor.process_ttcp() scaler = preprocessor.scaler label_encoders = preprocessor.label_encoders return scaler, label_encoders, X_train_ttc.columns, X_train_df_ttc, y_train_df_ttc def scale_dataset(test_df, target_column, train_cols, scaler): cols = [] for c in target_column: cols.append(c) X_test = test_df.drop(cols, axis = 1) X_test_df_ordered = X_test[train_cols] X_test_scaled = scaler.transform(X_test_df_ordered) X_test_scaled_df = pd.DataFrame(X_test_scaled, columns = X_test_df_ordered.columns, index = X_test_df_ordered.index) test_df_scaled = pd.concat([X_test_scaled_df, test_df[target_column]], axis = 1) return X_test_df_ordered, test_df_scaled def extract_customer(test, test_pd_X_tmp, test_pd_y_tmp, test_unscaled_pd): cust_indices = test_pd_y_tmp.index customer_idx = random.choice(cust_indices.tolist()) print("Customer Index:", customer_idx) customer_pos=test_pd_X_tmp.index.get_loc(customer_idx) test_pd_X_tmp=test_pd_X_tmp.reset_index() test_pd_y_tmp=test_pd_y_tmp.reset_index() customer_x = test_pd_X_tmp[customer_pos:customer_pos+1] customer_y = test_pd_y_tmp[customer_pos:customer_pos+1] customer_x_original = test_unscaled_pd.loc[customer_idx] customer_record = test[test.index == customer_idx] return customer_pos,customer_idx,customer_x.set_index('index'), customer_y.set_index('index'), customer_x_original,customer_record def plot_single_customer_survival_curve(customer_idx, df_test, test_features, cox_model, df_train, max_time=10, figsize=(12, 6)): """ Plotta la curva di sopravvivenza per un singolo cliente confrontata con la popolazione. Parametri: ----------- customer_idx : int Indice del cliente da analizzare df_test : pandas DataFrame Dataset di test con colonne 'Tenure' e 'Exited' test_features : pandas DataFrame Features preprocessate per il test set cox_model : modello Cox di scikit-survival Modello di sopravvivenza già addestrato df_train : pandas DataFrame Dati di training con 'Tenure' e 'Exited' per baseline max_time : int Tempo massimo da visualizzare (years) figsize : tuple Dimensioni della figura """ # Estrai dati del cliente customer_features = test_features.loc[[customer_idx]] customer_data = df_test.loc[customer_idx] actual_tenure = customer_data['Tenure'] actual_churn = customer_data['Exited'] # Calcola risk score risk_score = np.exp(cox_model.predict(customer_features))[0] # Baseline survival (Kaplan-Meier) event_observed = df_train['Exited'].astype(bool).values time = df_train['Tenure'].values time_points, survival_prob = kaplan_meier_estimator(event_observed, time) baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points) baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time] # Survival curve del cliente customer_survival = baseline_survival.copy() customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score customer_churn_prob = 1 - customer_survival['KM_estimate'] # Crea il plot fig, ax = plt.subplots(figsize=figsize) # Plot baseline popolazione ax.plot(baseline_survival.index, 1 - baseline_survival['KM_estimate'], 'k--', alpha=0.4, linewidth=2, label='Population mean') # Plot curva del cliente color = 'red' if actual_churn == 1 else 'orange' ax.plot(customer_survival.index, customer_churn_prob, color=color, linewidth=3, label=f'Customer #{customer_idx} (Risk Score: {risk_score:.2f})') # Marca il punto attuale if actual_tenure <= max_time: current_prob = np.interp(actual_tenure, customer_survival.index, customer_churn_prob.values) ax.scatter(actual_tenure, current_prob, color='blue', s=200, marker='*', edgecolor='black', linewidth=2, label=f'Actual Position ({actual_tenure:.1f} years)', zorder=10) # Soglie di rischio - trova l'intersezione esatta con la curva for prob in [0.25, 0.5, 0.75]: ax.axhline(y=prob, color='gray', linestyle=':', alpha=0.5) ax.text(max_time*0.85, prob + 0.02, f'{prob*100:.0f}%', color='gray', fontsize=9) # Trova il primo punto in cui la curva supera la soglia threshold_exceeded = customer_churn_prob[customer_churn_prob >= prob] if not threshold_exceeded.empty: first_time = threshold_exceeded.index[0] # Interpolazione lineare per trovare l'intersezione ESATTA # Trova i due punti immediatamente prima e dopo la soglia idx_after = customer_churn_prob[customer_churn_prob >= prob].index[0] idx_after_pos = customer_survival.index.get_loc(idx_after) if idx_after_pos > 0: idx_before = customer_survival.index[idx_after_pos - 1] prob_before = customer_churn_prob.loc[idx_before] prob_after = customer_churn_prob.loc[idx_after] # Interpolazione lineare per trovare il tempo esatto if prob_after != prob_before: # Evita divisione per zero time_exact = idx_before + (prob - prob_before) * (idx_after - idx_before) / (prob_after - prob_before) else: time_exact = idx_before else: time_exact = first_time # Plotta il pallino all'intersezione esatta ax.scatter(time_exact, prob, color='red', s=100, edgecolor='black', linewidth=2, zorder=5) # Posiziona l'etichetta BEN SOPRA il pallino con sfondo bianco opaco ax.annotate(f'{time_exact:.1f}y', xy=(time_exact, prob), xytext=(time_exact, prob + 0.12), # Aumentato l'offset verticale fontsize=9, weight='bold', ha='center', va='bottom', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', linewidth=1.5, alpha=1.0), # Alpha=1.0 per box opaco zorder=10) ax.set_xlabel('Time (years)', fontsize=12) ax.set_ylabel('Churn Probability', fontsize=12) ax.set_title(f'Churn Probability over Time - Customer #{customer_idx}', fontsize=14, weight='bold') ax.set_xlim(0, max_time) ax.set_ylim(0, 1) ax.legend(loc='upper left', fontsize=10) ax.grid(True, alpha=0.3) plt.tight_layout() return fig def plot_single_customer_risk_timeline(customer_idx, df_test, test_features, cox_model, df_train, max_time=10, figsize=(12, 6)): """ Plotta la timeline del rischio per un singolo cliente con zone colorate. Parametri: stessi di plot_single_customer_survival_curve """ # Estrai dati del cliente customer_features = test_features.loc[[customer_idx]] customer_data = df_test.loc[customer_idx] actual_tenure = customer_data['Tenure'] risk_score = np.exp(cox_model.predict(customer_features))[0] # Baseline survival event_observed = df_train['Exited'].astype(bool).values time = df_train['Tenure'].values time_points, survival_prob = kaplan_meier_estimator(event_observed, time) baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points) baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time] # Survival curve del cliente customer_survival = baseline_survival.copy() customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score customer_risk = 1 - customer_survival['KM_estimate'] # Crea il plot fig, ax = plt.subplots(figsize=figsize) # Zone di rischio ax.axhspan(0, 0.25, alpha=0.15, color='green', label='Low Risk') ax.axhspan(0.25, 0.5, alpha=0.15, color='yellow', label='Medium Risk') ax.axhspan(0.5, 0.75, alpha=0.15, color='orange', label='High Risk') ax.axhspan(0.75, 1, alpha=0.15, color='red', label='Critical Risk') # Plot del rischio times = customer_survival.index ax.fill_between(times, 0, customer_risk, alpha=0.4, color='darkred') ax.plot(times, customer_risk, color='darkred', linewidth=3, label=f'Customer Risk #{customer_idx}') # Marca la posizione attuale if actual_tenure <= max_time: current_risk = np.interp(actual_tenure, times, customer_risk.values) ax.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2.5) ax.scatter(actual_tenure, current_risk, color='blue', s=250, marker='*', edgecolor='black', linewidth=2, zorder=10) ax.text(actual_tenure, 0.95, f'Today\n{current_risk:.1%}', ha='center', fontsize=10, weight='bold', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='blue', linewidth=2)) ax.set_xlabel('Time (years)', fontsize=12) ax.set_ylabel('Churn Risk', fontsize=12) ax.set_title(f'Risk Timeline - Customer #{customer_idx}', fontsize=14, weight='bold') ax.set_xlim(0, max_time) ax.set_ylim(0, 1) ax.legend(loc='upper left', fontsize=10) ax.grid(True, alpha=0.3) plt.tight_layout() return fig def plot_single_customer_survival_bars(customer_idx, df_test, test_features, cox_model, df_train, time_points=[2, 4, 6, 8, 10], figsize=(10, 6)): """ Plotta bar chart della probabilità di sopravvivenza a intervalli temporali. Parametri: ----------- customer_idx : int Indice del cliente df_test, test_features, cox_model, df_train : come sopra time_points : list Punti temporali da visualizzare (years) figsize : tuple Dimensioni della figura """ # Estrai dati del cliente customer_features = test_features.loc[[customer_idx]] risk_score = np.exp(cox_model.predict(customer_features))[0] # Baseline survival event_observed = df_train['Exited'].astype(bool).values time = df_train['Tenure'].values time_points_km, survival_prob = kaplan_meier_estimator(event_observed, time) baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points_km) # Survival curve del cliente customer_survival = baseline_survival.copy() customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score # Calcola probabilità ai vari time points max_time = baseline_survival.index.max() time_points = [t for t in time_points if t <= max_time] customer_probs = [] population_probs = [] churn_probs = [] for t in time_points: if t in customer_survival.index: customer_probs.append(customer_survival.loc[t, 'KM_estimate']) population_probs.append(baseline_survival.loc[t, 'KM_estimate']) else: customer_probs.append(np.interp(t, customer_survival.index, customer_survival['KM_estimate'].values)) population_probs.append(np.interp(t, baseline_survival.index, baseline_survival['KM_estimate'].values)) churn_probs.append(1 - customer_probs[-1]) # Crea il plot fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Plot 1: Probabilità di Sopravvivenza x_pos = np.arange(len(time_points)) width = 0.35 bars1 = ax1.bar(x_pos - width/2, customer_probs, width, label=f'Customer #{customer_idx}', color='steelblue', alpha=0.8, edgecolor='black') bars2 = ax1.bar(x_pos + width/2, population_probs, width, label='Population mean', color='lightgray', alpha=0.8, edgecolor='black') ax1.set_xticks(x_pos) ax1.set_xticklabels([f'{t}y' for t in time_points]) ax1.set_ylabel('Survival Probability', fontsize=11) ax1.set_xlabel('Time', fontsize=11) ax1.set_title('Survival Probability through Time', fontsize=12, weight='bold') ax1.set_ylim(0, 1.1) ax1.legend(fontsize=9) ax1.grid(True, alpha=0.3, axis='y') # Aggiungi valori sulle barre for bars in [bars1, bars2]: for bar in bars: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height + 0.02, f'{height:.2%}', ha='center', va='bottom', fontsize=8) # Plot 2: Probabilità di Churn (degradante) colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(time_points))) bars3 = ax2.bar(x_pos, churn_probs, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax2.set_xticks(x_pos) ax2.set_xticklabels([f'{t}y' for t in time_points]) ax2.set_ylabel('Churn Probability', fontsize=11) ax2.set_xlabel('Time', fontsize=11) ax2.set_title('Churn Probability Evolution', fontsize=12, weight='bold') ax2.set_ylim(0, 1.1) ax2.grid(True, alpha=0.3, axis='y') # Aggiungi valori sulle barre for i, (bar, prob) in enumerate(zip(bars3, churn_probs)): height = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02, f'{prob:.1%}', ha='center', va='bottom', fontsize=9, weight='bold') plt.suptitle(f'Temporal Analysis - Customer #{customer_idx}', fontsize=14, weight='bold', y=1.02) plt.tight_layout() return fig def plot_single_customer_complete(customer_idx, df_test, test_features, cox_model, df_train, max_time=10, time_points=[2, 4, 6, 8, 10], figsize=(16, 10)): """ Crea un plot completo con tutte le visualizzazioni di survival analysis per un singolo cliente. Parametri: combinazione dei parametri delle funzioni precedenti """ # Estrai dati del cliente customer_features = test_features.loc[[customer_idx]] customer_data = df_test.loc[customer_idx] actual_tenure = customer_data['Tenure'] actual_churn = customer_data['Exited'] risk_score = np.exp(cox_model.predict(customer_features))[0] # Baseline survival event_observed = df_train['Exited'].astype(bool).values time = df_train['Tenure'].values time_points_km, survival_prob = kaplan_meier_estimator(event_observed, time) baseline_survival = pd.DataFrame({'KM_estimate': survival_prob}, index=time_points_km) baseline_survival = baseline_survival.loc[baseline_survival.index <= max_time] # Survival curve del cliente customer_survival = baseline_survival.copy() customer_survival['KM_estimate'] = baseline_survival['KM_estimate'] ** risk_score customer_churn_prob = 1 - customer_survival['KM_estimate'] customer_risk = customer_churn_prob # Crea figura con 3 subplots fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3) # === PLOT 1: Curva di Sopravvivenza === ax1 = fig.add_subplot(gs[0, :]) # Plot baseline popolazione (SOPRAVVIVENZA, non churn) ax1.plot(baseline_survival.index, baseline_survival['KM_estimate'], 'k--', alpha=0.4, linewidth=2, label='Population Mean') color = 'red' if actual_churn == 1 else 'orange' # Plot curva del cliente (SOPRAVVIVENZA) ax1.plot(customer_survival.index, customer_survival['KM_estimate'], color=color, linewidth=3, label=f'Customer #{customer_idx} (Risk Score: {risk_score:.2f})') if actual_tenure <= max_time: current_survival = np.interp(actual_tenure, customer_survival.index, customer_survival['KM_estimate'].values) ax1.scatter(actual_tenure, current_survival, color='blue', s=200, marker='*', edgecolor='black', linewidth=2, label=f'Actual Position ({actual_tenure:.1f} years)', zorder=10) # Soglie di sopravvivenza (75%, 50%, 25% = rischio 25%, 50%, 75%) survival_thresholds = [0.75, 0.5, 0.25] for surv_prob in survival_thresholds: ax1.axhline(y=surv_prob, color='gray', linestyle=':', alpha=0.5) # Trova quando la sopravvivenza scende sotto questa soglia threshold_crossed = customer_survival['KM_estimate'][customer_survival['KM_estimate'] <= surv_prob] if not threshold_crossed.empty: first_time = threshold_crossed.index[0] # Interpolazione lineare per intersezione esatta idx_after = customer_survival['KM_estimate'][customer_survival['KM_estimate'] <= surv_prob].index[0] idx_after_pos = customer_survival.index.get_loc(idx_after) if idx_after_pos > 0: idx_before = customer_survival.index[idx_after_pos - 1] prob_before = customer_survival.loc[idx_before, 'KM_estimate'] prob_after = customer_survival.loc[idx_after, 'KM_estimate'] if prob_after != prob_before: time_exact = idx_before + (surv_prob - prob_before) * (idx_after - idx_before) / (prob_after - prob_before) else: time_exact = idx_before else: time_exact = first_time ax1.scatter(time_exact, surv_prob, color='green', s=100, edgecolor='black', linewidth=2, zorder=5) # Etichetta con sfondo bianco opaco ben sopra il pallino risk_equivalent = (1 - surv_prob) * 100 ax1.annotate(f'{time_exact:.1f}y\n({risk_equivalent:.0f}% risk)', xy=(time_exact, surv_prob), xytext=(time_exact, surv_prob - 0.12), # Sotto per sopravvivenza fontsize=8, weight='bold', ha='center', va='top', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='green', linewidth=1.5, alpha=1.0), zorder=10) ax1.set_xlabel('Time (years)', fontsize=11) ax1.set_ylabel('Survival Probability', fontsize=11) ax1.set_title('Survival Probability through Time', fontsize=12, weight='bold') ax1.set_xlim(0, max_time) ax1.set_ylim(0, 1) ax1.legend(loc='upper right', fontsize=9) ax1.grid(True, alpha=0.3) # === PLOT 2: Timeline del Rischio === """ax2 = fig.add_subplot(gs[1, 0]) ax2.axhspan(0, 0.25, alpha=0.15, color='green', label='Low') ax2.axhspan(0.25, 0.5, alpha=0.15, color='yellow', label='Medium') ax2.axhspan(0.5, 0.75, alpha=0.15, color='orange', label='High') ax2.axhspan(0.75, 1, alpha=0.15, color='red', label='Critical') times = customer_survival.index ax2.fill_between(times, 0, customer_risk, alpha=0.4, color='darkred') ax2.plot(times, customer_risk, color='darkred', linewidth=2.5) if actual_tenure <= max_time: current_risk = np.interp(actual_tenure, times, customer_risk.values) ax2.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2) ax2.scatter(actual_tenure, current_risk, color='blue', s=200, marker='*', edgecolor='black', linewidth=2, zorder=10) ax2.set_xlabel('Time (years)', fontsize=11) ax2.set_ylabel('Risk', fontsize=11) ax2.set_title('Risk Timeline', fontsize=12, weight='bold') ax2.set_xlim(0, max_time) ax2.set_ylim(0, 1) ax2.legend(loc='upper left', fontsize=8, title='Level') ax2.grid(True, alpha=0.3)""" ax2 = fig.add_subplot(gs[1, 0]) # Calcola cumulative hazard (può superare 1!) baseline_cumhaz = -np.log(baseline_survival['KM_estimate']) customer_cumhaz = baseline_cumhaz * risk_score # Scaled by risk score # Zone di rischio basate su Risk Score ax2.axhspan(0, 0.8, alpha=0.15, color='green', label='Low (<0.8)') ax2.axhspan(0.8, 1.2, alpha=0.15, color='yellow', label='Medium (0.8-1.2)') ax2.axhspan(1.2, 1.8, alpha=0.15, color='orange', label='High (1.2-1.8)') ax2.axhspan(1.8, max(customer_cumhaz.max(), 3), alpha=0.15, color='red', label='Critical (>1.8)') times = baseline_survival.index ax2.fill_between(times, 0, customer_cumhaz, alpha=0.4, color='darkred') ax2.plot(times, customer_cumhaz, color='darkred', linewidth=2.5, label=f'Customer (RS={risk_score:.2f})') # Linea baseline (risk score = 1) ax2.plot(times, baseline_cumhaz, color='gray', linewidth=2, linestyle='--', alpha=0.6, label='Population Mean (RS=1.0)') if actual_tenure <= max_time: current_cumhaz = np.interp(actual_tenure, times, customer_cumhaz.values) ax2.axvline(actual_tenure, color='blue', linestyle='--', linewidth=2) ax2.scatter(actual_tenure, current_cumhaz, color='blue', s=200, marker='*', edgecolor='black', linewidth=2, zorder=10) ax2.text(actual_tenure, current_cumhaz + 0.1, f'Today\nCH={current_cumhaz:.2f}', ha='center', fontsize=8, weight='bold', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9)) ax2.set_xlabel('Time (years)', fontsize=11) ax2.set_ylabel('Cumulative Hazard', fontsize=11) ax2.set_title('Cumulative Hazard Timeline', fontsize=12, weight='bold') ax2.set_xlim(0, max_time) ax2.set_ylim(0, max(customer_cumhaz.max() * 1.2, 3)) # Dynamic y-axis ax2.legend(loc='upper left', fontsize=8) ax2.grid(True, alpha=0.3) # Aggiungi nota esplicativa ax2.text(0.98, 0.02, f'Risk Score: {risk_score:.2f}\n' + ('Critical Risk' if risk_score > 1.8 else 'High Risk' if risk_score > 1.2 else 'Medium Risk' if risk_score > 0.8 else 'Low Risk'), transform=ax2.transAxes, fontsize=9, va='bottom', ha='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='red' if risk_score > 1.5 else 'orange' if risk_score > 0.8 else 'green', linewidth=2)) # === PLOT 3: Bar Chart === ax3 = fig.add_subplot(gs[1, 1]) time_points = [t for t in time_points if t <= max_time] churn_probs = [] for t in time_points: if t in customer_survival.index: prob = 1 - customer_survival.loc[t, 'KM_estimate'] else: prob = 1 - np.interp(t, customer_survival.index, customer_survival['KM_estimate'].values) churn_probs.append(prob) colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, len(time_points))) bars = ax3.bar(range(len(time_points)), churn_probs, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) ax3.set_xticks(range(len(time_points))) ax3.set_xticklabels([f'{t}y' for t in time_points]) ax3.set_ylabel('Churn Probability', fontsize=11) ax3.set_xlabel('Time', fontsize=11) ax3.set_title('Churn Probability at Invervals', fontsize=12, weight='bold') ax3.set_ylim(0, 1.1) ax3.grid(True, alpha=0.3, axis='y') for bar, prob in zip(bars, churn_probs): height = bar.get_height() ax3.text(bar.get_x() + bar.get_width()/2., height + 0.02, f'{prob:.1%}', ha='center', va='bottom', fontsize=9, weight='bold') status = 'CHURNER' if actual_churn == 1 else 'NON-CHURNER' plt.suptitle(f'Survival Analysis - Customer #{customer_idx} ({status})', fontsize=15, weight='bold') plt.tight_layout() return fig if __name__ == "__main__": print("Download the dataset") hundred_churners_val = pd.read_csv('data/hundred_val_churners.csv', index_col=0) hundred_non_churners_val = pd.read_csv('data/hundred_val_non_churners.csv', index_col=0) val_df = pd.concat([hundred_churners_val, hundred_non_churners_val], axis=0) y_val_df = val_df[['Exited', 'Tenure']] print("Scale the dataset") scaler, label_encs, train_cols, X_train, y_train = obtain_scaler_and_label_enc() val_ordered_df, val_scaled_df = scale_dataset(val_df, ["Exited", "Tenure"], train_cols, scaler) y_val = Surv.from_dataframe("Exited", "Tenure", y_val_df) val_unscaled_pd = val_ordered_df val2 = pd.DataFrame(val_scaled_df, columns=val_ordered_df.columns, index = val_ordered_df.index) val_pd_X = val2 val_pd_y = val_df['Exited'] cph = CoxPHSurvivalAnalysis() # Load the trained model with open('models/cox_model.pkl', 'rb') as f: cph = pickle.load(f) # Predict risk scores for the test set prediction = cph.predict(val_scaled_df.drop(['Exited', 'Tenure'], axis = 1)) val_scaled_df['preds'] = prediction # Predict survival functions surv_func = cph.predict_survival_function(val_scaled_df.drop(['Exited', 'Tenure', 'preds'], axis = 1), return_array = True) df_surv = pd.DataFrame(surv_func.T, columns = val_scaled_df.index) threshold = 0.5 predicted_time_to_churn = (df_surv <= threshold) churns = predicted_time_to_churn.idxmax().where(predicted_time_to_churn.any()) val_scaled_df['absolute_time_to_churn'] = churns val_scaled_df['absolute_time_to_churn'].fillna(11, inplace=True) val_scaled_df['Churn_Prediction'] = (val_scaled_df['absolute_time_to_churn'] <= 10).astype(int) churners_and_non = val_scaled_df churners_and_non_X = val_pd_X[val_pd_X.index.isin(churners_and_non.index.tolist())] X_val_final = val_scaled_df # Create dataframe for plotting survival curves df_surv_final = pd.concat([df_surv, X_val_final[['Exited', 'Tenure', 'absolute_time_to_churn', 'Churn_Prediction']].T], axis = 0) df_surv_def = df_surv_final.T sample_size=30 customer_pos, customer_idx, customer_x, customer_y, customer_x_original,customer_record = extract_customer(val2, churners_and_non_X.sample(sample_size,random_state=42), churners_and_non.sample(sample_size,random_state=42), val_unscaled_pd) df_train = pd.concat([X_train, y_train], axis = 1) # Nel tuo codice esistente, dopo aver estratto il cliente random e plottato SHAP: test_features = pd.DataFrame(val_scaled_df, columns = val_scaled_df.drop(['preds', 'absolute_time_to_churn', 'Churn_Prediction', 'Exited', 'Tenure'], axis = 1).columns, index = val_scaled_df.index) # Plot singoli fig1 = plot_single_customer_survival_curve(customer_idx, X_val_final, test_features, cph, df_train) fig2 = plot_single_customer_risk_timeline(customer_idx, X_val_final, test_features, cph, df_train) fig3 = plot_single_customer_survival_bars(customer_idx, X_val_final, test_features, cph, df_train) # Oppure tutto insieme fig = plot_single_customer_complete(customer_idx, X_val_final, test_features, cph, df_train, max_time=10) plt.savefig(f'img/survival_customer_{customer_idx}.png', dpi=300, bbox_inches='tight') plt.show()