Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,11 +2,12 @@ import gradio as gr
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
from sklearn.cluster import KMeans
|
| 5 |
-
from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
-
import matplotlib.cm
|
|
|
|
| 8 |
import io
|
| 9 |
-
import os
|
| 10 |
from PIL import Image
|
| 11 |
|
| 12 |
# Define the paths for example data
|
|
@@ -22,98 +23,98 @@ EXAMPLE_FILES = {
|
|
| 22 |
}
|
| 23 |
|
| 24 |
class Clusters:
|
| 25 |
-
def __init__(self,
|
| 26 |
-
#
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
loc_vars_np_float32 = np.ascontiguousarray(loc_vars_df.astype(np.float32).values)
|
| 36 |
-
|
| 37 |
-
# Initialize KMeans with algorithm="elkan" for potential speedup
|
| 38 |
-
# and fit on the float32 data.
|
| 39 |
-
self.kmeans = KMeans(
|
| 40 |
-
n_clusters=1000,
|
| 41 |
-
random_state=0,
|
| 42 |
-
n_init=10,
|
| 43 |
-
algorithm="elkan" # Added for speed optimization
|
| 44 |
-
).fit(loc_vars_np_float32)
|
| 45 |
-
|
| 46 |
-
# cluster_centers_ will be float32 if fitted on float32 data.
|
| 47 |
-
# Pass the same float32 NumPy array for distance calculations.
|
| 48 |
-
closest, _ = pairwise_distances_argmin_min(
|
| 49 |
-
self.kmeans.cluster_centers_,
|
| 50 |
-
loc_vars_np_float32
|
| 51 |
-
)
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
-
# policy_count
|
| 58 |
-
|
| 59 |
-
|
| 60 |
|
| 61 |
def agg_by_cluster(self, df, agg=None):
|
| 62 |
"""Aggregate columns by cluster"""
|
| 63 |
temp = df.copy()
|
| 64 |
-
temp['cluster_id'] = self.kmeans.labels_
|
| 65 |
temp = temp.set_index('cluster_id')
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def extract_reps(self, df):
|
| 70 |
"""Extract the rows of representative policies"""
|
| 71 |
-
#
|
| 72 |
-
# Typically, df here will have 'policy_id' as its index as per original data.
|
| 73 |
-
# If df's index is not 'policy_id', ensure 'policy_id' column exists and has compatible type.
|
| 74 |
-
current_df_index_name = df.index.name
|
| 75 |
-
# If 'policy_id' is not the index, reset it. Otherwise, use the index.
|
| 76 |
if 'policy_id' not in df.columns and df.index.name != 'policy_id':
|
| 77 |
-
|
| 78 |
-
# Forcing index to be named 'policy_id' if it's the policy identifier
|
| 79 |
-
df_indexed = df.copy()
|
| 80 |
-
if df_indexed.index.name is None: # Or some other logic to identify the policy_id column
|
| 81 |
-
gr.Warning("DataFrame passed to extract_reps has no index name, assuming index is policy_id.")
|
| 82 |
-
df_indexed.index.name = 'policy_id'
|
| 83 |
-
|
| 84 |
-
temp = pd.merge(self.rep_ids, df_indexed.reset_index(), how='left', on='policy_id')
|
| 85 |
-
|
| 86 |
-
elif 'policy_id' in df.columns and df.index.name == 'policy_id' and df.index.name in df.columns: # if policy_id is both index and a column
|
| 87 |
-
temp = pd.merge(self.rep_ids, df, how='left', on='policy_id') # Merge on column if available
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
-
temp.
|
| 97 |
-
temp
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def extract_and_scale_reps(self, df, agg=None):
|
| 102 |
"""Extract and scale the rows of representative policies"""
|
| 103 |
extracted_df = self.extract_reps(df)
|
| 104 |
if agg:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
mult
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
else:
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
def compare(self, df, agg=None):
|
| 113 |
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
|
| 114 |
source = self.agg_by_cluster(df, agg)
|
| 115 |
target = self.extract_and_scale_reps(df, agg)
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def compare_total(self, df, agg=None):
|
| 119 |
"""Aggregate df by columns"""
|
|
@@ -130,37 +131,35 @@ class Clusters:
|
|
| 130 |
estimate_values = {}
|
| 131 |
|
| 132 |
for col in df.columns: # Iterate over original df columns to ensure all are covered
|
| 133 |
-
if col not in reps_unscaled.columns:
|
| 134 |
-
|
| 135 |
-
estimate_values[col] = np.nan # Or some other placeholder like 0, or actual.get(col, 0)
|
| 136 |
-
else:
|
| 137 |
-
estimate_values[col] = 0
|
| 138 |
-
gr.Warning(f"Column '{col}' not found in representative policies output for 'compare_total'. Estimate will be 0/NaN.")
|
| 139 |
continue
|
| 140 |
|
| 141 |
if agg.get(col, 'sum') == 'mean':
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
| 145 |
else: # sum
|
| 146 |
-
estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
|
| 147 |
-
|
| 148 |
estimate = pd.Series(estimate_values)
|
| 149 |
-
|
| 150 |
-
else: # Original logic if no agg is specified (all sum)
|
| 151 |
actual = df.sum()
|
| 152 |
estimate = self.extract_and_scale_reps(df).sum()
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
actual, estimate
|
| 156 |
-
error = np.
|
| 157 |
|
| 158 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
| 163 |
-
"""Create cashflow comparison plots"""
|
|
|
|
| 164 |
if not cfs_list or not cluster_obj or not titles:
|
| 165 |
return None
|
| 166 |
num_plots = len(cfs_list)
|
|
@@ -173,20 +172,30 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
|
| 173 |
fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
|
| 174 |
axes = axes.flatten()
|
| 175 |
|
| 176 |
-
for i, (
|
| 177 |
if i < len(axes):
|
| 178 |
-
|
| 179 |
-
#
|
| 180 |
-
|
| 181 |
-
if df.index.name != 'policy_id' and 'policy_id' not in df.columns:
|
| 182 |
-
gr.Warning(f"DataFrame for plot '{title}' does not have 'policy_id' as index or column. Results may be incorrect.")
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
for j in range(i + 1, len(axes)):
|
| 190 |
fig.delaxes(axes[j])
|
| 191 |
|
| 192 |
plt.tight_layout()
|
|
@@ -198,47 +207,77 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
|
| 198 |
return img
|
| 199 |
|
| 200 |
def plot_scatter_comparison(df_compare_output, title):
|
| 201 |
-
"""Create scatter plot comparison from compare() output"""
|
|
|
|
|
|
|
|
|
|
| 202 |
if df_compare_output is None or df_compare_output.empty:
|
| 203 |
-
fig, ax = plt.subplots(figsize=(12, 8))
|
| 204 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
| 205 |
ax.set_title(title)
|
| 206 |
-
|
| 207 |
-
plt.savefig(buf, format='png', dpi=100)
|
| 208 |
-
buf.seek(0)
|
| 209 |
-
img = Image.open(buf)
|
| 210 |
-
plt.close(fig)
|
| 211 |
-
return img
|
| 212 |
-
|
| 213 |
-
fig, ax = plt.subplots(figsize=(12, 8))
|
| 214 |
-
|
| 215 |
-
if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
|
| 216 |
gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
|
| 217 |
-
|
|
|
|
| 218 |
else:
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
ax.set_xlabel('Actual')
|
| 229 |
ax.set_ylabel('Estimate')
|
| 230 |
-
ax.
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
buf = io.BytesIO()
|
| 243 |
plt.savefig(buf, format='png', dpi=100)
|
| 244 |
buf.seek(0)
|
|
@@ -246,56 +285,74 @@ def plot_scatter_comparison(df_compare_output, title):
|
|
| 246 |
plt.close(fig)
|
| 247 |
return img
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
|
| 251 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
| 252 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
| 253 |
"""Main processing function - now accepts file paths"""
|
| 254 |
try:
|
| 255 |
-
#
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
|
| 263 |
|
| 264 |
-
#
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
if
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
if all(col in
|
| 273 |
-
|
| 274 |
-
cols_to_select = [col for col in required_cols if col in pol_data_full.columns]
|
| 275 |
-
if pol_data_full.index.name in required_cols and pol_data_full.index.name not in cols_to_select:
|
| 276 |
-
# This case is tricky; if an ID is part of required_cols and is the index.
|
| 277 |
-
# For simplicity, assume required_cols are actual data columns.
|
| 278 |
-
pass # Let it proceed, it might be handled by selection or error later.
|
| 279 |
-
|
| 280 |
-
pol_data = pol_data_full[cols_to_select].copy() # Use .copy() to avoid SettingWithCopyWarning
|
| 281 |
-
# If 'policy_id' was the index and required, it's implicitly handled or needs specific logic.
|
| 282 |
-
# For K-Means, policy_id itself is usually not a feature.
|
| 283 |
else:
|
| 284 |
-
|
| 285 |
-
gr.Warning(f"Policy data might be missing required columns: {
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
if 'policy_id' not in df.columns and df.index.name == 'policy_id':
|
| 296 |
-
df.reset_index(inplace=True)
|
| 297 |
-
df.set_index('policy_id', inplace=True)
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
| 300 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
| 301 |
|
|
@@ -303,44 +360,44 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 303 |
|
| 304 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
| 305 |
|
| 306 |
-
# DataFrames passed to Clusters should be policy_id indexed for .values to exclude it.
|
| 307 |
-
# Or, select only feature columns before passing.
|
| 308 |
-
# The Clusters class now expects a DataFrame and will use .values, so pass only feature columns.
|
| 309 |
-
# If index is policy_id, df.values will not include it. This is good.
|
| 310 |
-
|
| 311 |
# --- 1. Cashflow Calibration ---
|
| 312 |
-
#
|
| 313 |
-
cluster_cfs = Clusters(cfs
|
| 314 |
|
| 315 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
| 316 |
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
| 317 |
-
|
| 318 |
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
|
| 319 |
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
|
| 320 |
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
|
| 321 |
-
|
| 322 |
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
|
| 323 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
| 324 |
|
| 325 |
# --- 2. Policy Attribute Calibration ---
|
| 326 |
-
|
| 327 |
-
if not
|
| 328 |
-
#
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
| 334 |
else:
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
else:
|
| 338 |
-
gr.Warning("Policy data is empty. Skipping attribute calibration
|
|
|
|
| 339 |
|
| 340 |
-
if not
|
| 341 |
-
cluster_attrs = Clusters(
|
| 342 |
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
|
| 343 |
-
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
|
| 344 |
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
|
| 345 |
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
|
| 346 |
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
|
|
@@ -348,41 +405,39 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 348 |
results['attr_total_cf_base'] = pd.DataFrame()
|
| 349 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
| 350 |
results['attr_total_pv_base'] = pd.DataFrame()
|
| 351 |
-
results['attr_cashflow_plot'] =
|
| 352 |
-
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. -
|
| 353 |
|
| 354 |
|
| 355 |
# --- 3. Present Value Calibration ---
|
| 356 |
-
cluster_pvs = Clusters(pvs
|
| 357 |
|
| 358 |
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
|
| 359 |
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
|
| 360 |
-
|
| 361 |
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
|
| 362 |
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
|
| 363 |
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
|
| 364 |
-
|
| 365 |
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
|
| 366 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
| 367 |
|
| 368 |
# --- Summary Comparison Plot Data ---
|
| 369 |
error_data = {}
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
if compare_result is None or compare_result.empty or 'error' not in compare_result.columns: # Check if None
|
| 373 |
return np.nan
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
key_pv_col = None
|
| 380 |
-
#
|
| 381 |
-
# Or, use the original pvs DataFrame if it's guaranteed to have the PV_NetCF column.
|
| 382 |
-
# For safety, check in the original pvs DataFrame which has not been stripped of columns.
|
| 383 |
-
original_pvs_cols = pd.read_excel(pv_base_path).columns # Quick read just for columns
|
| 384 |
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
|
| 385 |
-
if potential_col in
|
| 386 |
key_pv_col = potential_col
|
| 387 |
break
|
| 388 |
|
|
@@ -392,11 +447,12 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 392 |
get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
|
| 393 |
]
|
| 394 |
|
| 395 |
-
if not
|
| 396 |
-
|
| 397 |
-
get_error_safe(results.get('attr_total_pv_base'), key_pv_col),
|
| 398 |
-
|
| 399 |
-
get_error_safe(cluster_attrs.compare_total(
|
|
|
|
| 400 |
]
|
| 401 |
else:
|
| 402 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
|
@@ -407,17 +463,26 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 407 |
get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
|
| 408 |
]
|
| 409 |
|
| 410 |
-
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
|
| 411 |
|
| 412 |
fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
ax_summary.set_ylabel('Absolute Error Rate')
|
| 415 |
-
title_suffix = f'
|
| 416 |
ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
|
| 417 |
ax_summary.tick_params(axis='x', rotation=0)
|
| 418 |
-
ax_summary.
|
|
|
|
|
|
|
|
|
|
| 419 |
plt.tight_layout()
|
| 420 |
-
|
| 421 |
buf_summary = io.BytesIO()
|
| 422 |
plt.savefig(buf_summary, format='png', dpi=100)
|
| 423 |
buf_summary.seek(0)
|
|
@@ -430,15 +495,22 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 430 |
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
|
| 431 |
return {"error": f"File not found: {e.filename}"}
|
| 432 |
except KeyError as e:
|
| 433 |
-
|
| 434 |
-
gr.Error(f"A required column or index is missing or misnamed: {e}. Please check data format and ensure 'policy_id' is correctly handled as index for feature dataframes.")
|
| 435 |
return {"error": f"Missing column/index: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
except Exception as e:
|
|
|
|
| 437 |
import traceback
|
| 438 |
-
|
| 439 |
-
return {"error": f"
|
| 440 |
|
| 441 |
-
# --- Gradio interface creation (create_interface, etc.)
|
|
|
|
|
|
|
| 442 |
def create_interface():
|
| 443 |
with gr.Blocks(title="Cluster Model Points Analysis") as demo:
|
| 444 |
gr.Markdown("""
|
|
@@ -448,15 +520,13 @@ def create_interface():
|
|
| 448 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
| 449 |
|
| 450 |
**Required Files (Excel .xlsx):**
|
| 451 |
-
- Cashflows - Base Scenario (
|
| 452 |
-
- Cashflows - Lapse Stress (+50%) (
|
| 453 |
-
- Cashflows - Mortality Stress (+15%) (
|
| 454 |
-
- Policy Data (
|
| 455 |
-
- Present Values - Base Scenario (
|
| 456 |
-
- Present Values - Lapse Stress (
|
| 457 |
-
- Present Values - Mortality Stress (
|
| 458 |
-
|
| 459 |
-
*Note: Ensure 'policy_id' is the index for all input files for correct processing.*
|
| 460 |
""")
|
| 461 |
|
| 462 |
with gr.Row():
|
|
@@ -503,11 +573,7 @@ def create_interface():
|
|
| 503 |
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
| 504 |
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
| 505 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 506 |
-
|
| 507 |
-
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
|
| 508 |
-
# Added placeholders for other scenarios if they were intended
|
| 509 |
-
# attr_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
|
| 510 |
-
# attr_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
|
| 511 |
|
| 512 |
with gr.TabItem("💰 Present Value Calibration"):
|
| 513 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
|
@@ -522,62 +588,46 @@ def create_interface():
|
|
| 522 |
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
|
| 523 |
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
|
| 524 |
|
| 525 |
-
# --- Helper function to prepare outputs ---
|
| 526 |
def get_all_output_components():
|
| 527 |
return [
|
| 528 |
summary_plot_output,
|
| 529 |
-
# Cashflow Calib Outputs
|
| 530 |
cf_total_base_table_out, cf_policy_attrs_total_out,
|
| 531 |
cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
|
| 532 |
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
|
| 533 |
-
# Attribute Calib Outputs
|
| 534 |
attr_total_cf_base_out, attr_policy_attrs_total_out,
|
| 535 |
attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
|
| 536 |
-
# PV Calib Outputs
|
| 537 |
pv_total_cf_base_out, pv_policy_attrs_total_out,
|
| 538 |
pv_cashflow_plot_out, pv_scatter_pvs_base_out,
|
| 539 |
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
|
| 540 |
]
|
| 541 |
|
| 542 |
-
# --- Action for Analyze Button ---
|
| 543 |
def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
|
| 544 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
| 545 |
-
|
| 546 |
file_paths = []
|
| 547 |
-
#
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
# For now, strict: all files must be present.
|
| 552 |
-
gr.Error("Missing file input for one or more fields. Please upload all required files or load the complete example dataset.")
|
| 553 |
-
return [None] * len(get_all_output_components())
|
| 554 |
-
|
| 555 |
-
for i, f_obj in enumerate(files):
|
| 556 |
-
# f_obj is TempFilePath (older Gradio) or FileData (newer) or str (from example load)
|
| 557 |
-
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Gradio FileData or similar
|
| 558 |
-
file_paths.append(f_obj.name)
|
| 559 |
-
elif isinstance(f_obj, str): # Path from example load
|
| 560 |
-
file_paths.append(f_obj)
|
| 561 |
-
else: # Should not happen if inputs are Files or paths
|
| 562 |
-
gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
|
| 563 |
return [None] * len(get_all_output_components())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
results = process_files(*file_paths)
|
| 566 |
|
| 567 |
-
if "error" in results :
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
|
|
|
| 571 |
return [
|
| 572 |
results.get('summary_plot'),
|
| 573 |
-
# CF Calib
|
| 574 |
results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
|
| 575 |
results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
|
| 576 |
results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
|
| 577 |
-
# Attr Calib
|
| 578 |
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
|
| 579 |
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
|
| 580 |
-
# PV Calib
|
| 581 |
results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
|
| 582 |
results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
|
| 583 |
results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
|
|
@@ -590,50 +640,50 @@ def create_interface():
|
|
| 590 |
outputs=get_all_output_components()
|
| 591 |
)
|
| 592 |
|
| 593 |
-
# --- Action for Load Example Data Button ---
|
| 594 |
def load_example_files():
|
| 595 |
-
# Create dummy example files if they don't exist
|
| 596 |
-
|
| 597 |
-
os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True) # Ensure dir exists
|
| 598 |
-
|
| 599 |
-
missing_files = []
|
| 600 |
for key, fp in EXAMPLE_FILES.items():
|
| 601 |
if not os.path.exists(fp):
|
| 602 |
-
|
| 603 |
-
# Create a minimal dummy Excel file if it's missing
|
| 604 |
try:
|
| 605 |
-
|
| 606 |
-
if "cashflow" in key or "pv" in key:
|
| 607 |
-
|
|
|
|
|
|
|
| 608 |
elif "policy_data" in key:
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
except Exception as e:
|
| 616 |
-
gr.
|
|
|
|
| 617 |
|
| 618 |
-
|
| 619 |
-
if missing_files
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
return [None] * 7 # Return None for all file inputs
|
| 623 |
|
| 624 |
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
|
| 625 |
-
# Return
|
| 626 |
return [
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
gr.File(value=EXAMPLE_FILES["policy_data"], Labeled_input=policy_data_input.label),
|
| 631 |
-
gr.File(value=EXAMPLE_FILES["pv_base"], Labeled_input=pv_base_input.label),
|
| 632 |
-
gr.File(value=EXAMPLE_FILES["pv_lapse"], Labeled_input=pv_lapse_input.label),
|
| 633 |
-
gr.File(value=EXAMPLE_FILES["pv_mort"], Labeled_input=pv_mort_input.label)
|
| 634 |
]
|
| 635 |
|
| 636 |
-
|
| 637 |
load_example_btn.click(
|
| 638 |
load_example_files,
|
| 639 |
inputs=[],
|
|
@@ -646,30 +696,7 @@ def create_interface():
|
|
| 646 |
if __name__ == "__main__":
|
| 647 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
| 648 |
os.makedirs(EXAMPLE_DATA_DIR)
|
| 649 |
-
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or
|
| 650 |
-
|
| 651 |
-
# Simple check and dummy file creation for example data if not present
|
| 652 |
-
for key, fp in EXAMPLE_FILES.items():
|
| 653 |
-
if not os.path.exists(fp):
|
| 654 |
-
print(f"Example file {fp} not found. Attempting to create a dummy file.")
|
| 655 |
-
try:
|
| 656 |
-
dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
|
| 657 |
-
if "cashflow" in key or "pv" in key:
|
| 658 |
-
dummy_df_data = {f'{i}':np.random.rand(3) for i in range(10)} # 10 time periods
|
| 659 |
-
dummy_df_data['policy_id'] = [f'P{j}' for j in range(3)]
|
| 660 |
-
elif "policy_data" in key:
|
| 661 |
-
dummy_df_data = {'policy_id': [f'P{j}' for j in range(3)],
|
| 662 |
-
'age_at_entry': np.random.randint(20, 50, 3),
|
| 663 |
-
'policy_term': np.random.randint(10, 30, 3),
|
| 664 |
-
'sum_assured': np.random.randint(10000, 50000, 3),
|
| 665 |
-
'duration_mth': np.random.randint(1, 120, 3)}
|
| 666 |
-
|
| 667 |
-
dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
|
| 668 |
-
dummy_df.to_excel(fp)
|
| 669 |
-
print(f"Dummy file for '{fp}' created.")
|
| 670 |
-
except Exception as e:
|
| 671 |
-
print(f"Could not create dummy file for {fp}: {e}")
|
| 672 |
-
|
| 673 |
|
| 674 |
demo_app = create_interface()
|
| 675 |
demo_app.launch()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
from sklearn.cluster import KMeans
|
| 5 |
+
from sklearn.metrics import pairwise_distances_argmin_min, r2_score # r2_score is not used but kept from original
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
+
# import matplotlib.cm # No longer explicitly needed for rainbow
|
| 8 |
+
import seaborn as sns # Added Seaborn
|
| 9 |
import io
|
| 10 |
+
import os
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
# Define the paths for example data
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
class Clusters:
|
| 26 |
+
def __init__(self, loc_vars):
|
| 27 |
+
# loc_vars is expected to be a DataFrame for cfs, loc_vars_attrs, pvs
|
| 28 |
+
# For KMeans, we need a NumPy array. If loc_vars is a DataFrame, .values extracts the data.
|
| 29 |
+
if isinstance(loc_vars, pd.DataFrame):
|
| 30 |
+
loc_vars_np = np.ascontiguousarray(loc_vars.values)
|
| 31 |
+
else: # If it's already a NumPy array (e.g. from previous processing not shown)
|
| 32 |
+
loc_vars_np = np.ascontiguousarray(loc_vars)
|
| 33 |
+
|
| 34 |
+
self.kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(loc_vars_np)
|
| 35 |
+
closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, loc_vars_np)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
|
| 38 |
+
rep_ids.name = 'policy_id'
|
| 39 |
+
rep_ids.index.name = 'cluster_id' # This index represents the cluster number (0 to 999)
|
| 40 |
+
self.rep_ids = rep_ids
|
| 41 |
|
| 42 |
+
# policy_count should be based on the length of the input data used for clustering
|
| 43 |
+
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars_np)}))['policy_count']
|
| 44 |
+
|
| 45 |
|
| 46 |
def agg_by_cluster(self, df, agg=None):
|
| 47 |
"""Aggregate columns by cluster"""
|
| 48 |
temp = df.copy()
|
| 49 |
+
temp['cluster_id'] = self.kmeans.labels_ # labels_ are 0-indexed cluster assignments
|
| 50 |
temp = temp.set_index('cluster_id')
|
| 51 |
+
agg_dict = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns if c != 'cluster_id'} if agg else "sum"
|
| 52 |
+
if not agg_dict: # handles case where temp has only cluster_id or agg makes agg_dict empty
|
| 53 |
+
return pd.DataFrame(index=temp.index.unique()) # return empty DF with cluster_id index
|
| 54 |
+
return temp.groupby(level='cluster_id').agg(agg_dict)
|
| 55 |
+
|
| 56 |
|
| 57 |
def extract_reps(self, df):
|
| 58 |
"""Extract the rows of representative policies"""
|
| 59 |
+
# df is expected to have policy_id as its index or as a column
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
if 'policy_id' not in df.columns and df.index.name != 'policy_id':
|
| 61 |
+
raise ValueError("DataFrame for extract_reps must have 'policy_id' as index or column.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
df_to_merge = df.reset_index() if df.index.name == 'policy_id' else df.copy()
|
| 64 |
+
|
| 65 |
+
# Ensure policy_id column exists after reset_index or in copy
|
| 66 |
+
if 'policy_id' not in df_to_merge.columns:
|
| 67 |
+
# This case implies policy_id was the index but reset_index didn't create it (e.g. unnamed index)
|
| 68 |
+
# This should be handled by input data prep: ensure policy_id is a named index or a column.
|
| 69 |
+
# For robustness, if original df had named index 'policy_id', reset_index works.
|
| 70 |
+
# If it was an unnamed index that is policy_id, it's more problematic.
|
| 71 |
+
# Assuming 'policy_id' is present in df_to_merge now.
|
| 72 |
+
pass
|
| 73 |
|
| 74 |
|
| 75 |
+
temp = pd.merge(self.rep_ids.reset_index(), df_to_merge, how='left', on='policy_id')
|
| 76 |
+
# temp now has 'cluster_id' from rep_ids and other columns from df_to_merge
|
| 77 |
+
temp = temp.set_index('cluster_id')
|
| 78 |
+
return temp.drop(columns=['policy_id'], errors='ignore')
|
| 79 |
|
| 80 |
|
| 81 |
def extract_and_scale_reps(self, df, agg=None):
|
| 82 |
"""Extract and scale the rows of representative policies"""
|
| 83 |
extracted_df = self.extract_reps(df)
|
| 84 |
if agg:
|
| 85 |
+
# Ensure we only try to multiply columns that exist in extracted_df
|
| 86 |
+
cols_to_multiply = [col for col in df.columns if col in extracted_df.columns]
|
| 87 |
+
mult = pd.DataFrame({
|
| 88 |
+
c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1)
|
| 89 |
+
for c in cols_to_multiply
|
| 90 |
+
})
|
| 91 |
+
mult.index = extracted_df.index # Align index for multiplication
|
| 92 |
+
|
| 93 |
+
# Only multiply existing columns
|
| 94 |
+
result_df = extracted_df.copy()
|
| 95 |
+
for col in cols_to_multiply:
|
| 96 |
+
result_df[col] = extracted_df[col].mul(mult[col])
|
| 97 |
+
return result_df
|
| 98 |
else:
|
| 99 |
+
# Scale all numeric columns in extracted_df
|
| 100 |
+
numeric_cols = extracted_df.select_dtypes(include=np.number).columns
|
| 101 |
+
result_df = extracted_df.copy()
|
| 102 |
+
for col in numeric_cols:
|
| 103 |
+
result_df[col] = extracted_df[col].mul(self.policy_count, axis=0)
|
| 104 |
+
return result_df
|
| 105 |
+
|
| 106 |
|
| 107 |
def compare(self, df, agg=None):
|
| 108 |
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
|
| 109 |
source = self.agg_by_cluster(df, agg)
|
| 110 |
target = self.extract_and_scale_reps(df, agg)
|
| 111 |
+
|
| 112 |
+
# Ensure consistent columns for stacking, could be an issue if agg is selective
|
| 113 |
+
common_columns = source.columns.intersection(target.columns)
|
| 114 |
+
source_stacked = source[common_columns].stack()
|
| 115 |
+
target_stacked = target[common_columns].stack()
|
| 116 |
+
|
| 117 |
+
return pd.DataFrame({'actual': source_stacked, 'estimate': target_stacked})
|
| 118 |
|
| 119 |
def compare_total(self, df, agg=None):
|
| 120 |
"""Aggregate df by columns"""
|
|
|
|
| 131 |
estimate_values = {}
|
| 132 |
|
| 133 |
for col in df.columns: # Iterate over original df columns to ensure all are covered
|
| 134 |
+
if col not in reps_unscaled.columns:
|
| 135 |
+
estimate_values[col] = np.nan # Column not in representative policies
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
continue
|
| 137 |
|
| 138 |
if agg.get(col, 'sum') == 'mean':
|
| 139 |
+
if self.policy_count.sum() > 0:
|
| 140 |
+
weighted_sum = (reps_unscaled[col].astype(float) * self.policy_count.astype(float)).sum()
|
| 141 |
+
total_weight = self.policy_count.sum()
|
| 142 |
+
estimate_values[col] = weighted_sum / total_weight
|
| 143 |
+
else:
|
| 144 |
+
estimate_values[col] = np.nan # Avoid division by zero
|
| 145 |
else: # sum
|
| 146 |
+
estimate_values[col] = (reps_unscaled[col].astype(float) * self.policy_count.astype(float)).sum()
|
|
|
|
| 147 |
estimate = pd.Series(estimate_values)
|
| 148 |
+
else:
|
|
|
|
| 149 |
actual = df.sum()
|
| 150 |
estimate = self.extract_and_scale_reps(df).sum()
|
| 151 |
|
| 152 |
+
actual, estimate = actual.align(estimate, fill_value=0) # Align before calculating error
|
| 153 |
+
error = np.where(actual != 0, (estimate / actual) - 1, 0) # estimate/actual can be NaN if actual is 0
|
| 154 |
+
error = np.nan_to_num(error, nan=0.0) # Replace NaNs from 0/0 with 0
|
| 155 |
|
| 156 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
| 157 |
|
| 158 |
+
## Plotting Functions (Modified for Seaborn)
|
| 159 |
+
---
|
| 160 |
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
| 161 |
+
"""Create cashflow comparison plots using Seaborn"""
|
| 162 |
+
sns.set_style("whitegrid") # Apply Seaborn styling
|
| 163 |
if not cfs_list or not cluster_obj or not titles:
|
| 164 |
return None
|
| 165 |
num_plots = len(cfs_list)
|
|
|
|
| 172 |
fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
|
| 173 |
axes = axes.flatten()
|
| 174 |
|
| 175 |
+
for i, (df_orig, title) in enumerate(zip(cfs_list, titles)):
|
| 176 |
if i < len(axes):
|
| 177 |
+
ax = axes[i]
|
| 178 |
+
# Assuming df_orig has policy_id as index, or it's handled before compare_total
|
| 179 |
+
comparison_df = cluster_obj.compare_total(df_orig)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# Prepare data for Seaborn lineplot (long format)
|
| 182 |
+
plot_data = comparison_df[['actual', 'estimate']].copy()
|
| 183 |
+
# Assuming the index of comparison_df represents 'Time'
|
| 184 |
+
plot_data['Time'] = plot_data.index.astype(str) # Ensure Time is string for categorical plotting if not truly numeric
|
| 185 |
+
try: # If Time can be numeric, use it as such.
|
| 186 |
+
plot_data['Time'] = pd.to_numeric(plot_data['Time'])
|
| 187 |
+
except ValueError:
|
| 188 |
+
pass # Keep as string if not convertible
|
| 189 |
+
|
| 190 |
+
plot_data_melted = plot_data.melt(id_vars='Time', var_name='Legend', value_name='Value')
|
| 191 |
+
|
| 192 |
+
sns.lineplot(x='Time', y='Value', hue='Legend', data=plot_data_melted, ax=ax, errorbar=None)
|
| 193 |
+
ax.set_title(title)
|
| 194 |
+
ax.set_xlabel('Time')
|
| 195 |
+
ax.set_ylabel('Value')
|
| 196 |
+
# ax.grid(True) # whitegrid style includes a grid
|
| 197 |
|
| 198 |
+
for j in range(i + 1, len(axes)): # Hide any unused subplots
|
| 199 |
fig.delaxes(axes[j])
|
| 200 |
|
| 201 |
plt.tight_layout()
|
|
|
|
| 207 |
return img
|
| 208 |
|
| 209 |
def plot_scatter_comparison(df_compare_output, title):
|
| 210 |
+
"""Create scatter plot comparison from compare() output using Seaborn"""
|
| 211 |
+
sns.set_style("whitegrid")
|
| 212 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 213 |
+
|
| 214 |
if df_compare_output is None or df_compare_output.empty:
|
|
|
|
| 215 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
| 216 |
ax.set_title(title)
|
| 217 |
+
elif not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
|
| 219 |
+
sns.scatterplot(x='actual', y='estimate', data=df_compare_output, s=25, alpha=0.7, ax=ax, legend=False)
|
| 220 |
+
ax.set_title(title)
|
| 221 |
else:
|
| 222 |
+
plot_data = df_compare_output.reset_index()
|
| 223 |
+
hue_col_name = df_compare_output.index.names[1]
|
| 224 |
+
# Ensure the hue column is treated as categorical by converting to string
|
| 225 |
+
plot_data[hue_col_name] = plot_data[hue_col_name].astype(str)
|
| 226 |
+
|
| 227 |
+
unique_levels = plot_data[hue_col_name].nunique()
|
| 228 |
+
show_legend_flag = "auto"
|
| 229 |
+
if unique_levels == 1:
|
| 230 |
+
show_legend_flag = False
|
| 231 |
+
elif unique_levels > 10:
|
| 232 |
+
show_legend_flag = False
|
| 233 |
+
gr.Warning(f"Warning: Too many unique values ({unique_levels}) in '{hue_col_name}' for scatter plot legend. Legend hidden.")
|
| 234 |
+
|
| 235 |
+
sns.scatterplot(x='actual', y='estimate', hue=hue_col_name, data=plot_data,
|
| 236 |
+
s=25, alpha=0.7, ax=ax, legend=show_legend_flag)
|
| 237 |
+
ax.set_title(title)
|
| 238 |
+
if show_legend_flag == True and ax.get_legend() is not None:
|
| 239 |
+
ax.get_legend().set_title(str(hue_col_name))
|
| 240 |
+
elif show_legend_flag == "auto" and ax.get_legend() is not None: # Seaborn decided to show it
|
| 241 |
+
ax.get_legend().set_title(str(hue_col_name))
|
| 242 |
+
|
| 243 |
|
| 244 |
ax.set_xlabel('Actual')
|
| 245 |
ax.set_ylabel('Estimate')
|
| 246 |
+
# ax.grid(True) # whitegrid includes it
|
| 247 |
+
|
| 248 |
+
# Draw identity line
|
| 249 |
+
# Must draw after scatterplot to get correct limits
|
| 250 |
+
# Delay lims calculation until after plot, ensure data exists
|
| 251 |
+
if not (df_compare_output is None or df_compare_output.empty):
|
| 252 |
+
# Get limits from data if axes limits are too wide or default
|
| 253 |
+
# This ensures the identity line is relevant to the plotted data
|
| 254 |
+
all_values = pd.concat([plot_data['actual'], plot_data['estimate']]).dropna() if 'plot_data' in locals() else \
|
| 255 |
+
pd.concat([df_compare_output['actual'], df_compare_output['estimate']]).dropna()
|
| 256 |
+
|
| 257 |
+
if not all_values.empty:
|
| 258 |
+
min_val = all_values.min()
|
| 259 |
+
max_val = all_values.max()
|
| 260 |
+
|
| 261 |
+
# Use current axis limits if they are tighter than data range (e.g., user zoomed)
|
| 262 |
+
# But if they are default (-0.05 to 0.05 for empty data), use data range.
|
| 263 |
+
ax_xlims = ax.get_xlim()
|
| 264 |
+
ax_ylims = ax.get_ylim()
|
| 265 |
+
|
| 266 |
+
plot_min = np.nanmin([min_val, ax_xlims[0], ax_ylims[0]])
|
| 267 |
+
plot_max = np.nanmax([max_val, ax_xlims[1], ax_ylims[1]])
|
| 268 |
+
|
| 269 |
+
# Handle cases where min and max might be too close or NaN
|
| 270 |
+
if np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min < plot_max:
|
| 271 |
+
ax.plot([plot_min, plot_max], [plot_min, plot_max], 'r-', linewidth=0.7, alpha=0.8, zorder=0)
|
| 272 |
+
ax.set_xlim(plot_min, plot_max)
|
| 273 |
+
ax.set_ylim(plot_min, plot_max)
|
| 274 |
+
elif np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min == plot_max: # Single point
|
| 275 |
+
margin = abs(plot_min * 0.1) if plot_min != 0 else 0.1
|
| 276 |
+
ax.plot([plot_min], [plot_min], 'ro') # Mark the point
|
| 277 |
+
ax.set_xlim(plot_min - margin, plot_min + margin)
|
| 278 |
+
ax.set_ylim(plot_min - margin, plot_min + margin)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
buf = io.BytesIO()
|
| 282 |
plt.savefig(buf, format='png', dpi=100)
|
| 283 |
buf.seek(0)
|
|
|
|
| 285 |
plt.close(fig)
|
| 286 |
return img
|
| 287 |
|
| 288 |
+
## Main Processing and Gradio UI (Largely Unchanged)
|
| 289 |
+
---
|
| 290 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
| 291 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
| 292 |
"""Main processing function - now accepts file paths"""
|
| 293 |
try:
|
| 294 |
+
# Ensure 'policy_id' is the index for dataframes used in clustering/comparison
|
| 295 |
+
def read_and_prep_excel(path, set_policy_id_index=True):
|
| 296 |
+
df = pd.read_excel(path) # Read first, then set index
|
| 297 |
+
if 'policy_id' not in df.columns:
|
| 298 |
+
# Try to find it in unnamed index columns if any, or assume first column
|
| 299 |
+
# This is risky; ideally, 'policy_id' is an explicit column name
|
| 300 |
+
gr.Warning(f"'policy_id' column not found in {os.path.basename(path)}. Attempting to use first column or existing index.")
|
| 301 |
+
if df.columns[0].lower() == 'policy_id' or 'policyid' in df.columns[0].lower():
|
| 302 |
+
df.rename(columns={df.columns[0]: 'policy_id'}, inplace=True)
|
| 303 |
+
# Or if it is in the index already but unnamed
|
| 304 |
+
elif df.index.name is None and len(df.index) == len(df): # A heuristic
|
| 305 |
+
pass # keep as is, will try to use index later
|
| 306 |
+
else: # Fallback if no clear policy_id column found and index is not it
|
| 307 |
+
gr.Error(f"Cannot reliably find 'policy_id' in {os.path.basename(path)}.")
|
| 308 |
+
# For this example, let's assume files WILL have policy_id column or as first column
|
| 309 |
+
# This part needs robust handling based on expected file structures.
|
| 310 |
+
# If it's always index_col=0 as in original:
|
| 311 |
+
df = pd.read_excel(path, index_col=0)
|
| 312 |
+
if df.index.name != 'policy_id': # if index_col=0 was not named 'policy_id'
|
| 313 |
+
df.index.name = 'policy_id' # Name it 'policy_id'
|
| 314 |
+
return df.reset_index() # Make policy_id a column then set as index
|
| 315 |
+
|
| 316 |
+
if set_policy_id_index:
|
| 317 |
+
return df.set_index('policy_id')
|
| 318 |
+
return df
|
| 319 |
+
|
| 320 |
+
cfs = read_and_prep_excel(cashflow_base_path).select_dtypes(include=np.number)
|
| 321 |
+
cfs_lapse50 = read_and_prep_excel(cashflow_lapse_path).select_dtypes(include=np.number)
|
| 322 |
+
cfs_mort15 = read_and_prep_excel(cashflow_mort_path).select_dtypes(include=np.number)
|
| 323 |
+
|
| 324 |
+
pol_data_full_raw = read_and_prep_excel(policy_data_path, set_policy_id_index=False)
|
| 325 |
+
# Ensure the correct columns are selected for pol_data
|
| 326 |
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
|
| 327 |
|
| 328 |
+
# Check if required_cols exist, case-insensitively, and normalize names
|
| 329 |
+
rename_map = {}
|
| 330 |
+
available_cols_lower = {col.lower(): col for col in pol_data_full_raw.columns}
|
| 331 |
+
for req_col in required_cols:
|
| 332 |
+
if req_col.lower() in available_cols_lower:
|
| 333 |
+
rename_map[available_cols_lower[req_col.lower()]] = req_col # Map original to standardized
|
| 334 |
+
pol_data_full_renamed = pol_data_full_raw.rename(columns=rename_map)
|
| 335 |
+
|
| 336 |
+
if all(col in pol_data_full_renamed.columns for col in required_cols):
|
| 337 |
+
pol_data = pol_data_full_renamed.set_index('policy_id')[required_cols].select_dtypes(include=np.number)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
else:
|
| 339 |
+
missing = [col for col in required_cols if col not in pol_data_full_renamed.columns]
|
| 340 |
+
gr.Warning(f"Policy data might be missing required columns: {missing}. Found: {pol_data_full_renamed.columns.tolist()}")
|
| 341 |
+
# Fallback: use all numeric columns if required are missing, set policy_id as index
|
| 342 |
+
pol_data = pol_data_full_renamed.set_index('policy_id').select_dtypes(include=np.number)
|
| 343 |
+
if pol_data.empty and not pol_data_full_renamed.select_dtypes(include=np.number).empty:
|
| 344 |
+
gr.Warning("Policy data became empty after trying to select numeric types with policy_id index. Check input.")
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
pvs = read_and_prep_excel(pv_base_path).select_dtypes(include=np.number)
|
| 348 |
+
pvs_lapse50 = read_and_prep_excel(pv_lapse_path).select_dtypes(include=np.number)
|
| 349 |
+
pvs_mort15 = read_and_prep_excel(pv_mort_path).select_dtypes(include=np.number)
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# DataFrames for Clusters class should not include the policy_id if it's an index
|
| 352 |
+
# The class constructor expects features only (typically a DataFrame where .values gives numeric data)
|
| 353 |
+
# The current read_and_prep_excel sets policy_id as index. This is fine.
|
| 354 |
+
# KMeans will be called on df.values implicitly.
|
| 355 |
+
|
| 356 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
| 357 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
| 358 |
|
|
|
|
| 360 |
|
| 361 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
# --- 1. Cashflow Calibration ---
|
| 364 |
+
# Pass DataFrame with features only. If policy_id is index, df.values is correct.
|
| 365 |
+
cluster_cfs = Clusters(cfs)
|
| 366 |
|
| 367 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
| 368 |
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
|
|
|
| 369 |
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
|
| 370 |
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
|
| 371 |
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
|
|
|
|
| 372 |
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
|
| 373 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
| 374 |
|
| 375 |
# --- 2. Policy Attribute Calibration ---
|
| 376 |
+
loc_vars_attrs_input = pol_data # pol_data is already features with policy_id as index
|
| 377 |
+
if not loc_vars_attrs_input.empty:
|
| 378 |
+
# Standardize policy attributes if there's variance
|
| 379 |
+
min_vals = loc_vars_attrs_input.min()
|
| 380 |
+
max_vals = loc_vars_attrs_input.max()
|
| 381 |
+
range_vals = max_vals - min_vals
|
| 382 |
+
if (range_vals == 0).all(): # No variance
|
| 383 |
+
gr.Warning("Policy data for attribute calibration has no variance. Using original values (may lead to poor clustering if scales differ).")
|
| 384 |
+
loc_vars_attrs_scaled = loc_vars_attrs_input
|
| 385 |
else:
|
| 386 |
+
# Scale only columns with variance, keep others as is (or handle as 0 if appropriate)
|
| 387 |
+
loc_vars_attrs_scaled = loc_vars_attrs_input.copy()
|
| 388 |
+
for col in range_vals.index:
|
| 389 |
+
if range_vals[col] > 1e-9: # Check for non-zero range with tolerance
|
| 390 |
+
loc_vars_attrs_scaled[col] = (loc_vars_attrs_input[col] - min_vals[col]) / range_vals[col]
|
| 391 |
+
else: # if no variance, scaled value is 0 or 0.5 (or original)
|
| 392 |
+
loc_vars_attrs_scaled[col] = 0.0 # Or np.nan, or keep original: loc_vars_attrs_input[col]
|
| 393 |
else:
|
| 394 |
+
gr.Warning("Policy data for attribute calibration is empty. Skipping attribute calibration plots.")
|
| 395 |
+
loc_vars_attrs_scaled = pd.DataFrame(index=pol_data.index) # Empty DF with correct index
|
| 396 |
|
| 397 |
+
if not loc_vars_attrs_scaled.empty:
|
| 398 |
+
cluster_attrs = Clusters(loc_vars_attrs_scaled) # Pass the scaled data
|
| 399 |
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
|
| 400 |
+
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) # Compare against original pol_data
|
| 401 |
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
|
| 402 |
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
|
| 403 |
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
|
|
|
|
| 405 |
results['attr_total_cf_base'] = pd.DataFrame()
|
| 406 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
| 407 |
results['attr_total_pv_base'] = pd.DataFrame()
|
| 408 |
+
results['attr_cashflow_plot'] = plot_cashflows_comparison([], None, []) # Empty plot
|
| 409 |
+
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - No Data')
|
| 410 |
|
| 411 |
|
| 412 |
# --- 3. Present Value Calibration ---
|
| 413 |
+
cluster_pvs = Clusters(pvs)
|
| 414 |
|
| 415 |
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
|
| 416 |
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
|
|
|
|
| 417 |
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
|
| 418 |
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
|
| 419 |
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
|
|
|
|
| 420 |
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
|
| 421 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
| 422 |
|
| 423 |
# --- Summary Comparison Plot Data ---
|
| 424 |
error_data = {}
|
| 425 |
+
def get_error_safe(compare_result_df, col_name=None):
|
| 426 |
+
if compare_result_df is None or compare_result_df.empty or 'error' not in compare_result_df.columns:
|
|
|
|
| 427 |
return np.nan
|
| 428 |
+
# Ensure col_name, if provided, is actually an index in the DataFrame
|
| 429 |
+
# compare_result_df has an index of column names of the original data (e.g. PV_NetCF)
|
| 430 |
+
if col_name and col_name in compare_result_df.index:
|
| 431 |
+
error_val = compare_result_df.loc[col_name, 'error']
|
| 432 |
+
return abs(error_val) if pd.notna(error_val) else np.nan
|
| 433 |
+
else: # Mean absolute error of all error column values
|
| 434 |
+
valid_errors = compare_result_df['error'].dropna()
|
| 435 |
+
return abs(valid_errors).mean() if not valid_errors.empty else np.nan
|
| 436 |
|
| 437 |
key_pv_col = None
|
| 438 |
+
# pvs dataframe here has policy_id as index, columns are features.
|
|
|
|
|
|
|
|
|
|
| 439 |
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
|
| 440 |
+
if potential_col in pvs.columns: # pvs is already loaded and indexed
|
| 441 |
key_pv_col = potential_col
|
| 442 |
break
|
| 443 |
|
|
|
|
| 447 |
get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
|
| 448 |
]
|
| 449 |
|
| 450 |
+
if not loc_vars_attrs_scaled.empty: # Check if attribute calibration was performed
|
| 451 |
+
error_data['Attr Calib.'] = [
|
| 452 |
+
get_error_safe(results.get('attr_total_pv_base'), key_pv_col),
|
| 453 |
+
# For stressed PVs under Attr Calib, we need to call compare_total from cluster_attrs
|
| 454 |
+
get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
|
| 455 |
+
get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
|
| 456 |
]
|
| 457 |
else:
|
| 458 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
|
|
|
| 463 |
get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
|
| 464 |
]
|
| 465 |
|
| 466 |
+
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']).astype(float) # Ensure float for plotting
|
| 467 |
|
| 468 |
fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
|
| 469 |
+
sns.set_style("whitegrid")
|
| 470 |
+
|
| 471 |
+
# Melt for Seaborn barplot
|
| 472 |
+
summary_df_melted = summary_df.reset_index().rename(columns={'index': 'Scenario'})
|
| 473 |
+
summary_df_melted = summary_df_melted.melt(id_vars='Scenario', var_name='Calibration Method', value_name='Absolute Error Rate')
|
| 474 |
+
|
| 475 |
+
sns.barplot(x='Scenario', y='Absolute Error Rate', hue='Calibration Method', data=summary_df_melted, ax=ax_summary)
|
| 476 |
+
|
| 477 |
ax_summary.set_ylabel('Absolute Error Rate')
|
| 478 |
+
title_suffix = f' for {key_pv_col}' if key_pv_col else ' (Mean Absolute Error)'
|
| 479 |
ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
|
| 480 |
ax_summary.tick_params(axis='x', rotation=0)
|
| 481 |
+
if ax_summary.get_legend():
|
| 482 |
+
ax_summary.get_legend().set_title('Calibration Method')
|
| 483 |
+
ax_summary.grid(True, axis='y') # Horizontal grid lines for bar plot
|
| 484 |
+
|
| 485 |
plt.tight_layout()
|
|
|
|
| 486 |
buf_summary = io.BytesIO()
|
| 487 |
plt.savefig(buf_summary, format='png', dpi=100)
|
| 488 |
buf_summary.seek(0)
|
|
|
|
| 495 |
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
|
| 496 |
return {"error": f"File not found: {e.filename}"}
|
| 497 |
except KeyError as e:
|
| 498 |
+
gr.Error(f"A required column/index ('policy_id' or feature column) is missing or misnamed: {e}. Please check data format.")
|
|
|
|
| 499 |
return {"error": f"Missing column/index: {e}"}
|
| 500 |
+
except ValueError as e: # Catch other value errors like from plotting or data prep
|
| 501 |
+
gr.Error(f"Data processing or plotting error: {str(e)}")
|
| 502 |
+
import traceback
|
| 503 |
+
traceback.print_exc()
|
| 504 |
+
return {"error": f"Data error: {str(e)}"}
|
| 505 |
except Exception as e:
|
| 506 |
+
gr.Error(f"An unexpected error occurred: {str(e)}")
|
| 507 |
import traceback
|
| 508 |
+
traceback.print_exc()
|
| 509 |
+
return {"error": f"Unexpected error: {str(e)}"}
|
| 510 |
|
| 511 |
+
# --- Gradio interface creation (create_interface, etc.) ---
|
| 512 |
+
# This part remains unchanged from your original script.
|
| 513 |
+
# Ensure dummy file creation in if __name__ == "__main__": handles policy_id correctly.
|
| 514 |
def create_interface():
|
| 515 |
with gr.Blocks(title="Cluster Model Points Analysis") as demo:
|
| 516 |
gr.Markdown("""
|
|
|
|
| 520 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
| 521 |
|
| 522 |
**Required Files (Excel .xlsx):**
|
| 523 |
+
- Cashflows - Base Scenario (should contain a 'policy_id' column, or it's the first column/index)
|
| 524 |
+
- Cashflows - Lapse Stress (+50%) (similar structure)
|
| 525 |
+
- Cashflows - Mortality Stress (+15%) (similar structure)
|
| 526 |
+
- Policy Data (should contain 'policy_id', 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
|
| 527 |
+
- Present Values - Base Scenario (should contain 'policy_id' and PV columns like 'PV_NetCF')
|
| 528 |
+
- Present Values - Lapse Stress (similar structure)
|
| 529 |
+
- Present Values - Mortality Stress (similar structure)
|
|
|
|
|
|
|
| 530 |
""")
|
| 531 |
|
| 532 |
with gr.Row():
|
|
|
|
| 573 |
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
| 574 |
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
| 575 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 576 |
+
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total") # Only one PV table shown in original UI for this tab
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
with gr.TabItem("💰 Present Value Calibration"):
|
| 579 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
|
|
|
| 588 |
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
|
| 589 |
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
|
| 590 |
|
|
|
|
| 591 |
def get_all_output_components():
|
| 592 |
return [
|
| 593 |
summary_plot_output,
|
|
|
|
| 594 |
cf_total_base_table_out, cf_policy_attrs_total_out,
|
| 595 |
cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
|
| 596 |
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
|
|
|
|
| 597 |
attr_total_cf_base_out, attr_policy_attrs_total_out,
|
| 598 |
attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
|
|
|
|
| 599 |
pv_total_cf_base_out, pv_policy_attrs_total_out,
|
| 600 |
pv_cashflow_plot_out, pv_scatter_pvs_base_out,
|
| 601 |
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
|
| 602 |
]
|
| 603 |
|
|
|
|
| 604 |
def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
|
| 605 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
|
|
|
| 606 |
file_paths = []
|
| 607 |
+
# Gradio File component now passes full path for temporary files
|
| 608 |
+
for i, f_obj_path in enumerate(files): # f_obj is now a path string or None
|
| 609 |
+
if f_obj_path is None:
|
| 610 |
+
gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
return [None] * len(get_all_output_components())
|
| 612 |
+
if not isinstance(f_obj_path, str): # Should be a path string
|
| 613 |
+
gr.Error(f"Invalid file input for argument {i+1}. Expected path, got {type(f_obj_path)}")
|
| 614 |
+
return [None] * len(get_all_output_components())
|
| 615 |
+
file_paths.append(f_obj_path)
|
| 616 |
|
| 617 |
results = process_files(*file_paths)
|
| 618 |
|
| 619 |
+
if "error" in results :
|
| 620 |
+
return [gr.Plot.update(None)] * len(get_all_output_components()) # Clear plots on error
|
| 621 |
+
|
| 622 |
+
# Ensure DataFrames are converted to a format Gradio can display (e.g. List of Lists or pandas)
|
| 623 |
+
# For Dataframe components, pandas DataFrames are fine. For Image, PIL Image is fine.
|
| 624 |
return [
|
| 625 |
results.get('summary_plot'),
|
|
|
|
| 626 |
results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
|
| 627 |
results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
|
| 628 |
results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
|
|
|
|
| 629 |
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
|
| 630 |
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
|
|
|
|
| 631 |
results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
|
| 632 |
results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
|
| 633 |
results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
|
|
|
|
| 640 |
outputs=get_all_output_components()
|
| 641 |
)
|
| 642 |
|
|
|
|
| 643 |
def load_example_files():
|
| 644 |
+
# Create dummy example files if they don't exist
|
| 645 |
+
os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
| 646 |
for key, fp in EXAMPLE_FILES.items():
|
| 647 |
if not os.path.exists(fp):
|
| 648 |
+
gr.Info(f"Example file {fp} not found. Attempting to create a dummy file.")
|
|
|
|
| 649 |
try:
|
| 650 |
+
num_policies = 50 # For dummy data
|
| 651 |
+
if "cashflow" in key or "pv" in key:
|
| 652 |
+
dummy_data = {'policy_id': [f'P{j:03d}' for j in range(num_policies)]}
|
| 653 |
+
for i in range(10): # 10 time periods / PV components
|
| 654 |
+
dummy_data[f't{i}'] = np.random.rand(num_policies) * 1000
|
| 655 |
elif "policy_data" in key:
|
| 656 |
+
dummy_data = {
|
| 657 |
+
'policy_id': [f'P{j:03d}' for j in range(num_policies)],
|
| 658 |
+
'age_at_entry': np.random.randint(20, 50, num_policies),
|
| 659 |
+
'policy_term': np.random.randint(10, 30, num_policies),
|
| 660 |
+
'sum_assured': np.random.randint(10000, 50000, num_policies),
|
| 661 |
+
'duration_mth': np.random.randint(1, 240, num_policies)
|
| 662 |
+
}
|
| 663 |
+
else: # Default dummy
|
| 664 |
+
dummy_data = {'policy_id': [f'P{j:03d}' for j in range(num_policies)], 'feature1': np.random.rand(num_policies)}
|
| 665 |
+
|
| 666 |
+
dummy_df = pd.DataFrame(dummy_data)
|
| 667 |
+
# Do not set index here, let read_and_prep_excel handle it.
|
| 668 |
+
dummy_df.to_excel(fp, index=False) # Save without pandas index
|
| 669 |
+
gr.Info(f"Dummy file for '{os.path.basename(fp)}' created in '{EXAMPLE_DATA_DIR}'.")
|
| 670 |
except Exception as e:
|
| 671 |
+
gr.Error(f"Could not create dummy file for {fp}: {e}")
|
| 672 |
+
return [None] * 7 # Fail loading if dummy creation fails
|
| 673 |
|
| 674 |
+
missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
|
| 675 |
+
if missing_files:
|
| 676 |
+
gr.Error(f"Still missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
|
| 677 |
+
return [None] * 7
|
|
|
|
| 678 |
|
| 679 |
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
|
| 680 |
+
# Return file paths directly to the File components
|
| 681 |
return [
|
| 682 |
+
EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
|
| 683 |
+
EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
|
| 684 |
+
EXAMPLE_FILES["pv_mort"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
]
|
| 686 |
|
|
|
|
| 687 |
load_example_btn.click(
|
| 688 |
load_example_files,
|
| 689 |
inputs=[],
|
|
|
|
| 696 |
if __name__ == "__main__":
|
| 697 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
| 698 |
os.makedirs(EXAMPLE_DATA_DIR)
|
| 699 |
+
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or dummy files will be generated on 'Load Example Data'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
demo_app = create_interface()
|
| 702 |
demo_app.launch()
|