Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,15 +2,15 @@ import gradio as gr
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
from sklearn.cluster import KMeans
|
| 5 |
-
from sklearn.metrics import pairwise_distances_argmin_min
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
import
|
|
|
|
| 8 |
import io
|
| 9 |
import os
|
| 10 |
from PIL import Image
|
| 11 |
|
| 12 |
# Define the paths for example data
|
| 13 |
-
# For Hugging Face Spaces, these paths will be relative to the app's root
|
| 14 |
EXAMPLE_DATA_DIR = "eg_data"
|
| 15 |
EXAMPLE_FILES = {
|
| 16 |
"cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
|
|
@@ -24,455 +24,306 @@ EXAMPLE_FILES = {
|
|
| 24 |
|
| 25 |
class Clusters:
|
| 26 |
def __init__(self, loc_vars):
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
else:
|
| 30 |
-
loc_vars_np = np.ascontiguousarray(loc_vars)
|
| 31 |
-
|
| 32 |
-
self.kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(loc_vars_np)
|
| 33 |
-
closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, loc_vars_np)
|
| 34 |
|
| 35 |
-
rep_ids = pd.Series(data=(closest
|
| 36 |
rep_ids.name = 'policy_id'
|
| 37 |
rep_ids.index.name = 'cluster_id'
|
| 38 |
self.rep_ids = rep_ids
|
| 39 |
|
| 40 |
-
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(
|
| 41 |
|
| 42 |
def agg_by_cluster(self, df, agg=None):
|
|
|
|
| 43 |
temp = df.copy()
|
| 44 |
temp['cluster_id'] = self.kmeans.labels_
|
| 45 |
temp = temp.set_index('cluster_id')
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
return pd.DataFrame(index=temp.index.unique())
|
| 49 |
-
return temp.groupby(level='cluster_id').agg(agg_dict)
|
| 50 |
|
| 51 |
def extract_reps(self, df):
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
else:
|
| 57 |
-
raise ValueError("DataFrame for extract_reps must have 'policy_id' as a named index or a column.")
|
| 58 |
-
|
| 59 |
-
df_to_merge = df.reset_index() if df.index.name == 'policy_id' or (isinstance(df.index, pd.MultiIndex) and 'policy_id' in df.index.names) else df.copy()
|
| 60 |
-
|
| 61 |
-
if 'policy_id' not in df_to_merge.columns:
|
| 62 |
-
# This is a fallback if policy_id was expected but still not a column.
|
| 63 |
-
# This might happen if the index was unnamed and thought to be policy_id.
|
| 64 |
-
# A robust solution depends on stricter input guarantees.
|
| 65 |
-
gr.Warning("extract_reps: 'policy_id' column not found after attempting to reset index. Merging may fail or be incorrect.")
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
temp = pd.merge(self.rep_ids.reset_index(), df_to_merge, how='left', on='policy_id')
|
| 69 |
-
temp = temp.set_index('cluster_id')
|
| 70 |
-
return temp.drop(columns=['policy_id'], errors='ignore')
|
| 71 |
|
| 72 |
def extract_and_scale_reps(self, df, agg=None):
|
| 73 |
-
|
| 74 |
if agg:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
policy_count_for_col = self.policy_count.reindex(extracted_df.index).fillna(1) # Default to 1 if cluster missing
|
| 81 |
-
else: # Should be a scalar or array-like usable directly
|
| 82 |
-
policy_count_for_col = self.policy_count
|
| 83 |
-
|
| 84 |
-
mult_data[c] = policy_count_for_col if (c not in agg or agg[c] == 'sum') else 1
|
| 85 |
-
|
| 86 |
-
mult = pd.DataFrame(mult_data, index=extracted_df.index)
|
| 87 |
-
|
| 88 |
-
result_df = extracted_df.copy()
|
| 89 |
-
for col in cols_to_multiply:
|
| 90 |
-
if col in mult.columns: # Ensure column exists in multiplier
|
| 91 |
-
result_df[col] = extracted_df[col].mul(mult[col])
|
| 92 |
-
return result_df
|
| 93 |
else:
|
| 94 |
-
|
| 95 |
-
result_df = extracted_df.copy()
|
| 96 |
-
for col in numeric_cols:
|
| 97 |
-
if isinstance(self.policy_count, pd.Series):
|
| 98 |
-
policy_count_for_col = self.policy_count.reindex(extracted_df.index).fillna(0) # Fill with 0 if not found
|
| 99 |
-
result_df[col] = extracted_df[col].mul(policy_count_for_col, axis=0)
|
| 100 |
-
else: # Assuming self.policy_count is a scalar or compatible array
|
| 101 |
-
result_df[col] = extracted_df[col].mul(self.policy_count, axis=0)
|
| 102 |
-
return result_df
|
| 103 |
|
| 104 |
def compare(self, df, agg=None):
|
|
|
|
| 105 |
source = self.agg_by_cluster(df, agg)
|
| 106 |
target = self.extract_and_scale_reps(df, agg)
|
| 107 |
-
|
| 108 |
-
common_columns = source.columns.intersection(target.columns)
|
| 109 |
-
if common_columns.empty and (not source.empty or not target.empty):
|
| 110 |
-
gr.Warning("Compare function: No common columns between source and target. Result will be empty.")
|
| 111 |
-
return pd.DataFrame({'actual': pd.Series(dtype=float), 'estimate': pd.Series(dtype=float)})
|
| 112 |
-
|
| 113 |
-
source_stacked = source[common_columns].stack(dropna=False) # keepna=True for older pandas
|
| 114 |
-
target_stacked = target[common_columns].stack(dropna=False)
|
| 115 |
-
|
| 116 |
-
return pd.DataFrame({'actual': source_stacked, 'estimate': target_stacked})
|
| 117 |
|
| 118 |
def compare_total(self, df, agg=None):
|
|
|
|
| 119 |
if agg:
|
| 120 |
actual_values = {}
|
| 121 |
for col in df.columns:
|
| 122 |
if agg.get(col, 'sum') == 'mean':
|
| 123 |
actual_values[col] = df[col].mean()
|
| 124 |
-
else:
|
| 125 |
actual_values[col] = df[col].sum()
|
| 126 |
actual = pd.Series(actual_values)
|
| 127 |
|
| 128 |
reps_unscaled = self.extract_reps(df)
|
| 129 |
estimate_values = {}
|
| 130 |
|
| 131 |
-
for
|
| 132 |
-
if
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
if agg.get(col_orig_df, 'sum') == 'mean':
|
| 140 |
-
weighted_sum = (current_col_data * policy_counts_aligned).sum()
|
| 141 |
-
total_weight = policy_counts_aligned.sum()
|
| 142 |
-
estimate_values[col_orig_df] = weighted_sum / total_weight if total_weight > 0 else np.nan
|
| 143 |
-
else:
|
| 144 |
-
estimate_values[col_orig_df] = (current_col_data * policy_counts_aligned).sum()
|
| 145 |
estimate = pd.Series(estimate_values)
|
| 146 |
-
|
|
|
|
| 147 |
actual = df.sum()
|
| 148 |
-
estimate = self.extract_and_scale_reps(df).sum()
|
| 149 |
|
| 150 |
-
actual, estimate
|
| 151 |
-
error = np.where(actual != 0, (estimate / actual) - 1, 0)
|
| 152 |
-
error = np.nan_to_num(error, nan=0.0)
|
| 153 |
|
| 154 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
| 155 |
|
| 156 |
-
# Plotting Functions (Modified for Seaborn)
|
| 157 |
-
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
| 158 |
-
sns.set_style("whitegrid")
|
| 159 |
-
if not cfs_list or not cluster_obj or not titles or not any(cfs_list) : # Check if cfs_list contains any non-None df
|
| 160 |
-
# Return a placeholder image indicating no data
|
| 161 |
-
fig, ax = plt.subplots(figsize=(7.5, 2.5)) # Smaller placeholder
|
| 162 |
-
ax.text(0.5, 0.5, "No cashflow data to plot", ha='center', va='center', fontsize=10)
|
| 163 |
-
ax.set_xticks([])
|
| 164 |
-
ax.set_yticks([])
|
| 165 |
-
buf = io.BytesIO()
|
| 166 |
-
plt.savefig(buf, format='png', dpi=100)
|
| 167 |
-
buf.seek(0)
|
| 168 |
-
img = Image.open(buf)
|
| 169 |
-
plt.close(fig)
|
| 170 |
-
return img
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if not
|
| 175 |
-
return
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
num_plots = len(valid_cfs_data)
|
| 178 |
cols = 2
|
| 179 |
rows = (num_plots + cols - 1) // cols
|
| 180 |
|
| 181 |
-
|
|
|
|
| 182 |
axes = axes.flatten()
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
for
|
| 186 |
-
if
|
| 187 |
-
ax = axes[
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
continue
|
| 195 |
-
|
| 196 |
-
plot_data = comparison_df[['actual', 'estimate']].copy()
|
| 197 |
-
plot_data['Time'] = plot_data.index.astype(str)
|
| 198 |
-
try:
|
| 199 |
-
plot_data['Time'] = pd.to_numeric(plot_data['Time'])
|
| 200 |
-
except ValueError:
|
| 201 |
-
pass
|
| 202 |
-
|
| 203 |
-
plot_data_melted = plot_data.melt(id_vars='Time', var_name='Legend', value_name='Value')
|
| 204 |
-
|
| 205 |
-
sns.lineplot(x='Time', y='Value', hue='Legend', data=plot_data_melted, ax=ax, errorbar=None)
|
| 206 |
ax.set_title(title)
|
| 207 |
ax.set_xlabel('Time')
|
| 208 |
ax.set_ylabel('Value')
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 212 |
fig.delaxes(axes[j])
|
| 213 |
|
| 214 |
-
plt.tight_layout()
|
| 215 |
buf = io.BytesIO()
|
| 216 |
-
plt.savefig(buf, format='png', dpi=100)
|
| 217 |
buf.seek(0)
|
| 218 |
img = Image.open(buf)
|
| 219 |
-
plt.close(fig)
|
| 220 |
return img
|
| 221 |
|
| 222 |
def plot_scatter_comparison(df_compare_output, title):
|
| 223 |
-
|
| 224 |
-
fig, ax = plt.subplots(figsize=(12, 8)) # Define fig and ax here for all paths
|
| 225 |
-
|
| 226 |
-
plot_data_available = False # Flag to check if we have data to plot for limits
|
| 227 |
-
|
| 228 |
if df_compare_output is None or df_compare_output.empty:
|
|
|
|
|
|
|
| 229 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
| 230 |
ax.set_title(title)
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
ax.text(0.5, 0.5, "Data for scatter plot is empty.", ha='center', va='center', fontsize=15)
|
| 238 |
-
ax.set_title(title)
|
| 239 |
-
else:
|
| 240 |
-
plot_data_internal = df_compare_output.reset_index()
|
| 241 |
-
if plot_data_internal[['actual', 'estimate']].dropna().empty:
|
| 242 |
-
ax.text(0.5, 0.5, "Comparison data (actual/estimate) is empty or all NaN.", ha='center', va='center', fontsize=15)
|
| 243 |
-
ax.set_title(title)
|
| 244 |
-
else:
|
| 245 |
-
hue_col_name = df_compare_output.index.names[1]
|
| 246 |
-
plot_data_internal[hue_col_name] = plot_data_internal[hue_col_name].astype(str)
|
| 247 |
-
|
| 248 |
-
unique_levels = plot_data_internal[hue_col_name].nunique()
|
| 249 |
-
show_legend_flag = "auto"
|
| 250 |
-
if unique_levels == 1:
|
| 251 |
-
show_legend_flag = False
|
| 252 |
-
elif unique_levels > 10: # Max 10 items in legend for clarity
|
| 253 |
-
show_legend_flag = False
|
| 254 |
-
gr.Warning(f"Warning: Too many unique values ({unique_levels}) in '{hue_col_name}' for scatter plot legend. Legend hidden.")
|
| 255 |
-
|
| 256 |
-
sns.scatterplot(x='actual', y='estimate', hue=hue_col_name, data=plot_data_internal,
|
| 257 |
-
s=25, alpha=0.7, ax=ax, legend=show_legend_flag)
|
| 258 |
-
plot_data_available = True
|
| 259 |
-
ax.set_title(title)
|
| 260 |
-
|
| 261 |
-
if ax.get_legend() is not None: # If legend is shown
|
| 262 |
-
ax.get_legend().set_title(str(hue_col_name))
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
ax.set_xlabel('Actual')
|
| 266 |
ax.set_ylabel('Estimate')
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
if min_val == max_val:
|
| 280 |
-
margin = abs(min_val * 0.1) if min_val != 0 else 0.1 # 10% margin or 0.1 if value is 0
|
| 281 |
-
plot_min, plot_max = min_val - margin, max_val + margin
|
| 282 |
-
else:
|
| 283 |
-
plot_min, plot_max = min_val, max_val
|
| 284 |
-
|
| 285 |
-
# Ensure plot_min and plot_max are finite and distinct
|
| 286 |
-
if np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min < plot_max:
|
| 287 |
-
ax.plot([plot_min, plot_max], [plot_min, plot_max], 'r-', linewidth=0.7, alpha=0.8, zorder=0)
|
| 288 |
-
ax.set_xlim(plot_min, plot_max)
|
| 289 |
-
ax.set_ylim(plot_min, plot_max)
|
| 290 |
-
elif np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min == plot_max: # Handles single point case after margin
|
| 291 |
-
ax.plot([plot_min], [plot_min], 'ro', markersize=5) # Mark the point
|
| 292 |
-
ax.set_xlim(plot_min - (abs(plot_min*0.1) if plot_min !=0 else 0.1), plot_min + (abs(plot_min*0.1) if plot_min !=0 else 0.1))
|
| 293 |
-
ax.set_ylim(plot_min - (abs(plot_min*0.1) if plot_min !=0 else 0.1), plot_min + (abs(plot_min*0.1) if plot_min !=0 else 0.1))
|
| 294 |
-
|
| 295 |
-
|
| 296 |
buf = io.BytesIO()
|
| 297 |
-
plt.savefig(buf, format='png', dpi=100)
|
| 298 |
buf.seek(0)
|
| 299 |
img = Image.open(buf)
|
| 300 |
-
plt.close(fig)
|
| 301 |
return img
|
| 302 |
|
| 303 |
-
|
| 304 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
| 305 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
|
|
|
| 306 |
try:
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
# For Hugging Face, ensure files are readable.
|
| 312 |
-
# The path provided by gr.File is usually to a temp copy.
|
| 313 |
-
df = pd.read_excel(file_path)
|
| 314 |
-
|
| 315 |
-
# Try to identify policy_id:
|
| 316 |
-
# 1. Explicit 'policy_id' column (case-insensitive)
|
| 317 |
-
# 2. First column if no explicit 'policy_id'
|
| 318 |
-
pid_col_name = None
|
| 319 |
-
for col in df.columns:
|
| 320 |
-
if str(col).lower() == 'policy_id':
|
| 321 |
-
pid_col_name = col
|
| 322 |
-
break
|
| 323 |
-
|
| 324 |
-
if pid_col_name:
|
| 325 |
-
df = df.rename(columns={pid_col_name: 'policy_id'})
|
| 326 |
-
df = df.set_index('policy_id')
|
| 327 |
-
elif df.index.name and df.index.name.lower() == 'policy_id': # Already indexed by policy_id
|
| 328 |
-
pass # Keep as is
|
| 329 |
-
else: # Assume first column is policy_id if no explicit one is found
|
| 330 |
-
gr.Warning(f"No explicit 'policy_id' column/index in {os.path.basename(file_path)}. Assuming first column is policy_id.")
|
| 331 |
-
df = df.rename(columns={df.columns[0]: 'policy_id'})
|
| 332 |
-
df = df.set_index('policy_id')
|
| 333 |
-
|
| 334 |
-
if is_policy_data:
|
| 335 |
-
return df # Return all columns for policy data, selection happens next
|
| 336 |
-
return df.select_dtypes(include=np.number)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
cfs = read_and_prep_excel(cashflow_base_path)
|
| 340 |
-
cfs_lapse50 = read_and_prep_excel(cashflow_lapse_path)
|
| 341 |
-
cfs_mort15 = read_and_prep_excel(cashflow_mort_path)
|
| 342 |
-
|
| 343 |
-
pol_data_full = read_and_prep_excel(policy_data_path, is_policy_data=True)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
cols_to_select = []
|
| 350 |
-
final_rename_map = {}
|
| 351 |
-
|
| 352 |
-
for req_col_std in required_cols_std:
|
| 353 |
-
req_col_norm = req_col_std.lower().replace("_", "").replace(" ", "")
|
| 354 |
-
if req_col_norm in available_cols_map:
|
| 355 |
-
original_name = available_cols_map[req_col_norm]
|
| 356 |
-
cols_to_select.append(original_name)
|
| 357 |
-
if original_name != req_col_std: # if original name was 'Age At Entry' map to 'age_at_entry'
|
| 358 |
-
final_rename_map[original_name] = req_col_std
|
| 359 |
-
else: # If after normalization, it's still not found.
|
| 360 |
-
gr.Warning(f"Required policy data column '{req_col_std}' not found or could not be matched.")
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
if len(cols_to_select) == len(required_cols_std):
|
| 364 |
-
pol_data = pol_data_full[cols_to_select].rename(columns=final_rename_map)
|
| 365 |
-
pol_data = pol_data.select_dtypes(include=np.number) # Ensure numeric after selection
|
| 366 |
else:
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
# Attempt to recover if 'policy_id' is a column
|
| 374 |
-
if 'policy_id' in pol_data.columns:
|
| 375 |
-
pol_data = pol_data.set_index('policy_id')
|
| 376 |
-
else: # cannot proceed with pol_data
|
| 377 |
-
pol_data = pd.DataFrame() # Make it empty to signal issues later
|
| 378 |
-
|
| 379 |
-
pvs = read_and_prep_excel(pv_base_path)
|
| 380 |
-
pvs_lapse50 = read_and_prep_excel(pv_lapse_path)
|
| 381 |
-
pvs_mort15 = read_and_prep_excel(pv_mort_path)
|
| 382 |
|
| 383 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
| 384 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
| 385 |
|
| 386 |
results = {}
|
|
|
|
| 387 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
| 388 |
|
| 389 |
# --- 1. Cashflow Calibration ---
|
| 390 |
-
|
| 391 |
-
|
| 392 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
| 393 |
-
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
| 394 |
-
|
| 395 |
-
results['
|
| 396 |
-
results['
|
|
|
|
|
|
|
| 397 |
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
|
| 398 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
| 399 |
|
| 400 |
-
|
| 401 |
# --- 2. Policy Attribute Calibration ---
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
range_vals = max_vals - min_vals
|
| 406 |
-
if (range_vals.abs() < 1e-9).all(): # Check if all ranges are effectively zero
|
| 407 |
-
gr.Warning("Policy data for attribute calibration has no variance. Using unscaled data (0s).")
|
| 408 |
-
loc_vars_attrs_scaled = pd.DataFrame(0, index=pol_data.index, columns=pol_data.columns)
|
| 409 |
-
else:
|
| 410 |
-
loc_vars_attrs_scaled = pol_data.copy()
|
| 411 |
-
for col in range_vals.index:
|
| 412 |
-
if range_vals[col] > 1e-9:
|
| 413 |
-
loc_vars_attrs_scaled[col] = (pol_data[col] - min_vals[col]) / range_vals[col]
|
| 414 |
-
else:
|
| 415 |
-
loc_vars_attrs_scaled[col] = 0.0 # Column with no variance becomes 0
|
| 416 |
-
loc_vars_attrs_scaled = loc_vars_attrs_scaled.fillna(0) # Handle any NaNs from division by zero if range_vals was exactly 0
|
| 417 |
else:
|
| 418 |
-
gr.Warning("Policy data is empty. Skipping attribute calibration.")
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
results['
|
| 424 |
-
results['
|
|
|
|
| 425 |
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
|
| 426 |
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
|
| 427 |
else:
|
| 428 |
results['attr_total_cf_base'] = pd.DataFrame()
|
| 429 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
| 430 |
results['attr_total_pv_base'] = pd.DataFrame()
|
| 431 |
-
results['attr_cashflow_plot'] =
|
| 432 |
-
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(
|
|
|
|
| 433 |
|
| 434 |
# --- 3. Present Value Calibration ---
|
| 435 |
-
if pvs.empty: gr.Warning("Base Present Value data (pvs) is empty. PV Calib may fail or produce no results.")
|
| 436 |
cluster_pvs = Clusters(pvs)
|
| 437 |
-
|
| 438 |
-
results['
|
|
|
|
|
|
|
| 439 |
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
|
| 440 |
-
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
|
| 441 |
-
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
|
|
|
|
| 442 |
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
|
| 443 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
| 444 |
|
| 445 |
-
|
| 446 |
# --- Summary Comparison Plot Data ---
|
| 447 |
error_data = {}
|
| 448 |
-
|
| 449 |
-
|
|
|
|
| 450 |
return np.nan
|
| 451 |
-
if col_name and col_name in
|
| 452 |
-
|
| 453 |
-
|
|
|
|
| 454 |
else:
|
| 455 |
-
|
| 456 |
-
return abs(valid_errors).mean() if not valid_errors.empty else np.nan
|
| 457 |
|
| 458 |
key_pv_col = None
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
break
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
error_data['CF Calib.'] = [
|
| 466 |
get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
|
| 467 |
get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
|
| 468 |
get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
|
| 469 |
]
|
| 470 |
|
| 471 |
-
if not
|
| 472 |
-
|
| 473 |
-
get_error_safe(results.get('attr_total_pv_base'), key_pv_col),
|
| 474 |
-
get_error_safe(cluster_attrs.compare_total(pvs_lapse50)
|
| 475 |
-
get_error_safe(cluster_attrs.compare_total(pvs_mort15)
|
| 476 |
]
|
| 477 |
else:
|
| 478 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
|
@@ -483,72 +334,74 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
|
| 483 |
get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
|
| 484 |
]
|
| 485 |
|
| 486 |
-
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
|
| 487 |
|
| 488 |
-
fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
|
| 489 |
sns.set_style("whitegrid")
|
| 490 |
-
|
| 491 |
-
summary_df_melted = summary_df.reset_index().rename(columns={'index': 'Scenario'})
|
| 492 |
-
summary_df_melted = summary_df_melted.melt(id_vars='Scenario', var_name='Calibration Method', value_name='Absolute Error Rate')
|
| 493 |
-
|
| 494 |
-
sns.barplot(x='Scenario', y='Absolute Error Rate', hue='Calibration Method', data=summary_df_melted, ax=ax_summary)
|
| 495 |
|
| 496 |
-
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
|
|
|
|
| 499 |
ax_summary.tick_params(axis='x', rotation=0)
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
plt.tight_layout()
|
| 504 |
buf_summary = io.BytesIO()
|
| 505 |
-
plt.savefig(buf_summary, format='png', dpi=100)
|
| 506 |
buf_summary.seek(0)
|
| 507 |
results['summary_plot'] = Image.open(buf_summary)
|
| 508 |
-
plt.close(fig_summary)
|
| 509 |
|
| 510 |
return results
|
| 511 |
|
| 512 |
except FileNotFoundError as e:
|
| 513 |
-
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded
|
| 514 |
return {"error": f"File not found: {e.filename}"}
|
| 515 |
except KeyError as e:
|
| 516 |
-
gr.Error(f"A required column
|
| 517 |
-
|
| 518 |
-
traceback.print_exc()
|
| 519 |
-
return {"error": f"Missing column/index: {e}"}
|
| 520 |
except ValueError as e:
|
| 521 |
-
gr.Error(f"
|
| 522 |
-
|
| 523 |
-
traceback.print_exc()
|
| 524 |
-
return {"error": f"Data error: {str(e)}"}
|
| 525 |
except Exception as e:
|
| 526 |
-
gr.Error(f"An unexpected error occurred: {str(e)}. Check logs for details.")
|
| 527 |
import traceback
|
| 528 |
-
traceback.
|
| 529 |
-
|
|
|
|
|
|
|
| 530 |
|
| 531 |
def create_interface():
|
| 532 |
-
with gr.Blocks(title="Cluster Model Points Analysis") as demo:
|
| 533 |
gr.Markdown("""
|
| 534 |
-
# Cluster Model Points Analysis
|
|
|
|
| 535 |
This application applies cluster analysis to model point selection for insurance portfolios.
|
| 536 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
|
|
|
| 537 |
**Required Files (Excel .xlsx):**
|
| 538 |
-
- Cashflows - Base Scenario
|
| 539 |
-
- Cashflows - Lapse Stress (+50%)
|
| 540 |
-
- Cashflows - Mortality Stress (+15%)
|
| 541 |
-
- Policy Data (
|
| 542 |
-
- Present Values - Base Scenario
|
| 543 |
-
- Present Values - Lapse Stress
|
| 544 |
-
- Present Values - Mortality Stress
|
| 545 |
-
*Note: Ensure your files are in the `eg_data` directory in your Hugging Face Space if using 'Load Example Data'.*
|
| 546 |
""")
|
| 547 |
|
| 548 |
with gr.Row():
|
| 549 |
with gr.Column(scale=1):
|
| 550 |
-
gr.Markdown("### Upload Files or Load Examples")
|
| 551 |
-
|
|
|
|
|
|
|
| 552 |
with gr.Row():
|
| 553 |
cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
|
| 554 |
cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
|
|
@@ -559,86 +412,115 @@ def create_interface():
|
|
| 559 |
pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
|
| 560 |
with gr.Row():
|
| 561 |
pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
|
| 562 |
-
|
|
|
|
| 563 |
|
| 564 |
with gr.Tabs():
|
| 565 |
with gr.TabItem("📊 Summary"):
|
| 566 |
-
summary_plot_output = gr.Image(label="Calibration Methods Comparison"
|
|
|
|
| 567 |
with gr.TabItem("💸 Cashflow Calibration"):
|
| 568 |
gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
|
| 569 |
with gr.Row():
|
| 570 |
-
cf_total_base_table_out = gr.
|
| 571 |
-
cf_policy_attrs_total_out = gr.
|
| 572 |
-
cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios"
|
| 573 |
-
cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)"
|
| 574 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 575 |
with gr.Row():
|
| 576 |
-
cf_pv_total_base_out = gr.
|
| 577 |
-
cf_pv_total_lapse_out = gr.
|
| 578 |
-
cf_pv_total_mort_out = gr.
|
|
|
|
| 579 |
with gr.TabItem("👤 Policy Attribute Calibration"):
|
| 580 |
gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
|
| 581 |
with gr.Row():
|
| 582 |
-
attr_total_cf_base_out = gr.
|
| 583 |
-
attr_policy_attrs_total_out = gr.
|
| 584 |
-
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios"
|
| 585 |
-
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)"
|
| 586 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 587 |
-
|
|
|
|
| 588 |
with gr.TabItem("💰 Present Value Calibration"):
|
| 589 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
| 590 |
with gr.Row():
|
| 591 |
-
pv_total_cf_base_out = gr.
|
| 592 |
-
pv_policy_attrs_total_out = gr.
|
| 593 |
-
pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios"
|
| 594 |
-
pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)"
|
| 595 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 596 |
with gr.Row():
|
| 597 |
-
pv_total_pv_base_out = gr.
|
| 598 |
-
pv_total_pv_lapse_out = gr.
|
| 599 |
-
pv_total_pv_mort_out = gr.
|
| 600 |
|
|
|
|
| 601 |
def get_all_output_components():
|
| 602 |
return [
|
| 603 |
-
summary_plot_output,
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
]
|
| 611 |
|
| 612 |
-
|
|
|
|
| 613 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
|
|
|
| 614 |
file_paths = []
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
return [None] * len(get_all_output_components())
|
| 620 |
-
# The object from gr.File is already the path string
|
| 621 |
-
file_paths.append(file_obj)
|
| 622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
results = process_files(*file_paths)
|
|
|
|
| 625 |
|
| 626 |
-
if "error" in results: #
|
| 627 |
-
# Error message already shown by gr.Error in process_files
|
| 628 |
-
# Return Nones to clear outputs
|
| 629 |
return [None] * len(get_all_output_components())
|
| 630 |
|
| 631 |
return [
|
| 632 |
-
results.get('summary_plot'),
|
| 633 |
-
|
| 634 |
-
results.get('
|
| 635 |
-
results.get('
|
|
|
|
|
|
|
| 636 |
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
|
| 637 |
-
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'),
|
| 638 |
-
|
| 639 |
-
results.get('
|
| 640 |
-
results.get('
|
| 641 |
-
results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
|
| 642 |
]
|
| 643 |
|
| 644 |
analyze_btn.click(
|
|
@@ -648,41 +530,31 @@ def create_interface():
|
|
| 648 |
outputs=get_all_output_components()
|
| 649 |
)
|
| 650 |
|
|
|
|
| 651 |
def load_example_files():
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
# Verify all files exist after potential dummy creation
|
| 677 |
-
if any(not os.path.exists(f) for f in EXAMPLE_FILES.values()):
|
| 678 |
-
gr.Error(f"One or more example files are still missing from '{EXAMPLE_DATA_DIR}' after attempting to create dummies. Please check permissions or provide the files.")
|
| 679 |
-
return [None] * 7
|
| 680 |
-
|
| 681 |
-
gr.Info("Example data loaded. Click 'Analyze Dataset'.")
|
| 682 |
-
return [
|
| 683 |
-
EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
|
| 684 |
-
EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
|
| 685 |
-
EXAMPLE_FILES["pv_mort"]
|
| 686 |
]
|
| 687 |
|
| 688 |
load_example_btn.click(
|
|
@@ -695,11 +567,12 @@ def create_interface():
|
|
| 695 |
return demo
|
| 696 |
|
| 697 |
if __name__ == "__main__":
|
| 698 |
-
# When running locally, ensure eg_data exists.
|
| 699 |
-
# Dummy file creation is now handled by load_example_files if needed.
|
| 700 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
| 701 |
os.makedirs(EXAMPLE_DATA_DIR)
|
| 702 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
demo_app = create_interface()
|
| 705 |
demo_app.launch()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
from sklearn.cluster import KMeans
|
| 5 |
+
from sklearn.metrics import pairwise_distances_argmin_min
|
| 6 |
+
# import matplotlib.pyplot as plt # Replaced with seaborn
|
| 7 |
+
# import matplotlib.cm # Replaced with seaborn palettes
|
| 8 |
+
import seaborn as sns # Added Seaborn
|
| 9 |
import io
|
| 10 |
import os
|
| 11 |
from PIL import Image
|
| 12 |
|
| 13 |
# Define the paths for example data
|
|
|
|
| 14 |
EXAMPLE_DATA_DIR = "eg_data"
|
| 15 |
EXAMPLE_FILES = {
|
| 16 |
"cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
|
|
|
|
| 24 |
|
| 25 |
class Clusters:
|
| 26 |
def __init__(self, loc_vars):
|
| 27 |
+
self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
|
| 28 |
+
closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
|
| 31 |
rep_ids.name = 'policy_id'
|
| 32 |
rep_ids.index.name = 'cluster_id'
|
| 33 |
self.rep_ids = rep_ids
|
| 34 |
|
| 35 |
+
self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
|
| 36 |
|
| 37 |
def agg_by_cluster(self, df, agg=None):
|
| 38 |
+
"""Aggregate columns by cluster"""
|
| 39 |
temp = df.copy()
|
| 40 |
temp['cluster_id'] = self.kmeans.labels_
|
| 41 |
temp = temp.set_index('cluster_id')
|
| 42 |
+
agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
|
| 43 |
+
return temp.groupby(temp.index).agg(agg)
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def extract_reps(self, df):
|
| 46 |
+
"""Extract the rows of representative policies"""
|
| 47 |
+
temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
|
| 48 |
+
temp.index.name = 'cluster_id'
|
| 49 |
+
return temp.drop('policy_id', axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def extract_and_scale_reps(self, df, agg=None):
|
| 52 |
+
"""Extract and scale the rows of representative policies"""
|
| 53 |
if agg:
|
| 54 |
+
cols = df.columns
|
| 55 |
+
mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
|
| 56 |
+
extracted_df = self.extract_reps(df)
|
| 57 |
+
mult.index = extracted_df.index
|
| 58 |
+
return extracted_df.mul(mult)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
else:
|
| 60 |
+
return self.extract_reps(df).mul(self.policy_count, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def compare(self, df, agg=None):
|
| 63 |
+
"""Returns a multi-indexed Dataframe comparing actual and estimate"""
|
| 64 |
source = self.agg_by_cluster(df, agg)
|
| 65 |
target = self.extract_and_scale_reps(df, agg)
|
| 66 |
+
return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def compare_total(self, df, agg=None):
|
| 69 |
+
"""Aggregate df by columns"""
|
| 70 |
if agg:
|
| 71 |
actual_values = {}
|
| 72 |
for col in df.columns:
|
| 73 |
if agg.get(col, 'sum') == 'mean':
|
| 74 |
actual_values[col] = df[col].mean()
|
| 75 |
+
else: # sum
|
| 76 |
actual_values[col] = df[col].sum()
|
| 77 |
actual = pd.Series(actual_values)
|
| 78 |
|
| 79 |
reps_unscaled = self.extract_reps(df)
|
| 80 |
estimate_values = {}
|
| 81 |
|
| 82 |
+
for col in df.columns:
|
| 83 |
+
if agg.get(col, 'sum') == 'mean':
|
| 84 |
+
weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
|
| 85 |
+
total_weight = self.policy_count.sum()
|
| 86 |
+
estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
|
| 87 |
+
else: # sum
|
| 88 |
+
estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
|
| 89 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
estimate = pd.Series(estimate_values)
|
| 91 |
+
|
| 92 |
+
else: # Original logic if no agg is specified (all sum)
|
| 93 |
actual = df.sum()
|
| 94 |
+
estimate = self.extract_and_scale_reps(df).sum()
|
| 95 |
|
| 96 |
+
error = np.where(actual != 0, estimate / actual - 1, 0)
|
|
|
|
|
|
|
| 97 |
|
| 98 |
return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
|
| 102 |
+
"""Create cashflow comparison plots using Seaborn"""
|
| 103 |
+
if not cfs_list or not cluster_obj or not titles:
|
| 104 |
+
return None
|
| 105 |
+
num_plots = len(cfs_list)
|
| 106 |
+
if num_plots == 0:
|
| 107 |
+
return None
|
| 108 |
|
|
|
|
| 109 |
cols = 2
|
| 110 |
rows = (num_plots + cols - 1) // cols
|
| 111 |
|
| 112 |
+
# Use matplotlib's subplots for layout, Seaborn will plot on these axes
|
| 113 |
+
fig, axes = sns.plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
|
| 114 |
axes = axes.flatten()
|
| 115 |
+
sns.set_style("whitegrid") # Apply Seaborn style
|
| 116 |
+
|
| 117 |
+
for i, (df, title) in enumerate(zip(cfs_list, titles)):
|
| 118 |
+
if i < len(axes):
|
| 119 |
+
ax = axes[i]
|
| 120 |
+
comparison = cluster_obj.compare_total(df)
|
| 121 |
+
# Melt dataframe for Seaborn lineplot
|
| 122 |
+
plot_data = comparison[['actual', 'estimate']].reset_index().melt(
|
| 123 |
+
id_vars='index', var_name='Category', value_name='Value'
|
| 124 |
+
)
|
| 125 |
+
sns.lineplot(x='index', y='Value', hue='Category', data=plot_data, ax=ax, marker="o")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
ax.set_title(title)
|
| 127 |
ax.set_xlabel('Time')
|
| 128 |
ax.set_ylabel('Value')
|
| 129 |
+
if not plot_data.empty: # Add legend if data exists
|
| 130 |
+
ax.legend(title='Category')
|
| 131 |
+
|
| 132 |
+
for j in range(i + 1, len(axes)):
|
| 133 |
fig.delaxes(axes[j])
|
| 134 |
|
| 135 |
+
sns.plt.tight_layout()
|
| 136 |
buf = io.BytesIO()
|
| 137 |
+
sns.plt.savefig(buf, format='png', dpi=100)
|
| 138 |
buf.seek(0)
|
| 139 |
img = Image.open(buf)
|
| 140 |
+
sns.plt.close(fig) # Use sns.plt to close
|
| 141 |
return img
|
| 142 |
|
| 143 |
def plot_scatter_comparison(df_compare_output, title):
|
| 144 |
+
"""Create scatter plot comparison from compare() output using Seaborn"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if df_compare_output is None or df_compare_output.empty:
|
| 146 |
+
fig, ax = sns.plt.subplots(figsize=(12, 8)) # Use sns.plt
|
| 147 |
+
sns.set_style("whitegrid")
|
| 148 |
ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
|
| 149 |
ax.set_title(title)
|
| 150 |
+
buf = io.BytesIO()
|
| 151 |
+
sns.plt.savefig(buf, format='png', dpi=100)
|
| 152 |
+
buf.seek(0)
|
| 153 |
+
img = Image.open(buf)
|
| 154 |
+
sns.plt.close(fig)
|
| 155 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
fig, ax = sns.plt.subplots(figsize=(12, 8)) # Use sns.plt
|
| 158 |
+
sns.set_style("whitegrid")
|
| 159 |
+
|
| 160 |
+
hue_col = None
|
| 161 |
+
plot_data = df_compare_output.copy()
|
| 162 |
+
|
| 163 |
+
if isinstance(df_compare_output.index, pd.MultiIndex) and df_compare_output.index.nlevels >= 2:
|
| 164 |
+
gr.Info("Plotting with multiple item levels.")
|
| 165 |
+
# Prepare data for seaborn: reset index to use levels as columns
|
| 166 |
+
plot_data = df_compare_output.reset_index()
|
| 167 |
+
hue_col = df_compare_output.index.names[1] # Use the second level for hue
|
| 168 |
+
if hue_col is None or hue_col == "": # Handle unnamed index level
|
| 169 |
+
hue_col = "item_level_1"
|
| 170 |
+
plot_data.rename(columns={plot_data.columns[1]: hue_col}, inplace=True)
|
| 171 |
+
|
| 172 |
+
num_unique_hue = plot_data[hue_col].nunique()
|
| 173 |
+
palette = "viridis" # Default seaborn palette
|
| 174 |
+
if num_unique_hue > 10 : # If too many categories, don't use hue or use a simpler palette
|
| 175 |
+
palette = sns.color_palette("husl", num_unique_hue)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
sns.scatterplot(x='actual', y='estimate', hue=hue_col if num_unique_hue <= 20 else None,
|
| 179 |
+
data=plot_data, ax=ax, s=20, alpha=0.7, palette=palette)
|
| 180 |
+
if hue_col and num_unique_hue > 1 and num_unique_hue <= 10:
|
| 181 |
+
ax.legend(title=hue_col)
|
| 182 |
+
elif num_unique_hue > 10:
|
| 183 |
+
ax.legend().set_visible(False) # Hide legend if too many items
|
| 184 |
+
else:
|
| 185 |
+
gr.Warning("Scatter plot data is not in the expected multi-index format or has fewer than 2 levels. Plotting raw actual vs estimate without hue.")
|
| 186 |
+
sns.scatterplot(x='actual', y='estimate', data=plot_data, ax=ax, s=20, alpha=0.7)
|
| 187 |
|
| 188 |
ax.set_xlabel('Actual')
|
| 189 |
ax.set_ylabel('Estimate')
|
| 190 |
+
ax.set_title(title)
|
| 191 |
+
|
| 192 |
+
# Draw identity line
|
| 193 |
+
lims = [
|
| 194 |
+
np.min([ax.get_xlim(), ax.get_ylim()]),
|
| 195 |
+
np.max([ax.get_xlim(), ax.get_ylim()]),
|
| 196 |
+
]
|
| 197 |
+
if lims[0] != lims[1] and np.isfinite(lims[0]) and np.isfinite(lims[1]): # Check for valid limits
|
| 198 |
+
ax.plot(lims, lims, 'r-', linewidth=0.7, alpha=0.8, zorder=0)
|
| 199 |
+
ax.set_xlim(lims)
|
| 200 |
+
ax.set_ylim(lims)
|
| 201 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
buf = io.BytesIO()
|
| 203 |
+
sns.plt.savefig(buf, format='png', dpi=100)
|
| 204 |
buf.seek(0)
|
| 205 |
img = Image.open(buf)
|
| 206 |
+
sns.plt.close(fig)
|
| 207 |
return img
|
| 208 |
|
| 209 |
+
|
| 210 |
def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
|
| 211 |
policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
|
| 212 |
+
"""Main processing function - now accepts file paths"""
|
| 213 |
try:
|
| 214 |
+
cfs = pd.read_excel(cashflow_base_path, index_col=0)
|
| 215 |
+
cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
|
| 216 |
+
cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
pol_data_full = pd.read_excel(policy_data_path, index_col=0)
|
| 219 |
+
required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
|
| 220 |
+
if all(col in pol_data_full.columns for col in required_cols):
|
| 221 |
+
pol_data = pol_data_full[required_cols]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
else:
|
| 223 |
+
gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
|
| 224 |
+
pol_data = pol_data_full
|
| 225 |
+
|
| 226 |
+
pvs = pd.read_excel(pv_base_path, index_col=0)
|
| 227 |
+
pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
|
| 228 |
+
pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
cfs_list = [cfs, cfs_lapse50, cfs_mort15]
|
| 231 |
scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
|
| 232 |
|
| 233 |
results = {}
|
| 234 |
+
|
| 235 |
mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
|
| 236 |
|
| 237 |
# --- 1. Cashflow Calibration ---
|
| 238 |
+
cluster_cfs = Clusters(cfs)
|
| 239 |
+
|
| 240 |
results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
|
| 241 |
+
results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
|
| 242 |
+
|
| 243 |
+
results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
|
| 244 |
+
results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
|
| 245 |
+
results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
|
| 246 |
+
|
| 247 |
results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
|
| 248 |
results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
|
| 249 |
|
|
|
|
| 250 |
# --- 2. Policy Attribute Calibration ---
|
| 251 |
+
if not pol_data.empty and not pol_data.isnull().all().all() and (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True)).sum() != 0: # Check for actual variance
|
| 252 |
+
loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
|
| 253 |
+
loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle potential NaNs after division if a column is constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
else:
|
| 255 |
+
gr.Warning("Policy data for attribute calibration is empty, all NaNs, or has no variance. Skipping attribute calibration plots.")
|
| 256 |
+
loc_vars_attrs = pol_data # or pd.DataFrame() if you want to ensure it's empty
|
| 257 |
+
|
| 258 |
+
if not loc_vars_attrs.empty:
|
| 259 |
+
cluster_attrs = Clusters(loc_vars_attrs)
|
| 260 |
+
results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
|
| 261 |
+
results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
|
| 262 |
+
results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
|
| 263 |
results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
|
| 264 |
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
|
| 265 |
else:
|
| 266 |
results['attr_total_cf_base'] = pd.DataFrame()
|
| 267 |
results['attr_policy_attrs_total'] = pd.DataFrame()
|
| 268 |
results['attr_total_pv_base'] = pd.DataFrame()
|
| 269 |
+
results['attr_cashflow_plot'] = plot_scatter_comparison(None, "Policy Attr. Calib. - Cashflows (Base) - No Data") # Generate blank plot
|
| 270 |
+
results['attr_scatter_cashflows_base'] = plot_scatter_comparison(None, "Policy Attr. Calib. - Scatter - No Data")
|
| 271 |
+
|
| 272 |
|
| 273 |
# --- 3. Present Value Calibration ---
|
|
|
|
| 274 |
cluster_pvs = Clusters(pvs)
|
| 275 |
+
|
| 276 |
+
results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
|
| 277 |
+
results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
|
| 278 |
+
|
| 279 |
results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
|
| 280 |
+
results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
|
| 281 |
+
results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
|
| 282 |
+
|
| 283 |
results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
|
| 284 |
results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
|
| 285 |
|
|
|
|
| 286 |
# --- Summary Comparison Plot Data ---
|
| 287 |
error_data = {}
|
| 288 |
+
|
| 289 |
+
def get_error_safe(compare_result, col_name=None):
|
| 290 |
+
if compare_result is None or compare_result.empty:
|
| 291 |
return np.nan
|
| 292 |
+
if col_name and col_name in compare_result.index:
|
| 293 |
+
return abs(compare_result.loc[col_name, 'error'])
|
| 294 |
+
elif 'error' in compare_result.columns:
|
| 295 |
+
return abs(compare_result['error']).mean()
|
| 296 |
else:
|
| 297 |
+
return np.nan # Should not happen if compare_result is valid
|
|
|
|
| 298 |
|
| 299 |
key_pv_col = None
|
| 300 |
+
for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF', 'PV NET CF']: # Added more common names
|
| 301 |
+
if potential_col in pvs.columns:
|
| 302 |
+
key_pv_col = potential_col
|
| 303 |
+
break
|
| 304 |
+
# Case insensitive check
|
| 305 |
+
for col in pvs.columns:
|
| 306 |
+
if col.lower() == potential_col.lower():
|
| 307 |
+
key_pv_col = col
|
| 308 |
break
|
| 309 |
+
if key_pv_col:
|
| 310 |
+
break
|
| 311 |
+
|
| 312 |
+
if not key_pv_col and not pvs.empty:
|
| 313 |
+
gr.Warning(f"Could not find a standard PV Net CF column in PV data. Using mean absolute error for all PV columns for summary. Columns available: {pvs.columns.tolist()}")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
error_data['CF Calib.'] = [
|
| 317 |
get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
|
| 318 |
get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
|
| 319 |
get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
|
| 320 |
]
|
| 321 |
|
| 322 |
+
if not loc_vars_attrs.empty:
|
| 323 |
+
error_data['Attr Calib.'] = [
|
| 324 |
+
get_error_safe(results.get('attr_total_pv_base'), key_pv_col), # Assuming pvs is the right df here
|
| 325 |
+
get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), # Recalculate for lapse scenario with attr cluster
|
| 326 |
+
get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) # Recalculate for mort scenario with attr cluster
|
| 327 |
]
|
| 328 |
else:
|
| 329 |
error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
|
|
|
|
| 334 |
get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
|
| 335 |
]
|
| 336 |
|
| 337 |
+
summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
|
| 338 |
|
| 339 |
+
fig_summary, ax_summary = sns.plt.subplots(figsize=(10, 6)) # Use sns.plt
|
| 340 |
sns.set_style("whitegrid")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
+
# Melt the DataFrame for Seaborn barplot
|
| 343 |
+
summary_plot_data = summary_df.reset_index().melt(
|
| 344 |
+
id_vars='index', var_name='Calibration Method', value_name='Absolute Error Rate'
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
sns.barplot(x='index', y='Absolute Error Rate', hue='Calibration Method', data=summary_plot_data, ax=ax_summary, palette="muted")
|
| 348 |
+
|
| 349 |
+
ax_summary.set_ylabel('Absolute Error Rate (0.1 = 10%)')
|
| 350 |
+
title_suffix = f' (Key PV Column: {key_pv_col})' if key_pv_col else ' (Mean Absolute Error of PVs)'
|
| 351 |
ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
|
| 352 |
+
ax_summary.set_xlabel('Scenario')
|
| 353 |
ax_summary.tick_params(axis='x', rotation=0)
|
| 354 |
+
ax_summary.legend(title='Calibration Method')
|
| 355 |
+
sns.plt.tight_layout()
|
| 356 |
+
|
|
|
|
| 357 |
buf_summary = io.BytesIO()
|
| 358 |
+
sns.plt.savefig(buf_summary, format='png', dpi=100)
|
| 359 |
buf_summary.seek(0)
|
| 360 |
results['summary_plot'] = Image.open(buf_summary)
|
| 361 |
+
sns.plt.close(fig_summary)
|
| 362 |
|
| 363 |
return results
|
| 364 |
|
| 365 |
except FileNotFoundError as e:
|
| 366 |
+
gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
|
| 367 |
return {"error": f"File not found: {e.filename}"}
|
| 368 |
except KeyError as e:
|
| 369 |
+
gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
|
| 370 |
+
return {"error": f"Missing column: {e}"}
|
|
|
|
|
|
|
| 371 |
except ValueError as e:
|
| 372 |
+
gr.Error(f"ValueError during processing: {str(e)}. This might be due to empty data or data format issues (e.g. non-numeric data for clustering).")
|
| 373 |
+
return {"error": f"ValueError: {str(e)}"}
|
|
|
|
|
|
|
| 374 |
except Exception as e:
|
|
|
|
| 375 |
import traceback
|
| 376 |
+
print(traceback.format_exc()) # Print full traceback to console for debugging
|
| 377 |
+
gr.Error(f"An unexpected error occurred: {str(e)}. Check console for details.")
|
| 378 |
+
return {"error": f"Error processing files: {str(e)}"}
|
| 379 |
+
|
| 380 |
|
| 381 |
def create_interface():
|
| 382 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Cluster Model Points Analysis") as demo: # Added a theme
|
| 383 |
gr.Markdown("""
|
| 384 |
+
# Cluster Model Points Analysis 📊
|
| 385 |
+
|
| 386 |
This application applies cluster analysis to model point selection for insurance portfolios.
|
| 387 |
Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
|
| 388 |
+
|
| 389 |
**Required Files (Excel .xlsx):**
|
| 390 |
+
- Cashflows - Base Scenario
|
| 391 |
+
- Cashflows - Lapse Stress (+50%)
|
| 392 |
+
- Cashflows - Mortality Stress (+15%)
|
| 393 |
+
- Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
|
| 394 |
+
- Present Values - Base Scenario
|
| 395 |
+
- Present Values - Lapse Stress
|
| 396 |
+
- Present Values - Mortality Stress
|
|
|
|
| 397 |
""")
|
| 398 |
|
| 399 |
with gr.Row():
|
| 400 |
with gr.Column(scale=1):
|
| 401 |
+
gr.Markdown("### 📁 Upload Files or Load Examples")
|
| 402 |
+
|
| 403 |
+
load_example_btn = gr.Button("Load Example Data ✨", variant="secondary")
|
| 404 |
+
|
| 405 |
with gr.Row():
|
| 406 |
cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
|
| 407 |
cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
|
|
|
|
| 412 |
pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
|
| 413 |
with gr.Row():
|
| 414 |
pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
|
| 415 |
+
|
| 416 |
+
analyze_btn = gr.Button("Analyze Dataset 🚀", variant="primary", size="lg")
|
| 417 |
|
| 418 |
with gr.Tabs():
|
| 419 |
with gr.TabItem("📊 Summary"):
|
| 420 |
+
summary_plot_output = gr.Image(label="Calibration Methods Comparison")
|
| 421 |
+
|
| 422 |
with gr.TabItem("💸 Cashflow Calibration"):
|
| 423 |
gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
|
| 424 |
with gr.Row():
|
| 425 |
+
cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
|
| 426 |
+
cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
|
| 427 |
+
cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
| 428 |
+
cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
| 429 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 430 |
with gr.Row():
|
| 431 |
+
cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total", wrap=True)
|
| 432 |
+
cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total", wrap=True)
|
| 433 |
+
cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total", wrap=True)
|
| 434 |
+
|
| 435 |
with gr.TabItem("👤 Policy Attribute Calibration"):
|
| 436 |
gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
|
| 437 |
with gr.Row():
|
| 438 |
+
attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
|
| 439 |
+
attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
|
| 440 |
+
attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
| 441 |
+
attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
|
| 442 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 443 |
+
attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total (All Shocks)", wrap=True) # Changed label for clarity
|
| 444 |
+
|
| 445 |
with gr.TabItem("💰 Present Value Calibration"):
|
| 446 |
gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
|
| 447 |
with gr.Row():
|
| 448 |
+
pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
|
| 449 |
+
pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
|
| 450 |
+
pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
|
| 451 |
+
pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
|
| 452 |
with gr.Accordion("Present Value Comparisons (Total)", open=False):
|
| 453 |
with gr.Row():
|
| 454 |
+
pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total", wrap=True)
|
| 455 |
+
pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total", wrap=True)
|
| 456 |
+
pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total", wrap=True)
|
| 457 |
|
| 458 |
+
# --- Helper function to prepare outputs ---
|
| 459 |
def get_all_output_components():
|
| 460 |
return [
|
| 461 |
+
summary_plot_output,
|
| 462 |
+
# Cashflow Calib Outputs
|
| 463 |
+
cf_total_base_table_out, cf_policy_attrs_total_out,
|
| 464 |
+
cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
|
| 465 |
+
cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
|
| 466 |
+
# Attribute Calib Outputs
|
| 467 |
+
attr_total_cf_base_out, attr_policy_attrs_total_out,
|
| 468 |
+
attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
|
| 469 |
+
# PV Calib Outputs
|
| 470 |
+
pv_total_cf_base_out, pv_policy_attrs_total_out,
|
| 471 |
+
pv_cashflow_plot_out, pv_scatter_pvs_base_out,
|
| 472 |
+
pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
|
| 473 |
]
|
| 474 |
|
| 475 |
+
# --- Action for Analyze Button ---
|
| 476 |
+
def handle_analysis(f1, f2, f3, f4, f5, f6, f7, progress=gr.Progress(track_tqdm=True)):
|
| 477 |
files = [f1, f2, f3, f4, f5, f6, f7]
|
| 478 |
+
|
| 479 |
file_paths = []
|
| 480 |
+
file_labels = ["Cashflows - Base", "Cashflows - Lapse", "Cashflows - Mort",
|
| 481 |
+
"Policy Data", "PVs - Base", "PVs - Lapse", "PVs - Mort"]
|
| 482 |
+
|
| 483 |
+
for i, f_obj in enumerate(files):
|
| 484 |
+
if f_obj is None:
|
| 485 |
+
gr.Error(f"Missing file input for: {file_labels[i]}. Please upload all files or load examples.")
|
| 486 |
+
# Return empty/None for all outputs
|
| 487 |
+
return [None] * len(get_all_output_components())
|
| 488 |
+
|
| 489 |
+
if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
|
| 490 |
+
file_paths.append(f_obj.name)
|
| 491 |
+
elif isinstance(f_obj, str): # Already a path (from example load)
|
| 492 |
+
file_paths.append(f_obj)
|
| 493 |
+
else:
|
| 494 |
+
gr.Error(f"Invalid file input for {file_labels[i]}. Type: {type(f_obj)}")
|
| 495 |
return [None] * len(get_all_output_components())
|
|
|
|
|
|
|
| 496 |
|
| 497 |
+
progress(0, desc="Starting Analysis...")
|
| 498 |
+
# This is a placeholder for actual progress tracking if process_files were to support it.
|
| 499 |
+
# For now, it just shows activity.
|
| 500 |
+
# You could break down process_files and update progress more granularly if needed.
|
| 501 |
+
for i in range(1, 6):
|
| 502 |
+
progress(i/5, desc=f"Processing Data Step {i}/5...") # Simulate progress
|
| 503 |
+
# time.sleep(0.2) # if you want to see the progress bar update
|
| 504 |
|
| 505 |
results = process_files(*file_paths)
|
| 506 |
+
progress(1, desc="Analysis Complete!")
|
| 507 |
|
| 508 |
+
if "error" in results: # Error handled by process_files with gr.Error
|
|
|
|
|
|
|
| 509 |
return [None] * len(get_all_output_components())
|
| 510 |
|
| 511 |
return [
|
| 512 |
+
results.get('summary_plot'),
|
| 513 |
+
# CF Calib
|
| 514 |
+
results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
|
| 515 |
+
results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
|
| 516 |
+
results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
|
| 517 |
+
# Attr Calib
|
| 518 |
results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
|
| 519 |
+
results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
|
| 520 |
+
# PV Calib
|
| 521 |
+
results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
|
| 522 |
+
results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
|
| 523 |
+
results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
|
| 524 |
]
|
| 525 |
|
| 526 |
analyze_btn.click(
|
|
|
|
| 530 |
outputs=get_all_output_components()
|
| 531 |
)
|
| 532 |
|
| 533 |
+
# --- Action for Load Example Data Button ---
|
| 534 |
def load_example_files():
|
| 535 |
+
# Create eg_data directory if it doesn't exist
|
| 536 |
+
if not os.path.exists(EXAMPLE_DATA_DIR):
|
| 537 |
+
os.makedirs(EXAMPLE_DATA_DIR)
|
| 538 |
+
gr.Warning(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there. App will likely fail analysis if files are missing.")
|
| 539 |
+
|
| 540 |
+
missing_files_info = []
|
| 541 |
+
for key, fp in EXAMPLE_FILES.items():
|
| 542 |
+
if not os.path.exists(fp):
|
| 543 |
+
missing_files_info.append(f"'{key}' (expected at '{fp}')")
|
| 544 |
+
|
| 545 |
+
if missing_files_info:
|
| 546 |
+
gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files_info)}. Please ensure they exist or upload files manually.")
|
| 547 |
+
return [None] * 7 # Return None for all file inputs
|
| 548 |
+
|
| 549 |
+
gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
|
| 550 |
+
return [ # Return the paths for the File components
|
| 551 |
+
gr.File(value=EXAMPLE_FILES["cashflow_base"]),
|
| 552 |
+
gr.File(value=EXAMPLE_FILES["cashflow_lapse"]),
|
| 553 |
+
gr.File(value=EXAMPLE_FILES["cashflow_mort"]),
|
| 554 |
+
gr.File(value=EXAMPLE_FILES["policy_data"]),
|
| 555 |
+
gr.File(value=EXAMPLE_FILES["pv_base"]),
|
| 556 |
+
gr.File(value=EXAMPLE_FILES["pv_lapse"]),
|
| 557 |
+
gr.File(value=EXAMPLE_FILES["pv_mort"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
]
|
| 559 |
|
| 560 |
load_example_btn.click(
|
|
|
|
| 567 |
return demo
|
| 568 |
|
| 569 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 570 |
if not os.path.exists(EXAMPLE_DATA_DIR):
|
| 571 |
os.makedirs(EXAMPLE_DATA_DIR)
|
| 572 |
+
print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
|
| 573 |
+
print(f"Expected files in '{EXAMPLE_DATA_DIR}':")
|
| 574 |
+
for key, path in EXAMPLE_FILES.items():
|
| 575 |
+
print(f" - {key}: {os.path.basename(path)}") # Print just file name for cleaner output
|
| 576 |
|
| 577 |
demo_app = create_interface()
|
| 578 |
demo_app.launch()
|