alidenewade commited on
Commit
8e9768e
·
verified ·
1 Parent(s): 005e14d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +336 -463
app.py CHANGED
@@ -2,15 +2,15 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
- import matplotlib.pyplot as plt
7
- import seaborn as sns
 
8
  import io
9
  import os
10
  from PIL import Image
11
 
12
  # Define the paths for example data
13
- # For Hugging Face Spaces, these paths will be relative to the app's root
14
  EXAMPLE_DATA_DIR = "eg_data"
15
  EXAMPLE_FILES = {
16
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
@@ -24,455 +24,306 @@ EXAMPLE_FILES = {
24
 
25
  class Clusters:
26
  def __init__(self, loc_vars):
27
- if isinstance(loc_vars, pd.DataFrame):
28
- loc_vars_np = np.ascontiguousarray(loc_vars.values)
29
- else:
30
- loc_vars_np = np.ascontiguousarray(loc_vars)
31
-
32
- self.kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(loc_vars_np)
33
- closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, loc_vars_np)
34
 
35
- rep_ids = pd.Series(data=(closest + 1))
36
  rep_ids.name = 'policy_id'
37
  rep_ids.index.name = 'cluster_id'
38
  self.rep_ids = rep_ids
39
 
40
- self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars_np)}))['policy_count']
41
 
42
  def agg_by_cluster(self, df, agg=None):
 
43
  temp = df.copy()
44
  temp['cluster_id'] = self.kmeans.labels_
45
  temp = temp.set_index('cluster_id')
46
- agg_dict = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns if c != 'cluster_id'} if agg else "sum"
47
- if not agg_dict:
48
- return pd.DataFrame(index=temp.index.unique())
49
- return temp.groupby(level='cluster_id').agg(agg_dict)
50
 
51
  def extract_reps(self, df):
52
- if 'policy_id' not in df.columns and df.index.name != 'policy_id':
53
- # Try to use the first part of the index if it's a MultiIndex and called 'policy_id'
54
- if isinstance(df.index, pd.MultiIndex) and 'policy_id' in df.index.names:
55
- df_to_merge = df.reset_index() # Reset all levels, policy_id becomes a column
56
- else:
57
- raise ValueError("DataFrame for extract_reps must have 'policy_id' as a named index or a column.")
58
-
59
- df_to_merge = df.reset_index() if df.index.name == 'policy_id' or (isinstance(df.index, pd.MultiIndex) and 'policy_id' in df.index.names) else df.copy()
60
-
61
- if 'policy_id' not in df_to_merge.columns:
62
- # This is a fallback if policy_id was expected but still not a column.
63
- # This might happen if the index was unnamed and thought to be policy_id.
64
- # A robust solution depends on stricter input guarantees.
65
- gr.Warning("extract_reps: 'policy_id' column not found after attempting to reset index. Merging may fail or be incorrect.")
66
-
67
-
68
- temp = pd.merge(self.rep_ids.reset_index(), df_to_merge, how='left', on='policy_id')
69
- temp = temp.set_index('cluster_id')
70
- return temp.drop(columns=['policy_id'], errors='ignore')
71
 
72
  def extract_and_scale_reps(self, df, agg=None):
73
- extracted_df = self.extract_reps(df)
74
  if agg:
75
- cols_to_multiply = [col for col in df.columns if col in extracted_df.columns]
76
- mult_data = {}
77
- for c in cols_to_multiply:
78
- # Ensure self.policy_count is aligned with extracted_df.index if it's a Series
79
- if isinstance(self.policy_count, pd.Series):
80
- policy_count_for_col = self.policy_count.reindex(extracted_df.index).fillna(1) # Default to 1 if cluster missing
81
- else: # Should be a scalar or array-like usable directly
82
- policy_count_for_col = self.policy_count
83
-
84
- mult_data[c] = policy_count_for_col if (c not in agg or agg[c] == 'sum') else 1
85
-
86
- mult = pd.DataFrame(mult_data, index=extracted_df.index)
87
-
88
- result_df = extracted_df.copy()
89
- for col in cols_to_multiply:
90
- if col in mult.columns: # Ensure column exists in multiplier
91
- result_df[col] = extracted_df[col].mul(mult[col])
92
- return result_df
93
  else:
94
- numeric_cols = extracted_df.select_dtypes(include=np.number).columns
95
- result_df = extracted_df.copy()
96
- for col in numeric_cols:
97
- if isinstance(self.policy_count, pd.Series):
98
- policy_count_for_col = self.policy_count.reindex(extracted_df.index).fillna(0) # Fill with 0 if not found
99
- result_df[col] = extracted_df[col].mul(policy_count_for_col, axis=0)
100
- else: # Assuming self.policy_count is a scalar or compatible array
101
- result_df[col] = extracted_df[col].mul(self.policy_count, axis=0)
102
- return result_df
103
 
104
  def compare(self, df, agg=None):
 
105
  source = self.agg_by_cluster(df, agg)
106
  target = self.extract_and_scale_reps(df, agg)
107
-
108
- common_columns = source.columns.intersection(target.columns)
109
- if common_columns.empty and (not source.empty or not target.empty):
110
- gr.Warning("Compare function: No common columns between source and target. Result will be empty.")
111
- return pd.DataFrame({'actual': pd.Series(dtype=float), 'estimate': pd.Series(dtype=float)})
112
-
113
- source_stacked = source[common_columns].stack(dropna=False) # keepna=True for older pandas
114
- target_stacked = target[common_columns].stack(dropna=False)
115
-
116
- return pd.DataFrame({'actual': source_stacked, 'estimate': target_stacked})
117
 
118
  def compare_total(self, df, agg=None):
 
119
  if agg:
120
  actual_values = {}
121
  for col in df.columns:
122
  if agg.get(col, 'sum') == 'mean':
123
  actual_values[col] = df[col].mean()
124
- else:
125
  actual_values[col] = df[col].sum()
126
  actual = pd.Series(actual_values)
127
 
128
  reps_unscaled = self.extract_reps(df)
129
  estimate_values = {}
130
 
131
- for col_orig_df in df.columns:
132
- if col_orig_df not in reps_unscaled.columns:
133
- estimate_values[col_orig_df] = np.nan
134
- continue
135
-
136
- current_col_data = reps_unscaled[col_orig_df].astype(float) # Ensure numeric for calcs
137
- policy_counts_aligned = self.policy_count.reindex(current_col_data.index).astype(float) # Align and ensure numeric
138
-
139
- if agg.get(col_orig_df, 'sum') == 'mean':
140
- weighted_sum = (current_col_data * policy_counts_aligned).sum()
141
- total_weight = policy_counts_aligned.sum()
142
- estimate_values[col_orig_df] = weighted_sum / total_weight if total_weight > 0 else np.nan
143
- else:
144
- estimate_values[col_orig_df] = (current_col_data * policy_counts_aligned).sum()
145
  estimate = pd.Series(estimate_values)
146
- else:
 
147
  actual = df.sum()
148
- estimate = self.extract_and_scale_reps(df).sum() # This sum might need to be on numeric cols only
149
 
150
- actual, estimate = actual.align(estimate, fill_value=0)
151
- error = np.where(actual != 0, (estimate / actual) - 1, 0)
152
- error = np.nan_to_num(error, nan=0.0)
153
 
154
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
155
 
156
- # Plotting Functions (Modified for Seaborn)
157
- def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
158
- sns.set_style("whitegrid")
159
- if not cfs_list or not cluster_obj or not titles or not any(cfs_list) : # Check if cfs_list contains any non-None df
160
- # Return a placeholder image indicating no data
161
- fig, ax = plt.subplots(figsize=(7.5, 2.5)) # Smaller placeholder
162
- ax.text(0.5, 0.5, "No cashflow data to plot", ha='center', va='center', fontsize=10)
163
- ax.set_xticks([])
164
- ax.set_yticks([])
165
- buf = io.BytesIO()
166
- plt.savefig(buf, format='png', dpi=100)
167
- buf.seek(0)
168
- img = Image.open(buf)
169
- plt.close(fig)
170
- return img
171
 
172
- # Filter out None DataFrames from cfs_list to prevent errors
173
- valid_cfs_data = [(df, title) for df, title in zip(cfs_list, titles) if df is not None and not df.empty]
174
- if not valid_cfs_data: # If all DFs were None or empty
175
- return plot_cashflows_comparison([], None, []) # Recurse to get placeholder
 
 
 
176
 
177
- num_plots = len(valid_cfs_data)
178
  cols = 2
179
  rows = (num_plots + cols - 1) // cols
180
 
181
- fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
 
182
  axes = axes.flatten()
183
- plot_idx = 0 # Separate index for placing plots
184
-
185
- for df_orig, title in valid_cfs_data:
186
- if plot_idx < len(axes):
187
- ax = axes[plot_idx]
188
- comparison_df = cluster_obj.compare_total(df_orig)
189
-
190
- if comparison_df.empty:
191
- ax.text(0.5, 0.5, f"No comparison data for\n{title}", ha='center', va='center', fontsize=9)
192
- ax.set_title(title)
193
- plot_idx += 1
194
- continue
195
-
196
- plot_data = comparison_df[['actual', 'estimate']].copy()
197
- plot_data['Time'] = plot_data.index.astype(str)
198
- try:
199
- plot_data['Time'] = pd.to_numeric(plot_data['Time'])
200
- except ValueError:
201
- pass
202
-
203
- plot_data_melted = plot_data.melt(id_vars='Time', var_name='Legend', value_name='Value')
204
-
205
- sns.lineplot(x='Time', y='Value', hue='Legend', data=plot_data_melted, ax=ax, errorbar=None)
206
  ax.set_title(title)
207
  ax.set_xlabel('Time')
208
  ax.set_ylabel('Value')
209
- plot_idx += 1
210
-
211
- for j in range(plot_idx, len(axes)): # Hide any unused subplots
 
212
  fig.delaxes(axes[j])
213
 
214
- plt.tight_layout()
215
  buf = io.BytesIO()
216
- plt.savefig(buf, format='png', dpi=100)
217
  buf.seek(0)
218
  img = Image.open(buf)
219
- plt.close(fig)
220
  return img
221
 
222
  def plot_scatter_comparison(df_compare_output, title):
223
- sns.set_style("whitegrid")
224
- fig, ax = plt.subplots(figsize=(12, 8)) # Define fig and ax here for all paths
225
-
226
- plot_data_available = False # Flag to check if we have data to plot for limits
227
-
228
  if df_compare_output is None or df_compare_output.empty:
 
 
229
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
230
  ax.set_title(title)
231
- elif not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
232
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
233
- if not df_compare_output[['actual', 'estimate']].empty:
234
- sns.scatterplot(x='actual', y='estimate', data=df_compare_output, s=25, alpha=0.7, ax=ax, legend=False)
235
- plot_data_available = True
236
- else:
237
- ax.text(0.5, 0.5, "Data for scatter plot is empty.", ha='center', va='center', fontsize=15)
238
- ax.set_title(title)
239
- else:
240
- plot_data_internal = df_compare_output.reset_index()
241
- if plot_data_internal[['actual', 'estimate']].dropna().empty:
242
- ax.text(0.5, 0.5, "Comparison data (actual/estimate) is empty or all NaN.", ha='center', va='center', fontsize=15)
243
- ax.set_title(title)
244
- else:
245
- hue_col_name = df_compare_output.index.names[1]
246
- plot_data_internal[hue_col_name] = plot_data_internal[hue_col_name].astype(str)
247
-
248
- unique_levels = plot_data_internal[hue_col_name].nunique()
249
- show_legend_flag = "auto"
250
- if unique_levels == 1:
251
- show_legend_flag = False
252
- elif unique_levels > 10: # Max 10 items in legend for clarity
253
- show_legend_flag = False
254
- gr.Warning(f"Warning: Too many unique values ({unique_levels}) in '{hue_col_name}' for scatter plot legend. Legend hidden.")
255
-
256
- sns.scatterplot(x='actual', y='estimate', hue=hue_col_name, data=plot_data_internal,
257
- s=25, alpha=0.7, ax=ax, legend=show_legend_flag)
258
- plot_data_available = True
259
- ax.set_title(title)
260
-
261
- if ax.get_legend() is not None: # If legend is shown
262
- ax.get_legend().set_title(str(hue_col_name))
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  ax.set_xlabel('Actual')
266
  ax.set_ylabel('Estimate')
267
-
268
- if plot_data_available:
269
- # Use 'plot_data_internal' if it exists, else 'df_compare_output' for non-multi-index case
270
- current_plot_df = plot_data_internal if 'plot_data_internal' in locals() and not plot_data_internal.empty else df_compare_output
271
-
272
- if current_plot_df is not None and not current_plot_df.empty:
273
- all_values = pd.concat([current_plot_df['actual'], current_plot_df['estimate']]).dropna()
274
- if not all_values.empty:
275
- min_val = all_values.min()
276
- max_val = all_values.max()
277
-
278
- # Fallback if min_val and max_val are the same (e.g. single point data)
279
- if min_val == max_val:
280
- margin = abs(min_val * 0.1) if min_val != 0 else 0.1 # 10% margin or 0.1 if value is 0
281
- plot_min, plot_max = min_val - margin, max_val + margin
282
- else:
283
- plot_min, plot_max = min_val, max_val
284
-
285
- # Ensure plot_min and plot_max are finite and distinct
286
- if np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min < plot_max:
287
- ax.plot([plot_min, plot_max], [plot_min, plot_max], 'r-', linewidth=0.7, alpha=0.8, zorder=0)
288
- ax.set_xlim(plot_min, plot_max)
289
- ax.set_ylim(plot_min, plot_max)
290
- elif np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min == plot_max: # Handles single point case after margin
291
- ax.plot([plot_min], [plot_min], 'ro', markersize=5) # Mark the point
292
- ax.set_xlim(plot_min - (abs(plot_min*0.1) if plot_min !=0 else 0.1), plot_min + (abs(plot_min*0.1) if plot_min !=0 else 0.1))
293
- ax.set_ylim(plot_min - (abs(plot_min*0.1) if plot_min !=0 else 0.1), plot_min + (abs(plot_min*0.1) if plot_min !=0 else 0.1))
294
-
295
-
296
  buf = io.BytesIO()
297
- plt.savefig(buf, format='png', dpi=100)
298
  buf.seek(0)
299
  img = Image.open(buf)
300
- plt.close(fig)
301
  return img
302
 
303
- # Main Processing and Gradio UI
304
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
305
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
 
306
  try:
307
- def read_and_prep_excel(file_path, is_policy_data=False):
308
- if not os.path.exists(file_path):
309
- raise FileNotFoundError(f"File not found: {file_path}")
310
-
311
- # For Hugging Face, ensure files are readable.
312
- # The path provided by gr.File is usually to a temp copy.
313
- df = pd.read_excel(file_path)
314
-
315
- # Try to identify policy_id:
316
- # 1. Explicit 'policy_id' column (case-insensitive)
317
- # 2. First column if no explicit 'policy_id'
318
- pid_col_name = None
319
- for col in df.columns:
320
- if str(col).lower() == 'policy_id':
321
- pid_col_name = col
322
- break
323
-
324
- if pid_col_name:
325
- df = df.rename(columns={pid_col_name: 'policy_id'})
326
- df = df.set_index('policy_id')
327
- elif df.index.name and df.index.name.lower() == 'policy_id': # Already indexed by policy_id
328
- pass # Keep as is
329
- else: # Assume first column is policy_id if no explicit one is found
330
- gr.Warning(f"No explicit 'policy_id' column/index in {os.path.basename(file_path)}. Assuming first column is policy_id.")
331
- df = df.rename(columns={df.columns[0]: 'policy_id'})
332
- df = df.set_index('policy_id')
333
-
334
- if is_policy_data:
335
- return df # Return all columns for policy data, selection happens next
336
- return df.select_dtypes(include=np.number)
337
-
338
-
339
- cfs = read_and_prep_excel(cashflow_base_path)
340
- cfs_lapse50 = read_and_prep_excel(cashflow_lapse_path)
341
- cfs_mort15 = read_and_prep_excel(cashflow_mort_path)
342
-
343
- pol_data_full = read_and_prep_excel(policy_data_path, is_policy_data=True)
344
 
345
- required_cols_std = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
346
-
347
- # Normalize available column names for matching
348
- available_cols_map = {col.lower().replace("_", "").replace(" ", ""): col for col in pol_data_full.columns}
349
- cols_to_select = []
350
- final_rename_map = {}
351
-
352
- for req_col_std in required_cols_std:
353
- req_col_norm = req_col_std.lower().replace("_", "").replace(" ", "")
354
- if req_col_norm in available_cols_map:
355
- original_name = available_cols_map[req_col_norm]
356
- cols_to_select.append(original_name)
357
- if original_name != req_col_std: # if original name was 'Age At Entry' map to 'age_at_entry'
358
- final_rename_map[original_name] = req_col_std
359
- else: # If after normalization, it's still not found.
360
- gr.Warning(f"Required policy data column '{req_col_std}' not found or could not be matched.")
361
-
362
-
363
- if len(cols_to_select) == len(required_cols_std):
364
- pol_data = pol_data_full[cols_to_select].rename(columns=final_rename_map)
365
- pol_data = pol_data.select_dtypes(include=np.number) # Ensure numeric after selection
366
  else:
367
- missing_cols_display = [rc for rc in required_cols_std if rc not in final_rename_map.values()]
368
- gr.Warning(f"Policy data is missing some required columns: {missing_cols_display}. Using all available numeric columns instead.")
369
- pol_data = pol_data_full.select_dtypes(include=np.number)
370
-
371
- if pol_data.index.name != 'policy_id': # safety check if index was lost
372
- gr.Error("Policy data lost 'policy_id' index during processing.")
373
- # Attempt to recover if 'policy_id' is a column
374
- if 'policy_id' in pol_data.columns:
375
- pol_data = pol_data.set_index('policy_id')
376
- else: # cannot proceed with pol_data
377
- pol_data = pd.DataFrame() # Make it empty to signal issues later
378
-
379
- pvs = read_and_prep_excel(pv_base_path)
380
- pvs_lapse50 = read_and_prep_excel(pv_lapse_path)
381
- pvs_mort15 = read_and_prep_excel(pv_mort_path)
382
 
383
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
384
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
385
 
386
  results = {}
 
387
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
388
 
389
  # --- 1. Cashflow Calibration ---
390
- if cfs.empty: gr.Warning("Base cashflow data (cfs) is empty. CF Calib may fail or produce no results.")
391
- cluster_cfs = Clusters(cfs)
392
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
393
- results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs) if not pol_data.empty else pd.DataFrame()
394
- results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs) if not pvs.empty else pd.DataFrame()
395
- results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50) if not pvs_lapse50.empty else pd.DataFrame()
396
- results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15) if not pvs_mort15.empty else pd.DataFrame()
 
 
397
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
398
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
399
 
400
-
401
  # --- 2. Policy Attribute Calibration ---
402
- loc_vars_attrs_scaled = pd.DataFrame() # Initialize
403
- if not pol_data.empty:
404
- min_vals, max_vals = pol_data.min(), pol_data.max()
405
- range_vals = max_vals - min_vals
406
- if (range_vals.abs() < 1e-9).all(): # Check if all ranges are effectively zero
407
- gr.Warning("Policy data for attribute calibration has no variance. Using unscaled data (0s).")
408
- loc_vars_attrs_scaled = pd.DataFrame(0, index=pol_data.index, columns=pol_data.columns)
409
- else:
410
- loc_vars_attrs_scaled = pol_data.copy()
411
- for col in range_vals.index:
412
- if range_vals[col] > 1e-9:
413
- loc_vars_attrs_scaled[col] = (pol_data[col] - min_vals[col]) / range_vals[col]
414
- else:
415
- loc_vars_attrs_scaled[col] = 0.0 # Column with no variance becomes 0
416
- loc_vars_attrs_scaled = loc_vars_attrs_scaled.fillna(0) # Handle any NaNs from division by zero if range_vals was exactly 0
417
  else:
418
- gr.Warning("Policy data is empty. Skipping attribute calibration.")
419
-
420
- if not loc_vars_attrs_scaled.empty:
421
- cluster_attrs = Clusters(loc_vars_attrs_scaled)
422
- results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs) if not cfs.empty else pd.DataFrame()
423
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) # Compare with original pol_data
424
- results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs) if not pvs.empty else pd.DataFrame()
 
425
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
426
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
427
  else:
428
  results['attr_total_cf_base'] = pd.DataFrame()
429
  results['attr_policy_attrs_total'] = pd.DataFrame()
430
  results['attr_total_pv_base'] = pd.DataFrame()
431
- results['attr_cashflow_plot'] = plot_cashflows_comparison([None,None,None], None, scen_titles) # Pass None to get placeholder
432
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - No Data')
 
433
 
434
  # --- 3. Present Value Calibration ---
435
- if pvs.empty: gr.Warning("Base Present Value data (pvs) is empty. PV Calib may fail or produce no results.")
436
  cluster_pvs = Clusters(pvs)
437
- results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs) if not cfs.empty else pd.DataFrame()
438
- results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs) if not pol_data.empty else pd.DataFrame()
 
 
439
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
440
- results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50) if not pvs_lapse50.empty else pd.DataFrame()
441
- results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15) if not pvs_mort15.empty else pd.DataFrame()
 
442
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
443
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
444
 
445
-
446
  # --- Summary Comparison Plot Data ---
447
  error_data = {}
448
- def get_error_safe(compare_result_df, col_name=None):
449
- if compare_result_df is None or compare_result_df.empty or 'error' not in compare_result_df.columns:
 
450
  return np.nan
451
- if col_name and col_name in compare_result_df.index:
452
- error_val = compare_result_df.loc[col_name, 'error']
453
- return abs(error_val) if pd.notna(error_val) else np.nan
 
454
  else:
455
- valid_errors = compare_result_df['error'].dropna()
456
- return abs(valid_errors).mean() if not valid_errors.empty else np.nan
457
 
458
  key_pv_col = None
459
- if not pvs.empty:
460
- for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
461
- if potential_col in pvs.columns:
462
- key_pv_col = potential_col
 
 
 
 
463
  break
464
-
 
 
 
 
 
 
465
  error_data['CF Calib.'] = [
466
  get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
467
  get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
468
  get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
469
  ]
470
 
471
- if not loc_vars_attrs_scaled.empty and 'cluster_attrs' in locals() : # Check if attribute calibration was performed
472
- error_data['Attr Calib.'] = [
473
- get_error_safe(results.get('attr_total_pv_base'), key_pv_col),
474
- get_error_safe(cluster_attrs.compare_total(pvs_lapse50) if not pvs_lapse50.empty else pd.DataFrame(), key_pv_col),
475
- get_error_safe(cluster_attrs.compare_total(pvs_mort15) if not pvs_mort15.empty else pd.DataFrame(), key_pv_col)
476
  ]
477
  else:
478
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
@@ -483,72 +334,74 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
483
  get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
484
  ]
485
 
486
- summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']).astype(float)
487
 
488
- fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
489
  sns.set_style("whitegrid")
490
-
491
- summary_df_melted = summary_df.reset_index().rename(columns={'index': 'Scenario'})
492
- summary_df_melted = summary_df_melted.melt(id_vars='Scenario', var_name='Calibration Method', value_name='Absolute Error Rate')
493
-
494
- sns.barplot(x='Scenario', y='Absolute Error Rate', hue='Calibration Method', data=summary_df_melted, ax=ax_summary)
495
 
496
- ax_summary.set_ylabel('Absolute Error Rate')
497
- title_suffix = f' for {key_pv_col}' if key_pv_col else ' (Mean Absolute Error)'
 
 
 
 
 
 
 
498
  ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
 
499
  ax_summary.tick_params(axis='x', rotation=0)
500
- if ax_summary.get_legend(): ax_summary.get_legend().set_title('Calibration Method')
501
- ax_summary.grid(True, axis='y')
502
-
503
- plt.tight_layout()
504
  buf_summary = io.BytesIO()
505
- plt.savefig(buf_summary, format='png', dpi=100)
506
  buf_summary.seek(0)
507
  results['summary_plot'] = Image.open(buf_summary)
508
- plt.close(fig_summary)
509
 
510
  return results
511
 
512
  except FileNotFoundError as e:
513
- gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded and paths are correct for Hugging Face.")
514
  return {"error": f"File not found: {e.filename}"}
515
  except KeyError as e:
516
- gr.Error(f"A required column/index ('policy_id' or feature column) is missing or misnamed: {e}. Please check data format and Hugging Face file structure.")
517
- import traceback
518
- traceback.print_exc()
519
- return {"error": f"Missing column/index: {e}"}
520
  except ValueError as e:
521
- gr.Error(f"Data processing or plotting error: {str(e)}. Check data consistency and formats.")
522
- import traceback
523
- traceback.print_exc()
524
- return {"error": f"Data error: {str(e)}"}
525
  except Exception as e:
526
- gr.Error(f"An unexpected error occurred: {str(e)}. Check logs for details.")
527
  import traceback
528
- traceback.print_exc()
529
- return {"error": f"Unexpected error: {str(e)}"}
 
 
530
 
531
  def create_interface():
532
- with gr.Blocks(title="Cluster Model Points Analysis") as demo:
533
  gr.Markdown("""
534
- # Cluster Model Points Analysis
 
535
  This application applies cluster analysis to model point selection for insurance portfolios.
536
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
 
537
  **Required Files (Excel .xlsx):**
538
- - Cashflows - Base Scenario (must contain a 'policy_id' column/index)
539
- - Cashflows - Lapse Stress (+50%) (similar structure)
540
- - Cashflows - Mortality Stress (+15%) (similar structure)
541
- - Policy Data (must contain 'policy_id', and ideally 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
542
- - Present Values - Base Scenario (must contain 'policy_id' and PV columns like 'PV_NetCF')
543
- - Present Values - Lapse Stress (similar structure)
544
- - Present Values - Mortality Stress (similar structure)
545
- *Note: Ensure your files are in the `eg_data` directory in your Hugging Face Space if using 'Load Example Data'.*
546
  """)
547
 
548
  with gr.Row():
549
  with gr.Column(scale=1):
550
- gr.Markdown("### Upload Files or Load Examples")
551
- load_example_btn = gr.Button("Load Example Data")
 
 
552
  with gr.Row():
553
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
554
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
@@ -559,86 +412,115 @@ def create_interface():
559
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
560
  with gr.Row():
561
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
562
- analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
 
563
 
564
  with gr.Tabs():
565
  with gr.TabItem("📊 Summary"):
566
- summary_plot_output = gr.Image(label="Calibration Methods Comparison", type="pil")
 
567
  with gr.TabItem("💸 Cashflow Calibration"):
568
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
569
  with gr.Row():
570
- cf_total_base_table_out = gr.DataFrame(label="Overall Comparison - Base Scenario (Cashflows)")
571
- cf_policy_attrs_total_out = gr.DataFrame(label="Overall Comparison - Policy Attributes")
572
- cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios", type="pil")
573
- cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)", type="pil")
574
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
575
  with gr.Row():
576
- cf_pv_total_base_out = gr.DataFrame(label="PVs - Base Total")
577
- cf_pv_total_lapse_out = gr.DataFrame(label="PVs - Lapse Stress Total")
578
- cf_pv_total_mort_out = gr.DataFrame(label="PVs - Mortality Stress Total")
 
579
  with gr.TabItem("👤 Policy Attribute Calibration"):
580
  gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
581
  with gr.Row():
582
- attr_total_cf_base_out = gr.DataFrame(label="Overall Comparison - Base Scenario (Cashflows)")
583
- attr_policy_attrs_total_out = gr.DataFrame(label="Overall Comparison - Policy Attributes")
584
- attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios", type="pil")
585
- attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)", type="pil")
586
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
587
- attr_total_pv_base_out = gr.DataFrame(label="PVs - Base Scenario Total")
 
588
  with gr.TabItem("💰 Present Value Calibration"):
589
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
590
  with gr.Row():
591
- pv_total_cf_base_out = gr.DataFrame(label="Overall Comparison - Base Scenario (Cashflows)")
592
- pv_policy_attrs_total_out = gr.DataFrame(label="Overall Comparison - Policy Attributes")
593
- pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios", type="pil")
594
- pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)", type="pil")
595
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
596
  with gr.Row():
597
- pv_total_pv_base_out = gr.DataFrame(label="PVs - Base Total")
598
- pv_total_pv_lapse_out = gr.DataFrame(label="PVs - Lapse Stress Total")
599
- pv_total_pv_mort_out = gr.DataFrame(label="PVs - Mortality Stress Total")
600
 
 
601
  def get_all_output_components():
602
  return [
603
- summary_plot_output, cf_total_base_table_out, cf_policy_attrs_total_out,
604
- cf_cashflow_plot_out, cf_scatter_cashflows_base_out, cf_pv_total_base_out,
605
- cf_pv_total_lapse_out, cf_pv_total_mort_out, attr_total_cf_base_out,
606
- attr_policy_attrs_total_out, attr_cashflow_plot_out, attr_scatter_cashflows_base_out,
607
- attr_total_pv_base_out, pv_total_cf_base_out, pv_policy_attrs_total_out,
608
- pv_cashflow_plot_out, pv_scatter_pvs_base_out, pv_total_pv_base_out,
609
- pv_total_pv_lapse_out, pv_total_pv_mort_out
 
 
 
 
 
610
  ]
611
 
612
- def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
 
613
  files = [f1, f2, f3, f4, f5, f6, f7]
 
614
  file_paths = []
615
- # In Gradio 4+, File component value is a string path to a temp file or None
616
- for i, file_obj in enumerate(files):
617
- if file_obj is None: # No file uploaded for this slot
618
- gr.Error(f"Missing file for input {i+1}. Please upload all required files or use 'Load Example Data'.")
 
 
 
 
 
 
 
 
 
 
 
619
  return [None] * len(get_all_output_components())
620
- # The object from gr.File is already the path string
621
- file_paths.append(file_obj)
622
 
 
 
 
 
 
 
 
623
 
624
  results = process_files(*file_paths)
 
625
 
626
- if "error" in results: # If process_files indicated an error
627
- # Error message already shown by gr.Error in process_files
628
- # Return Nones to clear outputs
629
  return [None] * len(get_all_output_components())
630
 
631
  return [
632
- results.get('summary_plot'), results.get('cf_total_base_table'),
633
- results.get('cf_policy_attrs_total'), results.get('cf_cashflow_plot'),
634
- results.get('cf_scatter_cashflows_base'), results.get('cf_pv_total_base'),
635
- results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
 
 
636
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
637
- results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'),
638
- results.get('attr_total_pv_base'), results.get('pv_total_cf_base'),
639
- results.get('pv_policy_attrs_total'), results.get('pv_cashflow_plot'),
640
- results.get('pv_scatter_pvs_base'), results.get('pv_total_pv_base'),
641
- results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
642
  ]
643
 
644
  analyze_btn.click(
@@ -648,41 +530,31 @@ def create_interface():
648
  outputs=get_all_output_components()
649
  )
650
 
 
651
  def load_example_files():
652
- os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True)
653
- # Check for existing files and create dummies if not present
654
- for key, target_path in EXAMPLE_FILES.items():
655
- if not os.path.exists(target_path):
656
- gr.Info(f"Example file {os.path.basename(target_path)} not found in '{EXAMPLE_DATA_DIR}'. Attempting to create a dummy file.")
657
- try:
658
- num_policies = 50
659
- df_data = {'policy_id': [f'P{j:03d}' for j in range(num_policies)]}
660
- if "cashflow" in key or "pv" in key:
661
- for i in range(10): df_data[f't{i}'] = np.random.uniform(100, 1000, num_policies)
662
- elif "policy_data" in key:
663
- df_data.update({
664
- 'age_at_entry': np.random.randint(20, 60, num_policies),
665
- 'policy_term': np.random.randint(5, 30, num_policies),
666
- 'sum_assured': np.random.randint(5000, 200000, num_policies),
667
- 'duration_mth': np.random.randint(1, 300, num_policies)
668
- })
669
- else: df_data['feature1'] = np.random.rand(num_policies)
670
- pd.DataFrame(df_data).to_excel(target_path, index=False)
671
- gr.Info(f"Dummy file '{os.path.basename(target_path)}' created.")
672
- except Exception as e:
673
- gr.Error(f"Failed to create dummy file {os.path.basename(target_path)}: {e}")
674
- return [None] * 7 # Abort if dummy creation fails
675
-
676
- # Verify all files exist after potential dummy creation
677
- if any(not os.path.exists(f) for f in EXAMPLE_FILES.values()):
678
- gr.Error(f"One or more example files are still missing from '{EXAMPLE_DATA_DIR}' after attempting to create dummies. Please check permissions or provide the files.")
679
- return [None] * 7
680
-
681
- gr.Info("Example data loaded. Click 'Analyze Dataset'.")
682
- return [
683
- EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
684
- EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
685
- EXAMPLE_FILES["pv_mort"]
686
  ]
687
 
688
  load_example_btn.click(
@@ -695,11 +567,12 @@ def create_interface():
695
  return demo
696
 
697
  if __name__ == "__main__":
698
- # When running locally, ensure eg_data exists.
699
- # Dummy file creation is now handled by load_example_files if needed.
700
  if not os.path.exists(EXAMPLE_DATA_DIR):
701
  os.makedirs(EXAMPLE_DATA_DIR)
702
- print(f"Directory '{EXAMPLE_DATA_DIR}' created/ensured. Example files will be checked/created by 'Load Example Data' button if not present.")
 
 
 
703
 
704
  demo_app = create_interface()
705
  demo_app.launch()
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min
6
+ # import matplotlib.pyplot as plt # Replaced with seaborn
7
+ # import matplotlib.cm # Replaced with seaborn palettes
8
+ import seaborn as sns # Added Seaborn
9
  import io
10
  import os
11
  from PIL import Image
12
 
13
  # Define the paths for example data
 
14
  EXAMPLE_DATA_DIR = "eg_data"
15
  EXAMPLE_FILES = {
16
  "cashflow_base": os.path.join(EXAMPLE_DATA_DIR, "cashflows_seriatim_10K.xlsx"),
 
24
 
25
  class Clusters:
26
  def __init__(self, loc_vars):
27
+ self.kmeans = kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(np.ascontiguousarray(loc_vars))
28
+ closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, np.ascontiguousarray(loc_vars))
 
 
 
 
 
29
 
30
+ rep_ids = pd.Series(data=(closest+1)) # 0-based to 1-based indexes
31
  rep_ids.name = 'policy_id'
32
  rep_ids.index.name = 'cluster_id'
33
  self.rep_ids = rep_ids
34
 
35
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars)}))['policy_count']
36
 
37
  def agg_by_cluster(self, df, agg=None):
38
+ """Aggregate columns by cluster"""
39
  temp = df.copy()
40
  temp['cluster_id'] = self.kmeans.labels_
41
  temp = temp.set_index('cluster_id')
42
+ agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
43
+ return temp.groupby(temp.index).agg(agg)
 
 
44
 
45
  def extract_reps(self, df):
46
+ """Extract the rows of representative policies"""
47
+ temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
48
+ temp.index.name = 'cluster_id'
49
+ return temp.drop('policy_id', axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def extract_and_scale_reps(self, df, agg=None):
52
+ """Extract and scale the rows of representative policies"""
53
  if agg:
54
+ cols = df.columns
55
+ mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
56
+ extracted_df = self.extract_reps(df)
57
+ mult.index = extracted_df.index
58
+ return extracted_df.mul(mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  else:
60
+ return self.extract_reps(df).mul(self.policy_count, axis=0)
 
 
 
 
 
 
 
 
61
 
62
  def compare(self, df, agg=None):
63
+ """Returns a multi-indexed Dataframe comparing actual and estimate"""
64
  source = self.agg_by_cluster(df, agg)
65
  target = self.extract_and_scale_reps(df, agg)
66
+ return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
 
 
 
 
 
 
 
 
 
67
 
68
  def compare_total(self, df, agg=None):
69
+ """Aggregate df by columns"""
70
  if agg:
71
  actual_values = {}
72
  for col in df.columns:
73
  if agg.get(col, 'sum') == 'mean':
74
  actual_values[col] = df[col].mean()
75
+ else: # sum
76
  actual_values[col] = df[col].sum()
77
  actual = pd.Series(actual_values)
78
 
79
  reps_unscaled = self.extract_reps(df)
80
  estimate_values = {}
81
 
82
+ for col in df.columns:
83
+ if agg.get(col, 'sum') == 'mean':
84
+ weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
85
+ total_weight = self.policy_count.sum()
86
+ estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
87
+ else: # sum
88
+ estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
89
+
 
 
 
 
 
 
90
  estimate = pd.Series(estimate_values)
91
+
92
+ else: # Original logic if no agg is specified (all sum)
93
  actual = df.sum()
94
+ estimate = self.extract_and_scale_reps(df).sum()
95
 
96
+ error = np.where(actual != 0, estimate / actual - 1, 0)
 
 
97
 
98
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
102
+ """Create cashflow comparison plots using Seaborn"""
103
+ if not cfs_list or not cluster_obj or not titles:
104
+ return None
105
+ num_plots = len(cfs_list)
106
+ if num_plots == 0:
107
+ return None
108
 
 
109
  cols = 2
110
  rows = (num_plots + cols - 1) // cols
111
 
112
+ # Use matplotlib's subplots for layout, Seaborn will plot on these axes
113
+ fig, axes = sns.plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
114
  axes = axes.flatten()
115
+ sns.set_style("whitegrid") # Apply Seaborn style
116
+
117
+ for i, (df, title) in enumerate(zip(cfs_list, titles)):
118
+ if i < len(axes):
119
+ ax = axes[i]
120
+ comparison = cluster_obj.compare_total(df)
121
+ # Melt dataframe for Seaborn lineplot
122
+ plot_data = comparison[['actual', 'estimate']].reset_index().melt(
123
+ id_vars='index', var_name='Category', value_name='Value'
124
+ )
125
+ sns.lineplot(x='index', y='Value', hue='Category', data=plot_data, ax=ax, marker="o")
 
 
 
 
 
 
 
 
 
 
 
 
126
  ax.set_title(title)
127
  ax.set_xlabel('Time')
128
  ax.set_ylabel('Value')
129
+ if not plot_data.empty: # Add legend if data exists
130
+ ax.legend(title='Category')
131
+
132
+ for j in range(i + 1, len(axes)):
133
  fig.delaxes(axes[j])
134
 
135
+ sns.plt.tight_layout()
136
  buf = io.BytesIO()
137
+ sns.plt.savefig(buf, format='png', dpi=100)
138
  buf.seek(0)
139
  img = Image.open(buf)
140
+ sns.plt.close(fig) # Use sns.plt to close
141
  return img
142
 
143
  def plot_scatter_comparison(df_compare_output, title):
144
+ """Create scatter plot comparison from compare() output using Seaborn"""
 
 
 
 
145
  if df_compare_output is None or df_compare_output.empty:
146
+ fig, ax = sns.plt.subplots(figsize=(12, 8)) # Use sns.plt
147
+ sns.set_style("whitegrid")
148
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
149
  ax.set_title(title)
150
+ buf = io.BytesIO()
151
+ sns.plt.savefig(buf, format='png', dpi=100)
152
+ buf.seek(0)
153
+ img = Image.open(buf)
154
+ sns.plt.close(fig)
155
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ fig, ax = sns.plt.subplots(figsize=(12, 8)) # Use sns.plt
158
+ sns.set_style("whitegrid")
159
+
160
+ hue_col = None
161
+ plot_data = df_compare_output.copy()
162
+
163
+ if isinstance(df_compare_output.index, pd.MultiIndex) and df_compare_output.index.nlevels >= 2:
164
+ gr.Info("Plotting with multiple item levels.")
165
+ # Prepare data for seaborn: reset index to use levels as columns
166
+ plot_data = df_compare_output.reset_index()
167
+ hue_col = df_compare_output.index.names[1] # Use the second level for hue
168
+ if hue_col is None or hue_col == "": # Handle unnamed index level
169
+ hue_col = "item_level_1"
170
+ plot_data.rename(columns={plot_data.columns[1]: hue_col}, inplace=True)
171
+
172
+ num_unique_hue = plot_data[hue_col].nunique()
173
+ palette = "viridis" # Default seaborn palette
174
+ if num_unique_hue > 10 : # If too many categories, don't use hue or use a simpler palette
175
+ palette = sns.color_palette("husl", num_unique_hue)
176
+
177
+
178
+ sns.scatterplot(x='actual', y='estimate', hue=hue_col if num_unique_hue <= 20 else None,
179
+ data=plot_data, ax=ax, s=20, alpha=0.7, palette=palette)
180
+ if hue_col and num_unique_hue > 1 and num_unique_hue <= 10:
181
+ ax.legend(title=hue_col)
182
+ elif num_unique_hue > 10:
183
+ ax.legend().set_visible(False) # Hide legend if too many items
184
+ else:
185
+ gr.Warning("Scatter plot data is not in the expected multi-index format or has fewer than 2 levels. Plotting raw actual vs estimate without hue.")
186
+ sns.scatterplot(x='actual', y='estimate', data=plot_data, ax=ax, s=20, alpha=0.7)
187
 
188
  ax.set_xlabel('Actual')
189
  ax.set_ylabel('Estimate')
190
+ ax.set_title(title)
191
+
192
+ # Draw identity line
193
+ lims = [
194
+ np.min([ax.get_xlim(), ax.get_ylim()]),
195
+ np.max([ax.get_xlim(), ax.get_ylim()]),
196
+ ]
197
+ if lims[0] != lims[1] and np.isfinite(lims[0]) and np.isfinite(lims[1]): # Check for valid limits
198
+ ax.plot(lims, lims, 'r-', linewidth=0.7, alpha=0.8, zorder=0)
199
+ ax.set_xlim(lims)
200
+ ax.set_ylim(lims)
201
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  buf = io.BytesIO()
203
+ sns.plt.savefig(buf, format='png', dpi=100)
204
  buf.seek(0)
205
  img = Image.open(buf)
206
+ sns.plt.close(fig)
207
  return img
208
 
209
+
210
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
211
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
212
+ """Main processing function - now accepts file paths"""
213
  try:
214
+ cfs = pd.read_excel(cashflow_base_path, index_col=0)
215
+ cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
216
+ cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ pol_data_full = pd.read_excel(policy_data_path, index_col=0)
219
+ required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
220
+ if all(col in pol_data_full.columns for col in required_cols):
221
+ pol_data = pol_data_full[required_cols]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  else:
223
+ gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
224
+ pol_data = pol_data_full
225
+
226
+ pvs = pd.read_excel(pv_base_path, index_col=0)
227
+ pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
228
+ pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
 
 
 
 
 
 
 
 
 
229
 
230
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
231
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
232
 
233
  results = {}
234
+
235
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
236
 
237
  # --- 1. Cashflow Calibration ---
238
+ cluster_cfs = Clusters(cfs)
239
+
240
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
241
+ results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
242
+
243
+ results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
244
+ results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
245
+ results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
246
+
247
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
248
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
249
 
 
250
  # --- 2. Policy Attribute Calibration ---
251
+ if not pol_data.empty and not pol_data.isnull().all().all() and (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True)).sum() != 0: # Check for actual variance
252
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
253
+ loc_vars_attrs = loc_vars_attrs.fillna(0) # Handle potential NaNs after division if a column is constant
 
 
 
 
 
 
 
 
 
 
 
 
254
  else:
255
+ gr.Warning("Policy data for attribute calibration is empty, all NaNs, or has no variance. Skipping attribute calibration plots.")
256
+ loc_vars_attrs = pol_data # or pd.DataFrame() if you want to ensure it's empty
257
+
258
+ if not loc_vars_attrs.empty:
259
+ cluster_attrs = Clusters(loc_vars_attrs)
260
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
261
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
262
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
263
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
264
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
265
  else:
266
  results['attr_total_cf_base'] = pd.DataFrame()
267
  results['attr_policy_attrs_total'] = pd.DataFrame()
268
  results['attr_total_pv_base'] = pd.DataFrame()
269
+ results['attr_cashflow_plot'] = plot_scatter_comparison(None, "Policy Attr. Calib. - Cashflows (Base) - No Data") # Generate blank plot
270
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(None, "Policy Attr. Calib. - Scatter - No Data")
271
+
272
 
273
  # --- 3. Present Value Calibration ---
 
274
  cluster_pvs = Clusters(pvs)
275
+
276
+ results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
277
+ results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
278
+
279
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
280
+ results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
281
+ results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
282
+
283
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
284
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
285
 
 
286
  # --- Summary Comparison Plot Data ---
287
  error_data = {}
288
+
289
+ def get_error_safe(compare_result, col_name=None):
290
+ if compare_result is None or compare_result.empty:
291
  return np.nan
292
+ if col_name and col_name in compare_result.index:
293
+ return abs(compare_result.loc[col_name, 'error'])
294
+ elif 'error' in compare_result.columns:
295
+ return abs(compare_result['error']).mean()
296
  else:
297
+ return np.nan # Should not happen if compare_result is valid
 
298
 
299
  key_pv_col = None
300
+ for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF', 'PV NET CF']: # Added more common names
301
+ if potential_col in pvs.columns:
302
+ key_pv_col = potential_col
303
+ break
304
+ # Case insensitive check
305
+ for col in pvs.columns:
306
+ if col.lower() == potential_col.lower():
307
+ key_pv_col = col
308
  break
309
+ if key_pv_col:
310
+ break
311
+
312
+ if not key_pv_col and not pvs.empty:
313
+ gr.Warning(f"Could not find a standard PV Net CF column in PV data. Using mean absolute error for all PV columns for summary. Columns available: {pvs.columns.tolist()}")
314
+
315
+
316
  error_data['CF Calib.'] = [
317
  get_error_safe(results.get('cf_pv_total_base'), key_pv_col),
318
  get_error_safe(results.get('cf_pv_total_lapse'), key_pv_col),
319
  get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
320
  ]
321
 
322
+ if not loc_vars_attrs.empty:
323
+ error_data['Attr Calib.'] = [
324
+ get_error_safe(results.get('attr_total_pv_base'), key_pv_col), # Assuming pvs is the right df here
325
+ get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), # Recalculate for lapse scenario with attr cluster
326
+ get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) # Recalculate for mort scenario with attr cluster
327
  ]
328
  else:
329
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
 
334
  get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
335
  ]
336
 
337
+ summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
338
 
339
+ fig_summary, ax_summary = sns.plt.subplots(figsize=(10, 6)) # Use sns.plt
340
  sns.set_style("whitegrid")
 
 
 
 
 
341
 
342
+ # Melt the DataFrame for Seaborn barplot
343
+ summary_plot_data = summary_df.reset_index().melt(
344
+ id_vars='index', var_name='Calibration Method', value_name='Absolute Error Rate'
345
+ )
346
+
347
+ sns.barplot(x='index', y='Absolute Error Rate', hue='Calibration Method', data=summary_plot_data, ax=ax_summary, palette="muted")
348
+
349
+ ax_summary.set_ylabel('Absolute Error Rate (0.1 = 10%)')
350
+ title_suffix = f' (Key PV Column: {key_pv_col})' if key_pv_col else ' (Mean Absolute Error of PVs)'
351
  ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
352
+ ax_summary.set_xlabel('Scenario')
353
  ax_summary.tick_params(axis='x', rotation=0)
354
+ ax_summary.legend(title='Calibration Method')
355
+ sns.plt.tight_layout()
356
+
 
357
  buf_summary = io.BytesIO()
358
+ sns.plt.savefig(buf_summary, format='png', dpi=100)
359
  buf_summary.seek(0)
360
  results['summary_plot'] = Image.open(buf_summary)
361
+ sns.plt.close(fig_summary)
362
 
363
  return results
364
 
365
  except FileNotFoundError as e:
366
+ gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
367
  return {"error": f"File not found: {e.filename}"}
368
  except KeyError as e:
369
+ gr.Error(f"A required column is missing from one of the excel files: {e}. Please check data format.")
370
+ return {"error": f"Missing column: {e}"}
 
 
371
  except ValueError as e:
372
+ gr.Error(f"ValueError during processing: {str(e)}. This might be due to empty data or data format issues (e.g. non-numeric data for clustering).")
373
+ return {"error": f"ValueError: {str(e)}"}
 
 
374
  except Exception as e:
 
375
  import traceback
376
+ print(traceback.format_exc()) # Print full traceback to console for debugging
377
+ gr.Error(f"An unexpected error occurred: {str(e)}. Check console for details.")
378
+ return {"error": f"Error processing files: {str(e)}"}
379
+
380
 
381
  def create_interface():
382
+ with gr.Blocks(theme=gr.themes.Soft(), title="Cluster Model Points Analysis") as demo: # Added a theme
383
  gr.Markdown("""
384
+ # Cluster Model Points Analysis 📊
385
+
386
  This application applies cluster analysis to model point selection for insurance portfolios.
387
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
388
+
389
  **Required Files (Excel .xlsx):**
390
+ - Cashflows - Base Scenario
391
+ - Cashflows - Lapse Stress (+50%)
392
+ - Cashflows - Mortality Stress (+15%)
393
+ - Policy Data (including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
394
+ - Present Values - Base Scenario
395
+ - Present Values - Lapse Stress
396
+ - Present Values - Mortality Stress
 
397
  """)
398
 
399
  with gr.Row():
400
  with gr.Column(scale=1):
401
+ gr.Markdown("### 📁 Upload Files or Load Examples")
402
+
403
+ load_example_btn = gr.Button("Load Example Data ✨", variant="secondary")
404
+
405
  with gr.Row():
406
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
407
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
 
412
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
413
  with gr.Row():
414
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
415
+
416
+ analyze_btn = gr.Button("Analyze Dataset 🚀", variant="primary", size="lg")
417
 
418
  with gr.Tabs():
419
  with gr.TabItem("📊 Summary"):
420
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison")
421
+
422
  with gr.TabItem("💸 Cashflow Calibration"):
423
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
424
  with gr.Row():
425
+ cf_total_base_table_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
426
+ cf_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
427
+ cf_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
428
+ cf_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
429
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
430
  with gr.Row():
431
+ cf_pv_total_base_out = gr.Dataframe(label="PVs - Base Total", wrap=True)
432
+ cf_pv_total_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total", wrap=True)
433
+ cf_pv_total_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total", wrap=True)
434
+
435
  with gr.TabItem("👤 Policy Attribute Calibration"):
436
  gr.Markdown("### Results: Using Policy Attributes as Calibration Variables")
437
  with gr.Row():
438
+ attr_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
439
+ attr_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
440
+ attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
441
+ attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
442
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
443
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total (All Shocks)", wrap=True) # Changed label for clarity
444
+
445
  with gr.TabItem("💰 Present Value Calibration"):
446
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
447
  with gr.Row():
448
+ pv_total_cf_base_out = gr.Dataframe(label="Overall Comparison - Base Scenario (Cashflows)", wrap=True, height=300)
449
+ pv_policy_attrs_total_out = gr.Dataframe(label="Overall Comparison - Policy Attributes", wrap=True, height=300)
450
+ pv_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
451
+ pv_scatter_pvs_base_out = gr.Image(label="Scatter Plot - Per-Cluster Present Values (Base Scenario)")
452
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
453
  with gr.Row():
454
+ pv_total_pv_base_out = gr.Dataframe(label="PVs - Base Total", wrap=True)
455
+ pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total", wrap=True)
456
+ pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total", wrap=True)
457
 
458
+ # --- Helper function to prepare outputs ---
459
  def get_all_output_components():
460
  return [
461
+ summary_plot_output,
462
+ # Cashflow Calib Outputs
463
+ cf_total_base_table_out, cf_policy_attrs_total_out,
464
+ cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
465
+ cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
466
+ # Attribute Calib Outputs
467
+ attr_total_cf_base_out, attr_policy_attrs_total_out,
468
+ attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
469
+ # PV Calib Outputs
470
+ pv_total_cf_base_out, pv_policy_attrs_total_out,
471
+ pv_cashflow_plot_out, pv_scatter_pvs_base_out,
472
+ pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
473
  ]
474
 
475
+ # --- Action for Analyze Button ---
476
+ def handle_analysis(f1, f2, f3, f4, f5, f6, f7, progress=gr.Progress(track_tqdm=True)):
477
  files = [f1, f2, f3, f4, f5, f6, f7]
478
+
479
  file_paths = []
480
+ file_labels = ["Cashflows - Base", "Cashflows - Lapse", "Cashflows - Mort",
481
+ "Policy Data", "PVs - Base", "PVs - Lapse", "PVs - Mort"]
482
+
483
+ for i, f_obj in enumerate(files):
484
+ if f_obj is None:
485
+ gr.Error(f"Missing file input for: {file_labels[i]}. Please upload all files or load examples.")
486
+ # Return empty/None for all outputs
487
+ return [None] * len(get_all_output_components())
488
+
489
+ if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
490
+ file_paths.append(f_obj.name)
491
+ elif isinstance(f_obj, str): # Already a path (from example load)
492
+ file_paths.append(f_obj)
493
+ else:
494
+ gr.Error(f"Invalid file input for {file_labels[i]}. Type: {type(f_obj)}")
495
  return [None] * len(get_all_output_components())
 
 
496
 
497
+ progress(0, desc="Starting Analysis...")
498
+ # This is a placeholder for actual progress tracking if process_files were to support it.
499
+ # For now, it just shows activity.
500
+ # You could break down process_files and update progress more granularly if needed.
501
+ for i in range(1, 6):
502
+ progress(i/5, desc=f"Processing Data Step {i}/5...") # Simulate progress
503
+ # time.sleep(0.2) # if you want to see the progress bar update
504
 
505
  results = process_files(*file_paths)
506
+ progress(1, desc="Analysis Complete!")
507
 
508
+ if "error" in results: # Error handled by process_files with gr.Error
 
 
509
  return [None] * len(get_all_output_components())
510
 
511
  return [
512
+ results.get('summary_plot'),
513
+ # CF Calib
514
+ results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
515
+ results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
516
+ results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
517
+ # Attr Calib
518
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
519
+ results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
520
+ # PV Calib
521
+ results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
522
+ results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
523
+ results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
524
  ]
525
 
526
  analyze_btn.click(
 
530
  outputs=get_all_output_components()
531
  )
532
 
533
+ # --- Action for Load Example Data Button ---
534
  def load_example_files():
535
+ # Create eg_data directory if it doesn't exist
536
+ if not os.path.exists(EXAMPLE_DATA_DIR):
537
+ os.makedirs(EXAMPLE_DATA_DIR)
538
+ gr.Warning(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there. App will likely fail analysis if files are missing.")
539
+
540
+ missing_files_info = []
541
+ for key, fp in EXAMPLE_FILES.items():
542
+ if not os.path.exists(fp):
543
+ missing_files_info.append(f"'{key}' (expected at '{fp}')")
544
+
545
+ if missing_files_info:
546
+ gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files_info)}. Please ensure they exist or upload files manually.")
547
+ return [None] * 7 # Return None for all file inputs
548
+
549
+ gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
550
+ return [ # Return the paths for the File components
551
+ gr.File(value=EXAMPLE_FILES["cashflow_base"]),
552
+ gr.File(value=EXAMPLE_FILES["cashflow_lapse"]),
553
+ gr.File(value=EXAMPLE_FILES["cashflow_mort"]),
554
+ gr.File(value=EXAMPLE_FILES["policy_data"]),
555
+ gr.File(value=EXAMPLE_FILES["pv_base"]),
556
+ gr.File(value=EXAMPLE_FILES["pv_lapse"]),
557
+ gr.File(value=EXAMPLE_FILES["pv_mort"])
 
 
 
 
 
 
 
 
 
 
 
558
  ]
559
 
560
  load_example_btn.click(
 
567
  return demo
568
 
569
  if __name__ == "__main__":
 
 
570
  if not os.path.exists(EXAMPLE_DATA_DIR):
571
  os.makedirs(EXAMPLE_DATA_DIR)
572
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there.")
573
+ print(f"Expected files in '{EXAMPLE_DATA_DIR}':")
574
+ for key, path in EXAMPLE_FILES.items():
575
+ print(f" - {key}: {os.path.basename(path)}") # Print just file name for cleaner output
576
 
577
  demo_app = create_interface()
578
  demo_app.launch()