alidenewade commited on
Commit
450e44d
·
verified ·
1 Parent(s): dd97346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -387
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the final Gradio app logic
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
@@ -15,7 +15,7 @@ EXAMPLE_FILES = {
15
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
  "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
  "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
- "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"),
19
  "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
  "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
  "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
@@ -23,16 +23,10 @@ EXAMPLE_FILES = {
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
- # Ensure loc_vars is not empty before fitting KMeans
27
- if loc_vars.empty:
28
- raise ValueError("Input data for KMeans (loc_vars) is empty.")
29
- if loc_vars.isnull().all().all():
30
- raise ValueError("Input data for KMeans (loc_vars) contains all NaN values.")
31
-
32
- self.kmeans = KMeans(n_clusters=min(1000, len(loc_vars)), random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
33
- closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
34
 
35
- rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
36
  rep_ids.name = 'policy_id'
37
  rep_ids.index.name = 'cluster_id'
38
  self.rep_ids = rep_ids
@@ -40,242 +34,177 @@ class Clusters:
40
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
41
 
42
  def agg_by_cluster(self, df, agg=None):
 
43
  temp = df.copy()
44
  temp['cluster_id'] = self.kmeans.labels_
45
  temp = temp.set_index('cluster_id')
46
-
47
- # Ensure agg is a dictionary if not None
48
- if agg is not None and not isinstance(agg, dict):
49
- # Assuming if agg is not a dict, it's the default "sum" for all, which is handled by else.
50
- # This case might need specific handling if agg can be other types.
51
- # For now, if it's not a dict, treat as if no specific agg ops were given for columns.
52
- agg_ops = {col: "sum" for col in temp.columns} # Default to sum if agg format is unexpected
53
- elif isinstance(agg, dict):
54
- agg_ops = {c: (agg[c] if c in agg else 'sum') for c in temp.columns}
55
- else: # agg is None
56
- agg_ops = "sum" # Pandas groupby will apply sum to all numeric columns
57
-
58
- return temp.groupby(temp.index).agg(agg_ops)
59
 
60
  def extract_reps(self, df):
 
61
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
62
  temp.index.name = 'cluster_id'
63
  return temp.drop('policy_id', axis=1)
64
 
65
  def extract_and_scale_reps(self, df, agg=None):
66
- extracted_df = self.extract_reps(df)
67
- if extracted_df.empty:
68
- return extracted_df # Return empty if no representatives
69
-
70
- if agg and isinstance(agg, dict):
71
- # mult should be a Series aligned with extracted_df's columns for element-wise multiplication after selection
72
- # This part of the logic seems to intend to scale rows based on policy_count for 'sum' aggs
73
- # and leave 'mean' aggs as is (to be weighted later).
74
- # The original code created a DataFrame `mult` then did .mul(mult).
75
- # A more direct approach for scaling rows:
76
- scaled_df = extracted_df.copy()
77
- for c in extracted_df.columns:
78
- if agg.get(c, 'sum') == 'sum': # Default to 'sum' if column not in agg
79
- scaled_df[c] = extracted_df[c].mul(self.policy_count, axis=0)
80
- # else (it's 'mean'), do not scale by policy_count here.
81
- return scaled_df
82
- else: # Default: scale all columns by policy_count (as if for sum)
83
- return extracted_df.mul(self.policy_count, axis=0)
84
 
85
  def compare(self, df, agg=None):
 
86
  source = self.agg_by_cluster(df, agg)
87
- target = self.extract_and_scale_reps(df, agg) # This target needs to be aggregated like source
88
-
89
- # The target from extract_and_scale_reps is already scaled per cluster for 'sum' ops.
90
- # For 'mean' ops, it's the representative value.
91
- # We need to sum up the 'sum' columns and calculate weighted average for 'mean' columns.
92
- if agg and isinstance(agg, dict):
93
- agg_ops_for_target = {}
94
- for col, method in agg.items():
95
- if method == 'sum':
96
- agg_ops_for_target[col] = 'sum'
97
- elif method == 'mean':
98
- # For mean, we need sum(val*count)/sum(count).
99
- # extract_and_scale_reps DID NOT scale mean columns by policy_count.
100
- # So, target[col] has rep values. We need to weight them.
101
- # This is better handled in compare_total. Here, target is per-cluster.
102
- # This function compares per-cluster values BEFORE final aggregation.
103
- # So target should represent aggregated values per cluster.
104
- pass # 'sum' columns are scaled, 'mean' columns are rep values
105
- else: # all sum
106
- pass # target is already scaled by policy_count, so it's the sum per cluster
107
-
108
- # This function is for per-cluster comparison, not total.
109
- # The 'target' from extract_and_scale_reps already has the representative values scaled by policy_count for sum-like aggregations.
110
- # If a column is meant for 'mean', it's just the representative value.
111
- # This 'compare' function might be misinterpreting 'target' if 'agg' has 'mean'.
112
- # The original notebook's compare function:
113
- # source = self.agg_by_cluster(df, agg) # Actual sums/means per cluster
114
- # target = self.extract_and_scale_reps(df, agg) # Rep values, scaled by count if 'sum', unscaled if 'mean'
115
- # This structure implies 'target' might not be directly comparable if 'mean' is involved without further processing.
116
- # However, the scatter plots it generates plot these per-cluster values.
117
- # For 'sum' variables, target is an estimate of the cluster total.
118
- # For 'mean' variables, target is the rep's value (estimate of cluster mean).
119
-
120
- return pd.DataFrame({'actual': source.stack(), 'estimate': target.stack()})
121
-
122
 
123
  def compare_total(self, df, agg=None):
124
- """Aggregate df by columns and compare actual vs estimate totals."""
125
- if df.empty:
126
- return pd.DataFrame(columns=['actual', 'estimate', 'error'])
127
-
128
- # Determine aggregation operations for each column
129
- op_for_actual = {}
130
- if isinstance(agg, dict):
131
- for c in df.columns:
132
- op_for_actual[c] = agg.get(c, 'sum') # Default to 'sum' if not in agg
133
- else: # agg is None or not a dict, apply sum to all
134
- for c in df.columns:
135
- if pd.api.types.is_numeric_dtype(df[c]):
136
- op_for_actual[c] = 'sum'
137
- # else: non-numeric columns will be ignored by df.agg if op not specified
138
-
139
- actual = df.agg(op_for_actual)
140
- actual = actual.dropna() # Remove non-numeric results if any
141
-
142
- # Calculate estimate
143
- reps_values = self.extract_reps(df) # Get raw representative values (one per cluster)
144
- if reps_values.empty: # No representatives found
145
- estimate = pd.Series(index=actual.index, dtype=float) # Empty or NaN series
146
- else:
147
- estimate_values = {}
148
- for col_name in actual.index: # Iterate over columns that had a valid actual aggregation
149
- col_op = op_for_actual.get(col_name, 'sum')
150
-
151
- if col_name not in reps_values.columns: # Should not happen if df columns match
152
- estimate_values[col_name] = np.nan
153
- continue
154
-
155
- rep_col_values = reps_values[col_name]
156
-
157
- if col_op == 'sum':
158
- # Estimate for sum is sum of (representative_value * policy_count_for_its_cluster)
159
- estimate_values[col_name] = (rep_col_values * self.policy_count).sum()
160
- elif col_op == 'mean':
161
- # Estimate for mean is weighted average: sum(rep_value * policy_count) / sum(policy_count)
162
- weighted_sum = (rep_col_values * self.policy_count).sum()
163
- total_weight = self.policy_count.sum()
164
- estimate_values[col_name] = weighted_sum / total_weight if total_weight != 0 else np.nan
165
- else: # Should not happen given op_for_actual logic
166
- estimate_values[col_name] = np.nan
167
 
168
- estimate = pd.Series(estimate_values, index=actual.index) # Align with actual's index
169
-
170
- # Calculate error
171
- # Align actual and estimate to ensure they cover the same items for error calculation
172
- actual_aligned, estimate_aligned = actual.align(estimate, join='inner')
173
-
174
- error = pd.Series(index=actual_aligned.index, dtype=float)
175
-
176
- # Valid division where actual is not zero and not NaN
177
- valid_mask = (actual_aligned != 0) & (~actual_aligned.isna())
178
- error[valid_mask] = estimate_aligned[valid_mask] / actual_aligned[valid_mask] - 1
179
-
180
- # Where actual is zero (and not NaN)
181
- actual_zero_mask = (actual_aligned == 0) & (~actual_aligned.isna())
182
- # If estimate is also zero, error is 0
183
- error[actual_zero_mask & (estimate_aligned == 0)] = 0
184
- # If estimate is non-zero and actual is zero, error is effectively infinite
185
- error[actual_zero_mask & (estimate_aligned != 0)] = np.inf
186
-
187
- # Replace any infinities with NaN for cleaner results (e.g., for .mean())
188
- error = error.replace([np.inf, -np.inf], np.nan)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- result_df = pd.DataFrame({'actual': actual_aligned, 'estimate': estimate_aligned, 'error': error})
191
- return result_df
192
 
193
 
194
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
195
- if not cfs_list or not cluster_obj or not titles or len(cfs_list) == 0:
196
- fig, ax = plt.subplots()
197
- ax.text(0.5, 0.5, "No data for cashflow comparison plot.", ha='center', va='center')
198
- buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
199
-
200
  num_plots = len(cfs_list)
 
 
 
 
201
  cols = 2
202
  rows = (num_plots + cols - 1) // cols
203
 
204
- fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
205
  axes = axes.flatten()
206
 
207
- plot_made = False
208
- for i, (df_cf, title) in enumerate(zip(cfs_list, titles)):
209
  if i < len(axes):
210
- if df_cf is None or df_cf.empty:
211
- axes[i].text(0.5,0.5, f"No data for {title}", ha='center', va='center')
212
- axes[i].set_title(title)
213
- continue
214
- comparison = cluster_obj.compare_total(df_cf) # Default is sum for all columns
215
- if not comparison.empty and 'actual' in comparison and 'estimate' in comparison:
216
- comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
217
- axes[i].set_xlabel('Time')
218
- axes[i].set_ylabel('Value')
219
- plot_made = True
220
- else:
221
- axes[i].text(0.5,0.5, f"Could not generate comparison for {title}", ha='center', va='center')
222
- axes[i].set_title(title)
223
 
224
- for j in range(i + 1, len(axes)): # Hide unused subplots
 
225
  fig.delaxes(axes[j])
226
 
227
- if not plot_made: # If no plots were actually made (e.g. all data was empty)
228
- plt.close(fig) # Close the figure
229
- fig, ax = plt.subplots() # Create a new one for the message
230
- ax.text(0.5, 0.5, "Insufficient data for any cashflow plots.", ha='center', va='center')
231
-
232
-
233
  plt.tight_layout()
234
  buf = io.BytesIO()
235
- plt.savefig(buf, format='png', dpi=100)
236
  buf.seek(0)
237
  img = Image.open(buf)
238
- plt.close(fig)
239
  return img
240
 
241
  def plot_scatter_comparison(df_compare_output, title):
 
242
  if df_compare_output is None or df_compare_output.empty:
243
- fig, ax = plt.subplots(figsize=(10,6)); ax.text(0.5, 0.5, "No data for scatter plot.", ha='center', va='center'); ax.set_title(title)
244
- buf = io.BytesIO(); plt.savefig(buf, format='png'); buf.seek(0); img = Image.open(buf); plt.close(fig); return img
245
-
246
- fig, ax = plt.subplots(figsize=(10, 6))
 
 
 
 
 
 
 
 
247
 
248
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
249
- # This case indicates df_compare_output is not from cluster_obj.compare() as expected
250
- ax.scatter(df_compare_output.get('actual', []), df_compare_output.get('estimate', []), s=9, alpha=0.6)
251
  else:
252
  unique_levels = df_compare_output.index.get_level_values(1).unique()
253
  colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
254
 
255
  for item_level, color_val in zip(unique_levels, colors):
256
  subset = df_compare_output.xs(item_level, level=1)
257
- ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
258
- if len(unique_levels) > 1 and len(unique_levels) <=10:
259
  ax.legend(title=df_compare_output.index.names[1])
260
 
 
261
  ax.set_xlabel('Actual')
262
  ax.set_ylabel('Estimate')
263
  ax.set_title(title)
264
  ax.grid(True)
265
 
266
- try:
267
- current_xlim = ax.get_xlim()
268
- current_ylim = ax.get_ylim()
269
- lims = [
270
- np.nanmin([current_xlim, current_ylim]),
271
- np.nanmax([current_xlim, current_ylim]),
272
- ]
273
- if lims[0] != lims[1] and not np.isnan(lims[0]) and not np.isnan(lims[1]):
274
- ax.plot(lims, lims, 'r-', linewidth=0.5)
275
- ax.set_xlim(lims)
276
- ax.set_ylim(lims)
277
- except Exception: # Catch errors if lims are problematic (e.g. all NaNs)
278
- pass
279
 
280
  buf = io.BytesIO()
281
  plt.savefig(buf, format='png', dpi=100)
@@ -287,179 +216,197 @@ def plot_scatter_comparison(df_compare_output, title):
287
 
288
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
289
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
290
- results = {}
291
  try:
 
292
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
293
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
294
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
295
 
296
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
297
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
298
- missing_policy_cols = [col for col in required_cols if col not in pol_data_full.columns]
299
- if missing_policy_cols:
300
- gr.Warning(f"Policy data is missing required columns: {', '.join(missing_policy_cols)}. Analysis may be affected.")
301
- pol_data = pol_data_full # Use what's available
302
- else:
303
  pol_data = pol_data_full[required_cols]
304
-
 
 
 
 
 
305
  pvs = pd.read_excel(pv_base_path, index_col=0)
306
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
307
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
308
 
309
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
 
310
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
311
 
312
- mean_attrs_agg = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
 
 
313
 
314
  # --- 1. Cashflow Calibration ---
315
- gr.Info("Starting Cashflow Calibration...")
316
- if cfs.empty: gr.Warning("Base cashflow data is empty for Cashflow Calibration.")
317
  cluster_cfs = Clusters(cfs)
 
318
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
319
- results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs_agg)
 
 
 
 
320
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
321
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
322
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
323
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
324
- results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'CF Calib. - Cashflows (Base)')
325
- gr.Info("Cashflow Calibration Done.")
 
326
 
327
  # --- 2. Policy Attribute Calibration ---
328
- gr.Info("Starting Policy Attribute Calibration...")
329
- if pol_data.empty :
330
- gr.Warning("Policy data is empty. Skipping Policy Attribute Calibration.")
331
- loc_vars_attrs = pd.DataFrame() # Empty dataframe
332
  else:
333
- pol_data_min = pol_data.min()
334
- pol_data_range = pol_data.max() - pol_data_min
335
- # Avoid division by zero if a column has no variance (all values are the same)
336
- if (pol_data_range == 0).any():
337
- gr.Warning("Some policy attributes have no variance (all values are the same). Standardization might be affected.")
338
- # For columns with zero range, standardized value becomes 0 or NaN depending on pandas version.
339
- # A common approach is to set them to 0 or handle them separately.
340
- # Here, we proceed, but pandas might produce NaNs if (val - min) / 0 occurs.
341
- # Let's ensure range is not zero for division:
342
- pol_data_range[pol_data_range == 0] = 1 # Avoid division by zero, effectively making constant columns 0 after (x-min)/1
343
- loc_vars_attrs = (pol_data - pol_data_min) / pol_data_range
344
- loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle any NaNs from perfect constant columns
345
-
346
  if not loc_vars_attrs.empty:
347
  cluster_attrs = Clusters(loc_vars_attrs)
348
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
349
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs_agg)
350
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
351
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
352
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Attr Calib. - Cashflows (Base)')
353
- else:
354
- results.update({k: pd.DataFrame() for k in ['attr_total_cf_base', 'attr_policy_attrs_total', 'attr_total_pv_base']})
355
- results.update({k: None for k in ['attr_cashflow_plot', 'attr_scatter_cashflows_base']})
356
- gr.Info("Policy Attribute Calibration Done.")
 
 
 
 
 
357
 
358
  # --- 3. Present Value Calibration ---
359
- gr.Info("Starting Present Value Calibration...")
360
- if pvs.empty: gr.Warning("Base Present Value data is empty for PV Calibration.")
361
  cluster_pvs = Clusters(pvs)
 
362
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
363
- results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs_agg)
 
364
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
365
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
366
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
 
367
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
368
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
369
- gr.Info("Present Value Calibration Done.")
370
-
 
371
  # --- Summary Comparison Plot Data ---
372
- gr.Info("Generating Summary Plot...")
 
 
 
373
  error_data = {}
374
- pv_col_name = 'PV_NetCF' # Target column for summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- for calib_prefix, cluster_obj, calib_name_display in [
377
- ('CF Calib.', cluster_cfs, "CF Calib."),
378
- ('Attr Calib.', globals().get('cluster_attrs'), "Attr Calib."),
379
- ('PV Calib.', cluster_pvs, "PV Calib.")]:
380
-
381
- current_calib_errors = []
382
- if cluster_obj is None and calib_prefix == 'Attr Calib.': # Attr calib might be skipped
383
- current_calib_errors = [np.nan, np.nan, np.nan]
384
- else:
385
- for pv_df_scenario in [pvs, pvs_lapse50, pvs_mort15]:
386
- if pv_df_scenario.empty:
387
- current_calib_errors.append(np.nan)
388
- continue
389
-
390
- comp_total_df = cluster_obj.compare_total(pv_df_scenario)
391
- if pv_col_name in comp_total_df.index:
392
- error_val = comp_total_df.loc[pv_col_name, 'error']
393
- elif not comp_total_df.empty and 'error' in comp_total_df.columns:
394
- error_val = comp_total_df['error'].mean() # Fallback
395
- if calib_prefix == 'CF Calib.' and pv_df_scenario is pvs: # Only warn once per type if fallback
396
- gr.Warning(f"'{pv_col_name}' not found for summary plot. Using mean error of all PV columns instead for {calib_name_display}.")
397
- else: # comp_total_df is empty or no 'error' column
398
- error_val = np.nan
399
- current_calib_errors.append(abs(error_val))
400
- error_data[calib_name_display] = current_calib_errors
401
 
402
- summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
403
-
404
- fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
 
 
 
 
 
 
 
 
405
 
406
- plot_title = f'Calibration Method Comparison - Abs. Error in Total {pv_col_name}'
407
- if summary_df.isnull().all().all():
408
- ax_summary.text(0.5, 0.5, f"Error data for summary is N/A.\nCheck input PV files for '{pv_col_name}' column and valid numeric data.",
409
- ha='center', va='center', transform=ax_summary.transAxes, wrap=True)
410
- ax_summary.set_title(plot_title)
411
- elif summary_df.empty:
412
- ax_summary.text(0.5, 0.5, "No summary data to plot.", ha='center', va='center')
413
- ax_summary.set_title(plot_title)
414
  else:
415
- summary_df.plot(kind='bar', ax=ax_summary, grid=True)
416
- ax_summary.set_ylabel(f'Mean Absolute Error (of {pv_col_name} or fallback)')
417
- ax_summary.set_title(plot_title)
418
- ax_summary.tick_params(axis='x', rotation=0)
 
419
 
 
 
 
 
 
 
 
 
420
  plt.tight_layout()
421
- buf_summary = io.BytesIO(); plt.savefig(buf_summary, format='png', dpi=100); buf_summary.seek(0)
 
 
 
422
  results['summary_plot'] = Image.open(buf_summary)
423
  plt.close(fig_summary)
424
- gr.Info("All processing complete.")
425
  return results
426
 
427
  except FileNotFoundError as e:
428
- gr.Error(f"File not found: {e.filename}. Ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded correctly.")
429
  return {"error": f"File not found: {e.filename}"}
430
- except ValueError as e: # Catch specific errors like empty data for KMeans
431
- gr.Error(f"Data validation error: {str(e)}")
432
- return {"error": f"Data error: {str(e)}"}
433
  except KeyError as e:
434
- gr.Error(f"A required column is missing: {e}. Please check data formats, especially index columns and expected data columns like 'PV_NetCF'.")
435
  return {"error": f"Missing column: {e}"}
436
  except Exception as e:
437
- gr.Error(f"An unexpected error occurred during processing: {str(e)}")
438
- import traceback
439
- traceback.print_exc() # Print full traceback to console for debugging
440
- return {"error": f"Processing error: {str(e)}"}
441
 
442
 
443
  def create_interface():
444
- with gr.Blocks(title="Cluster Model Points Analysis") as demo:
445
  gr.Markdown("""
446
  # Cluster Model Points Analysis
447
- This application applies k-means cluster analysis to select representative model points from an insurance portfolio.
448
- Upload your Excel files or use the example data to analyze results based on different calibration variable choices.
449
- **Required Excel (.xlsx) Files:**
 
 
450
  - Cashflows - Base Scenario
451
  - Cashflows - Lapse Stress (+50%)
452
  - Cashflows - Mortality Stress (+15%)
453
- - Policy Data (must include 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth', and an index column for `policy_id`)
454
- - Present Values - Base Scenario (ideally with a 'PV_NetCF' column and an index column for `policy_id`)
455
- - Present Values - Lapse Stress (same structure as Base PV)
456
- - Present Values - Mortality Stress (same structure as Base PV)
457
  """)
458
 
459
  with gr.Row():
460
  with gr.Column(scale=1):
461
- gr.Markdown("### 📂 Upload Files or Load Examples")
462
- load_example_btn = gr.Button("Load Example Data", icon="💾")
 
 
463
  with gr.Row():
464
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
465
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
@@ -470,134 +417,155 @@ def create_interface():
470
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
471
  with gr.Row():
472
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
473
- span_dummy = gr.File(visible=False) # For layout balance if needed
474
- span_dummy2 = gr.File(visible=False)
475
-
476
-
477
- analyze_btn = gr.Button("Analyze Dataset", variant="primary", icon="🚀", scale=1)
478
 
479
  with gr.Tabs():
480
  with gr.TabItem("📊 Summary"):
481
- summary_plot_output = gr.Image(label="Calibration Methods Comparison")
482
 
483
  with gr.TabItem("💸 Cashflow Calibration"):
484
- gr.Markdown("### Results: Using Annual Cashflows (Base) as Calibration Variables")
485
  with gr.Row():
486
- cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
487
- cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
488
- cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
489
- cf_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
490
- with gr.Accordion("Present Value Comparisons (Totals)", open=False):
491
  with gr.Row():
492
- cf_pv_total_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
493
- cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
494
- cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
495
 
496
  with gr.TabItem("👤 Policy Attribute Calibration"):
497
  gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
498
  with gr.Row():
499
- attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
500
- attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
501
- attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
502
- attr_scatter_cashflows_base_out = gr.Image(label="Scatter: Per-Cluster Cashflows (Base)")
503
- with gr.Accordion("Present Value Comparisons (Totals)", open=False):
504
- attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario", wrap=True)
505
 
506
  with gr.TabItem("💰 Present Value Calibration"):
507
- gr.Markdown("### Results: Using Present Values (Base) as Calibration Variables")
508
  with gr.Row():
509
- pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base CF", wrap=True)
510
- pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True)
511
- pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate)")
512
- pv_scatter_pvs_base_out = gr.Image(label="Scatter: Per-Cluster PVs (Base)")
513
- with gr.Accordion("Present Value Comparisons (Totals)", open=False):
514
  with gr.Row():
515
- pv_total_pv_base_out = gr.Dataframe(label="PVs - Base", wrap=True)
516
- pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress", wrap=True)
517
- pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress", wrap=True)
518
-
519
- output_components = [
520
- summary_plot_output,
521
- cf_total_base_table_out, cf_policy_attrs_total_out, cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
522
- cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
523
- attr_total_cf_base_out, attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
524
- pv_total_cf_base_out, pv_policy_attrs_total_out, pv_cashflow_plot_out, pv_scatter_pvs_base_out,
525
- pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
526
- ]
 
 
 
 
 
 
 
 
527
 
528
- def handle_analysis_click(f1, f2, f3, f4, f5, f6, f7):
529
- all_files_present = all(f is not None for f in [f1, f2, f3, f4, f5, f6, f7])
530
- if not all_files_present:
531
- gr.Warning("Not all files have been provided. Please upload all 7 files or load example data.")
532
- return [None] * len(output_components) # Return Nones for all output components
 
533
 
534
- # file objects (f1, etc.) from gr.File are TemporaryFileWrapper or string paths if loaded by example
535
  file_paths = []
536
- for f_obj in [f1, f2, f3, f4, f5, f6, f7]:
537
- if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Uploaded file
 
 
 
 
 
 
538
  file_paths.append(f_obj.name)
539
- elif isinstance(f_obj, str): # Path from example load
 
540
  file_paths.append(f_obj)
541
- else: # Should not happen if files are present
542
- gr.Error(f"Invalid file input: {f_obj}. Please re-upload or reload examples.")
543
- return [None] * len(output_components)
544
-
545
- analysis_results = process_files(*file_paths)
 
546
 
547
- if "error" in analysis_results: # Error handled and displayed by process_files
548
- return [None] * len(output_components)
 
549
 
550
- # Map results to output components
551
  return [
552
- analysis_results.get('summary_plot'),
553
- analysis_results.get('cf_total_base_table'), analysis_results.get('cf_policy_attrs_total'),
554
- analysis_results.get('cf_cashflow_plot'), analysis_results.get('cf_scatter_cashflows_base'),
555
- analysis_results.get('cf_pv_total_base'), analysis_results.get('cf_pv_total_lapse'), analysis_results.get('cf_pv_total_mort'),
556
- analysis_results.get('attr_total_cf_base'), analysis_results.get('attr_policy_attrs_total'),
557
- analysis_results.get('attr_cashflow_plot'), analysis_results.get('attr_scatter_cashflows_base'), analysis_results.get('attr_total_pv_base'),
558
- analysis_results.get('pv_total_cf_base'), analysis_results.get('pv_policy_attrs_total'),
559
- analysis_results.get('pv_cashflow_plot'), analysis_results.get('pv_scatter_pvs_base'),
560
- analysis_results.get('pv_total_pv_base'), analysis_results.get('pv_total_pv_lapse'), analysis_results.get('pv_total_pv_mort')
 
 
 
561
  ]
562
 
563
  analyze_btn.click(
564
- handle_analysis_click,
565
  inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
566
  policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
567
- outputs=output_components
568
  )
569
 
570
- input_file_components = [
571
- cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
572
- policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input
573
- ]
574
- def load_example_files_action():
575
- missing_example_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
576
- if missing_example_files:
577
- gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_example_files)}. Please ensure they exist.")
578
- return [None] * len(input_file_components)
579
- gr.Info(f"Example data paths loaded from '{EXAMPLE_DATA_DIR}'. Click 'Analyze Dataset'.")
580
  return [
581
  EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
582
  EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
583
  EXAMPLE_FILES["pv_mort"]
584
  ]
585
- load_example_btn.click(load_example_files_action, inputs=[], outputs=input_file_components)
 
 
 
 
 
 
 
586
  return demo
587
 
588
  if __name__ == "__main__":
 
589
  if not os.path.exists(EXAMPLE_DATA_DIR):
590
- try:
591
- os.makedirs(EXAMPLE_DATA_DIR)
592
- print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
593
- print(f"Expected files: {list(EXAMPLE_FILES.keys())}")
594
- except OSError as e:
595
- print(f"Error creating directory {EXAMPLE_DATA_DIR}: {e}. Please create it manually.")
596
-
597
- print("Starting Gradio application...")
598
- print(f"Note: Ensure your example Excel files are placed in the '{os.getcwd()}{os.sep}{EXAMPLE_DATA_DIR}' folder.")
599
- print(f"Required policy data columns: 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' (and an index col).")
600
- print(f"Recommended PV files column for summary: 'PV_NetCF' (and an index col).")
601
 
602
  demo_app = create_interface()
603
  demo_app.launch()
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm
8
  import io
 
15
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
16
  "cashflow_lapse": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_lapse50.xlsx"),
17
  "cashflow_mort": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K_mort15.xlsx"),
18
+ "policy_data": os.path.join(EXAMPLE_DATA_DIR, "model_point_table.xlsx"), # Assuming this is the correct path/name for the example
19
  "pv_base": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K.xlsx"),
20
  "pv_lapse": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_lapse50.xlsx"),
21
  "pv_mort": os.path.join(EXAMPLE_DATA_DIR, "pv_seriatim_10K_mort15.xlsx"),
 
23
 
24
  class Clusters:
25
  def __init__(self, loc_vars):
26
+ self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
27
+ closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
 
 
 
 
 
 
28
 
29
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
30
  rep_ids.name = 'policy_id'
31
  rep_ids.index.name = 'cluster_id'
32
  self.rep_ids = rep_ids
 
34
  self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
35
 
36
  def agg_by_cluster(self, df, agg=None):
37
+ """Aggregate columns by cluster"""
38
  temp = df.copy()
39
  temp['cluster_id'] = self.kmeans.labels_
40
  temp = temp.set_index('cluster_id')
41
+ agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
42
+ return temp.groupby(temp.index).agg(agg)
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def extract_reps(self, df):
45
+ """Extract the rows of representative policies"""
46
  temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
47
  temp.index.name = 'cluster_id'
48
  return temp.drop('policy_id', axis=1)
49
 
50
  def extract_and_scale_reps(self, df, agg=None):
51
+ """Extract and scale the rows of representative policies"""
52
+ if agg:
53
+ cols = df.columns
54
+ mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+ # Ensure mult has same index as extract_reps(df) for proper alignment
56
+ extracted_df = self.extract_reps(df)
57
+ mult.index = extracted_df.index
58
+ return extracted_df.mul(mult)
59
+ else:
60
+ return self.extract_reps(df).mul(self.policy_count, axis=0)
 
 
 
 
 
 
 
 
61
 
62
  def compare(self, df, agg=None):
63
+ """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
  source = self.agg_by_cluster(df, agg)
65
+ target = self.extract_and_scale_reps(df, agg)
66
+ return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def compare_total(self, df, agg=None):
69
+ """Aggregate df by columns"""
70
+ if agg:
71
+ # cols = df.columns # Not used
72
+ op = {c: (agg[c] if c in agg else 'sum') for c in df.columns}
73
+ actual = df.agg(op)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # For estimate, ensure aggregation ops are correctly applied *after* scaling
76
+ scaled_reps = self.extract_and_scale_reps(df, agg=op) # Pass op to ensure correct scaling for mean
77
+
78
+ # Corrected aggregation for estimate when 'mean' is involved
79
+ estimate_agg_ops = {}
80
+ for col_name, agg_type in op.items():
81
+ if agg_type == 'mean':
82
+ # Weighted average for mean columns
83
+ estimate_agg_ops[col_name] = lambda s, c=col_name: (s * self.policy_count.reindex(s.index)).sum() / self.policy_count.reindex(s.index).sum() if c in self.policy_count.name else s.mean()
84
+ else: # 'sum'
85
+ estimate_agg_ops[col_name] = 'sum'
86
+
87
+ # Need to handle the case where extract_and_scale_reps already applied scaling for sum
88
+ # The logic in extract_and_scale_reps is:
89
+ # mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
90
+ # This means 'mean' columns are NOT multiplied by policy_count initially.
91
+
92
+ # Let's re-think the estimate aggregation for 'mean'
93
+ estimate_scaled = self.extract_and_scale_reps(df, agg=op) # agg=op is important here
94
+
95
+ final_estimate_ops = {}
96
+ for col, method in op.items():
97
+ if method == 'mean':
98
+ # For mean, we need the sum of (value * policy_count) / sum(policy_count)
99
+ # extract_and_scale_reps with agg=op should have scaled sum-columns by policy_count
100
+ # and mean-columns by 1. So, for mean columns in estimate_scaled, we need to multiply by policy_count,
101
+ # sum them up, and divide by total policy_count.
102
+ # However, the current extract_and_scale_reps scales 'mean' columns by 1.
103
+ # So we need to take the mean of these scaled (by 1) values, but it should be a weighted mean.
104
+
105
+ # Let's try to be more direct:
106
+ # Get the representative policies (unscaled for mean columns)
107
+ reps_unscaled_for_mean = self.extract_reps(df)
108
+ estimate_values = {}
109
+ for c in df.columns:
110
+ if op[c] == 'sum':
111
+ estimate_values[c] = reps_unscaled_for_mean[c].mul(self.policy_count, axis=0).sum()
112
+ elif op[c] == 'mean':
113
+ weighted_sum = (reps_unscaled_for_mean[c] * self.policy_count).sum()
114
+ total_weight = self.policy_count.sum()
115
+ estimate_values[c] = weighted_sum / total_weight if total_weight else 0
116
+ estimate = pd.Series(estimate_values)
117
+
118
+ else: # original 'sum' logic for all columns
119
+ final_estimate_ops[col] = 'sum' # All columns in estimate_scaled are ready to be summed up
120
+ estimate = estimate_scaled.agg(final_estimate_ops)
121
+
122
+
123
+ else: # Original logic if no agg is specified (all sum)
124
+ actual = df.sum()
125
+ estimate = self.extract_and_scale_reps(df).sum()
126
 
127
+ return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': estimate / actual - 1})
 
128
 
129
 
130
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
131
+ """Create cashflow comparison plots"""
132
+ if not cfs_list or not cluster_obj or not titles:
133
+ return None # Or a placeholder image
 
 
134
  num_plots = len(cfs_list)
135
+ if num_plots == 0:
136
+ return None
137
+
138
+ # Determine subplot layout (e.g., 2x2 or adapt)
139
  cols = 2
140
  rows = (num_plots + cols - 1) // cols
141
 
142
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False) # Ensure axes is always 2D
143
  axes = axes.flatten()
144
 
145
+ for i, (df, title) in enumerate(zip(cfs_list, titles)):
 
146
  if i < len(axes):
147
+ comparison = cluster_obj.compare_total(df)
148
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
149
+ axes[i].set_xlabel('Time') # Assuming x-axis is time for cashflows
150
+ axes[i].set_ylabel('Value')
 
 
 
 
 
 
 
 
 
151
 
152
+ # Hide any unused subplots
153
+ for j in range(i + 1, len(axes)):
154
  fig.delaxes(axes[j])
155
 
 
 
 
 
 
 
156
  plt.tight_layout()
157
  buf = io.BytesIO()
158
+ plt.savefig(buf, format='png', dpi=100) # Lowered DPI slightly for potentially faster rendering
159
  buf.seek(0)
160
  img = Image.open(buf)
161
+ plt.close(fig) # Ensure figure is closed
162
  return img
163
 
164
  def plot_scatter_comparison(df_compare_output, title):
165
+ """Create scatter plot comparison from compare() output"""
166
  if df_compare_output is None or df_compare_output.empty:
167
+ # Create a blank plot with a message
168
+ fig, ax = plt.subplots(figsize=(12, 8))
169
+ ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
170
+ ax.set_title(title)
171
+ buf = io.BytesIO()
172
+ plt.savefig(buf, format='png', dpi=100)
173
+ buf.seek(0)
174
+ img = Image.open(buf)
175
+ plt.close(fig)
176
+ return img
177
+
178
+ fig, ax = plt.subplots(figsize=(12, 8)) # Use a single Axes object
179
 
180
  if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
181
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
182
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
183
  else:
184
  unique_levels = df_compare_output.index.get_level_values(1).unique()
185
  colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
186
 
187
  for item_level, color_val in zip(unique_levels, colors):
188
  subset = df_compare_output.xs(item_level, level=1)
189
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
190
+ if len(unique_levels) > 1 and len(unique_levels) <=10: # Add legend if not too many items
191
  ax.legend(title=df_compare_output.index.names[1])
192
 
193
+
194
  ax.set_xlabel('Actual')
195
  ax.set_ylabel('Estimate')
196
  ax.set_title(title)
197
  ax.grid(True)
198
 
199
+ # Draw identity line
200
+ lims = [
201
+ np.min([ax.get_xlim(), ax.get_ylim()]),
202
+ np.max([ax.get_xlim(), ax.get_ylim()]),
203
+ ]
204
+ if lims[0] != lims[1]: # Avoid issues if all data is zero or a single point
205
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
206
+ ax.set_xlim(lims)
207
+ ax.set_ylim(lims)
 
 
 
 
208
 
209
  buf = io.BytesIO()
210
  plt.savefig(buf, format='png', dpi=100)
 
216
 
217
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
218
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
219
+ """Main processing function - now accepts file paths"""
220
  try:
221
+ # Read uploaded files using paths
222
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
223
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
224
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
225
 
226
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
227
+ # Ensure the correct columns are selected for pol_data
228
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
229
+ if all(col in pol_data_full.columns for col in required_cols):
 
 
 
 
230
  pol_data = pol_data_full[required_cols]
231
+ else:
232
+ # Fallback or error if columns are missing. For now, try to use as is or a subset.
233
+ gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
234
+ pol_data = pol_data_full
235
+
236
+
237
  pvs = pd.read_excel(pv_base_path, index_col=0)
238
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
239
  pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
240
 
241
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
242
+ # pvs_list = [pvs, pvs_lapse50, pvs_mort15] # Not directly used for plotting in this structure
243
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
244
 
245
+ results = {}
246
+
247
+ mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'} # sum_assured is usually summed
248
 
249
  # --- 1. Cashflow Calibration ---
 
 
250
  cluster_cfs = Clusters(cfs)
251
+
252
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
253
+ # results['cf_total_lapse_table'] = cluster_cfs.compare_total(cfs_lapse50) # For full detail if needed
254
+ # results['cf_total_mort_table'] = cluster_cfs.compare_total(cfs_mort15)
255
+
256
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
257
+
258
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
259
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
260
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
261
+
262
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
263
+ results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
264
+ # results['cf_scatter_policy_attrs'] = plot_scatter_comparison(cluster_cfs.compare(pol_data, agg=mean_attrs), 'Cashflow Calib. - Policy Attributes')
265
+ # results['cf_scatter_pvs_base'] = plot_scatter_comparison(cluster_cfs.compare(pvs), 'Cashflow Calib. - PVs (Base)')
266
 
267
  # --- 2. Policy Attribute Calibration ---
268
+ # Standardize policy attributes
269
+ if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0 : # Avoid division by zero if a column is constant
270
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
 
271
  else:
272
+ gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
273
+ loc_vars_attrs = pol_data # or handle as an error/skip
274
+
 
 
 
 
 
 
 
 
 
 
275
  if not loc_vars_attrs.empty:
276
  cluster_attrs = Clusters(loc_vars_attrs)
277
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
278
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
279
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
280
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
281
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
282
+ # results['attr_scatter_policy_attrs'] = plot_scatter_comparison(cluster_attrs.compare(pol_data, agg=mean_attrs), 'Policy Attr. Calib. - Policy Attributes')
283
+
284
+ else: # Fill with None if skipped
285
+ results['attr_total_cf_base'] = pd.DataFrame()
286
+ results['attr_policy_attrs_total'] = pd.DataFrame()
287
+ results['attr_total_pv_base'] = pd.DataFrame()
288
+ results['attr_cashflow_plot'] = None
289
+ results['attr_scatter_cashflows_base'] = None
290
+
291
 
292
  # --- 3. Present Value Calibration ---
 
 
293
  cluster_pvs = Clusters(pvs)
294
+
295
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
296
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
297
+
298
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
299
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
300
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
301
+
302
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
303
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
304
+ # results['pv_scatter_cashflows_base'] = plot_scatter_comparison(cluster_pvs.compare(cfs), 'PV Calib. - Cashflows (Base)')
305
+
306
+
307
  # --- Summary Comparison Plot Data ---
308
+ # Error metric: Mean Absolute Percentage Error for the 'TOTAL' net present value of cashflows (usually the 'PV_NetCF' column)
309
+ # Or sum of absolute errors if percentage is problematic (e.g. actual is zero)
310
+ # For simplicity, using mean of the 'error' column from compare_total for key metrics
311
+
312
  error_data = {}
313
+
314
+ # Cashflow Calibration Errors
315
+ if 'PV_NetCF' in pvs.columns:
316
+ err_cf_cal_pv_base = cluster_cfs.compare_total(pvs).loc['PV_NetCF', 'error']
317
+ err_cf_cal_pv_lapse = cluster_cfs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
318
+ err_cf_cal_pv_mort = cluster_cfs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
319
+ error_data['CF Calib. (PV NetCF)'] = [
320
+ abs(err_cf_cal_pv_base), abs(err_cf_cal_pv_lapse), abs(err_cf_cal_pv_mort)
321
+ ]
322
+ else: # Fallback if PV_NetCF is not present
323
+ error_data['CF Calib. (PV NetCF)'] = [
324
+ abs(cluster_cfs.compare_total(pvs)['error'].mean()),
325
+ abs(cluster_cfs.compare_total(pvs_lapse50)['error'].mean()),
326
+ abs(cluster_cfs.compare_total(pvs_mort15)['error'].mean())
327
+ ]
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ # Policy Attribute Calibration Errors
331
+ if not loc_vars_attrs.empty and 'PV_NetCF' in pvs.columns:
332
+ err_attr_cal_pv_base = cluster_attrs.compare_total(pvs).loc['PV_NetCF', 'error']
333
+ err_attr_cal_pv_lapse = cluster_attrs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
334
+ err_attr_cal_pv_mort = cluster_attrs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
335
+ error_data['Attr Calib. (PV NetCF)'] = [
336
+ abs(err_attr_cal_pv_base), abs(err_attr_cal_pv_lapse), abs(err_attr_cal_pv_mort)
337
+ ]
338
+ else:
339
+ error_data['Attr Calib. (PV NetCF)'] = [np.nan, np.nan, np.nan] # Placeholder if skipped
340
+
341
 
342
+ # Present Value Calibration Errors
343
+ if 'PV_NetCF' in pvs.columns:
344
+ err_pv_cal_pv_base = cluster_pvs.compare_total(pvs).loc['PV_NetCF', 'error']
345
+ err_pv_cal_pv_lapse = cluster_pvs.compare_total(pvs_lapse50).loc['PV_NetCF', 'error']
346
+ err_pv_cal_pv_mort = cluster_pvs.compare_total(pvs_mort15).loc['PV_NetCF', 'error']
347
+ error_data['PV Calib. (PV NetCF)'] = [
348
+ abs(err_pv_cal_pv_base), abs(err_pv_cal_pv_lapse), abs(err_pv_cal_pv_mort)
349
+ ]
350
  else:
351
+ error_data['PV Calib. (PV NetCF)'] = [
352
+ abs(cluster_pvs.compare_total(pvs)['error'].mean()),
353
+ abs(cluster_pvs.compare_total(pvs_lapse50)['error'].mean()),
354
+ abs(cluster_pvs.compare_total(pvs_mort15)['error'].mean())
355
+ ]
356
 
357
+ # Create Summary Plot
358
+ summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
359
+
360
+ fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
361
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
362
+ ax_summary.set_ylabel('Mean Absolute Error (of PV_NetCF)')
363
+ ax_summary.set_title('Calibration Method Comparison - Error in Total PV Net Cashflow')
364
+ ax_summary.tick_params(axis='x', rotation=0)
365
  plt.tight_layout()
366
+
367
+ buf_summary = io.BytesIO()
368
+ plt.savefig(buf_summary, format='png', dpi=100)
369
+ buf_summary.seek(0)
370
  results['summary_plot'] = Image.open(buf_summary)
371
  plt.close(fig_summary)
372
+
373
  return results
374
 
375
  except FileNotFoundError as e:
376
+ gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
377
  return {"error": f"File not found: {e.filename}"}
 
 
 
378
  except KeyError as e:
379
+ gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
380
  return {"error": f"Missing column: {e}"}
381
  except Exception as e:
382
+ gr.Error(f"Error processing files: {str(e)}")
383
+ return {"error": f"Error processing files: {str(e)}"}
 
 
384
 
385
 
386
  def create_interface():
387
+ with gr.Blocks(title="Cluster Model Points Analysis") as demo: # Removed theme
388
  gr.Markdown("""
389
  # Cluster Model Points Analysis
390
+
391
+ This application applies cluster analysis to model point selection for insurance portfolios.
392
+ Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
393
+
394
+ **Required Files (Excel .xlsx):**
395
  - Cashflows - Base Scenario
396
  - Cashflows - Lapse Stress (+50%)
397
  - Cashflows - Mortality Stress (+15%)
398
+ - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
399
+ - Present Values - Base Scenario
400
+ - Present Values - Lapse Stress
401
+ - Present Values - Mortality Stress
402
  """)
403
 
404
  with gr.Row():
405
  with gr.Column(scale=1):
406
+ gr.Markdown("### Upload Files or Load Examples")
407
+
408
+ load_example_btn = gr.Button("Load Example Data")
409
+
410
  with gr.Row():
411
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
412
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
 
417
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
418
  with gr.Row():
419
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
420
+
421
+ analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
 
 
 
422
 
423
  with gr.Tabs():
424
  with gr.TabItem("📊 Summary"):
425
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison (Error in Total PV Net Cashflow)")
426
 
427
  with gr.TabItem("💸 Cashflow Calibration"):
428
+ gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
429
  with gr.Row():
430
+ cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
431
+ cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
432
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
433
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
434
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
435
  with gr.Row():
436
+ cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total")
437
+ cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
438
+ cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
439
 
440
  with gr.TabItem("👤 Policy Attribute Calibration"):
441
  gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
442
  with gr.Row():
443
+ attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
444
+ attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
445
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
446
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
447
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
448
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
449
 
450
  with gr.TabItem("💰 Present Value Calibration"):
451
+ gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
452
  with gr.Row():
453
+ pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)")
454
+ pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes")
455
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
456
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
457
+ with gr.Accordion("Present Value Comparisons (Total)", open=False):
458
  with gr.Row():
459
+ pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total")
460
+ pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
461
+ pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
462
+
463
+ # --- Helper function to prepare outputs ---
464
+ def get_all_output_components():
465
+ return [
466
+ summary_plot_output,
467
+ # Cashflow Calib Outputs
468
+ cf_total_base_table_out, cf_policy_attrs_total_out,
469
+ cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
470
+ cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
471
+ # Attribute Calib Outputs
472
+ attr_total_cf_base_out, attr_policy_attrs_total_out,
473
+ attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
474
+ # PV Calib Outputs
475
+ pv_total_cf_base_out, pv_policy_attrs_total_out,
476
+ pv_cashflow_plot_out, pv_scatter_pvs_base_out,
477
+ pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
478
+ ]
479
 
480
+ # --- Action for Analyze Button ---
481
+ def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
482
+ # Ensure all files are provided (either by upload or example load)
483
+ files = [f1, f2, f3, f4, f5, f6, f7]
484
+ # Gradio File objects have a .name attribute for the temp path
485
+ # If they are already strings (from example load), they are paths
486
 
 
487
  file_paths = []
488
+ for i, f_obj in enumerate(files):
489
+ if f_obj is None:
490
+ gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
491
+ # Return Nones for all output components
492
+ return [None] * len(get_all_output_components())
493
+
494
+ # If f_obj is a Gradio FileData object (from direct upload)
495
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
496
  file_paths.append(f_obj.name)
497
+ # If f_obj is already a string path (from example load)
498
+ elif isinstance(f_obj, str):
499
  file_paths.append(f_obj)
500
+ else:
501
+ gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
502
+ return [None] * len(get_all_output_components())
503
+
504
+
505
+ results = process_files(*file_paths)
506
 
507
+ if "error" in results:
508
+ # Error already displayed by process_files or here
509
+ return [None] * len(get_all_output_components())
510
 
 
511
  return [
512
+ results.get('summary_plot'),
513
+ # CF Calib
514
+ results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
515
+ results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
516
+ results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
517
+ # Attr Calib
518
+ results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
519
+ results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
520
+ # PV Calib
521
+ results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
522
+ results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
523
+ results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
524
  ]
525
 
526
  analyze_btn.click(
527
+ handle_analysis,
528
  inputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
529
  policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input],
530
+ outputs=get_all_output_components()
531
  )
532
 
533
+ # --- Action for Load Example Data Button ---
534
+ def load_example_files():
535
+ # Check if all example files exist
536
+ missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
537
+ if missing_files:
538
+ gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
539
+ return [None] * 7 # Return Nones for all file inputs
540
+
541
+ gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
 
542
  return [
543
  EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
544
  EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
545
  EXAMPLE_FILES["pv_mort"]
546
  ]
547
+
548
+ load_example_btn.click(
549
+ load_example_files,
550
+ inputs=[],
551
+ outputs=[cashflow_base_input, cashflow_lapse_input, cashflow_mort_input,
552
+ policy_data_input, pv_base_input, pv_lapse_input, pv_mort_input]
553
+ )
554
+
555
  return demo
556
 
557
  if __name__ == "__main__":
558
+ # Create the eg_data directory if it doesn't exist (for testing, user should create it with files)
559
  if not os.path.exists(EXAMPLE_DATA_DIR):
560
+ os.makedirs(EXAMPLE_DATA_DIR)
561
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
562
+ # You might want to add dummy files here for basic testing if the real files aren't present
563
+ # For example:
564
+ # with open(os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"), "w") as f: f.write("")
565
+ # ... and so on for other files, but they would be empty and cause errors in pd.read_excel.
566
+ # It's better to instruct the user to add the actual files.
567
+ print(f"Expected files in '{EXAMPLE_DATA_DIR}': {list(EXAMPLE_FILES.values())}")
568
+
 
 
569
 
570
  demo_app = create_interface()
571
  demo_app.launch()