alidenewade commited on
Commit
4f94e21
·
verified ·
1 Parent(s): 2f2f1dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -328
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min # r2_score is not used in the provided snippet.
6
  import matplotlib.pyplot as plt
7
- import matplotlib.cm
 
8
  import io
9
- import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
@@ -22,98 +23,98 @@ EXAMPLE_FILES = {
22
  }
23
 
24
  class Clusters:
25
- def __init__(self, loc_vars_df): # Expecting a pandas DataFrame
26
- # "Quantisize" by converting input DataFrame to float32 for KMeans.
27
- # This reduces precision, potentially speeding up calculations and lowering memory.
28
- # Results might have minor numerical differences compared to float64.
29
- # Ensure data is a C-contiguous NumPy array.
30
- if loc_vars_df.empty:
31
- # Handle empty DataFrame case to avoid errors with .values or astype
32
- # KMeans would fail anyway, but this prevents issues before that.
33
- loc_vars_np_float32 = np.array([], dtype=np.float32).reshape(0, loc_vars_df.shape[1] if loc_vars_df.shape[1] > 0 else 0)
34
- else:
35
- loc_vars_np_float32 = np.ascontiguousarray(loc_vars_df.astype(np.float32).values)
36
-
37
- # Initialize KMeans with algorithm="elkan" for potential speedup
38
- # and fit on the float32 data.
39
- self.kmeans = KMeans(
40
- n_clusters=1000,
41
- random_state=0,
42
- n_init=10,
43
- algorithm="elkan" # Added for speed optimization
44
- ).fit(loc_vars_np_float32)
45
-
46
- # cluster_centers_ will be float32 if fitted on float32 data.
47
- # Pass the same float32 NumPy array for distance calculations.
48
- closest, _ = pairwise_distances_argmin_min(
49
- self.kmeans.cluster_centers_,
50
- loc_vars_np_float32
51
- )
52
 
53
- self.rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
54
- self.rep_ids.name = 'policy_id'
55
- self.rep_ids.index.name = 'cluster_id'
 
56
 
57
- # policy_count is based on the number of items in the input data.
58
- # Use loc_vars_np_float32.shape[0] which is the number of rows.
59
- self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * loc_vars_np_float32.shape[0]}))['policy_count']
60
 
61
  def agg_by_cluster(self, df, agg=None):
62
  """Aggregate columns by cluster"""
63
  temp = df.copy()
64
- temp['cluster_id'] = self.kmeans.labels_
65
  temp = temp.set_index('cluster_id')
66
- agg = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns} if agg else "sum"
67
- return temp.groupby(temp.index).agg(agg)
 
 
 
68
 
69
  def extract_reps(self, df):
70
  """Extract the rows of representative policies"""
71
- # Ensure policy_id in df is of the same type as self.rep_ids if it's not already the index
72
- # Typically, df here will have 'policy_id' as its index as per original data.
73
- # If df's index is not 'policy_id', ensure 'policy_id' column exists and has compatible type.
74
- current_df_index_name = df.index.name
75
- # If 'policy_id' is not the index, reset it. Otherwise, use the index.
76
  if 'policy_id' not in df.columns and df.index.name != 'policy_id':
77
- # This case should ideally not happen if inputs are consistent
78
- # Forcing index to be named 'policy_id' if it's the policy identifier
79
- df_indexed = df.copy()
80
- if df_indexed.index.name is None: # Or some other logic to identify the policy_id column
81
- gr.Warning("DataFrame passed to extract_reps has no index name, assuming index is policy_id.")
82
- df_indexed.index.name = 'policy_id'
83
-
84
- temp = pd.merge(self.rep_ids, df_indexed.reset_index(), how='left', on='policy_id')
85
-
86
- elif 'policy_id' in df.columns and df.index.name == 'policy_id' and df.index.name in df.columns: # if policy_id is both index and a column
87
- temp = pd.merge(self.rep_ids, df, how='left', on='policy_id') # Merge on column if available
88
 
89
- elif df.index.name == 'policy_id':
90
- temp = pd.merge(self.rep_ids, df.reset_index(), how='left', on='policy_id')
91
-
92
- else: # 'policy_id' is a column, not the index
93
- temp = pd.merge(self.rep_ids, df.reset_index(drop=df.index.name is None), how='left', on='policy_id')
 
 
 
 
 
94
 
95
 
96
- temp.index.name = 'cluster_id' # The merge result's index is not cluster_id by default
97
- temp = temp.set_index(self.rep_ids.index) # Set index to be cluster_id from self.rep_ids
98
- return temp.drop('policy_id', axis=1, errors='ignore')
 
99
 
100
 
101
  def extract_and_scale_reps(self, df, agg=None):
102
  """Extract and scale the rows of representative policies"""
103
  extracted_df = self.extract_reps(df)
104
  if agg:
105
- cols = extracted_df.columns # Use columns from extracted_df
106
- mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
107
- mult.index = extracted_df.index # Align index
108
- return extracted_df.mul(mult)
 
 
 
 
 
 
 
 
 
109
  else:
110
- return extracted_df.mul(self.policy_count, axis=0)
 
 
 
 
 
 
111
 
112
  def compare(self, df, agg=None):
113
  """Returns a multi-indexed Dataframe comparing actual and estimate"""
114
  source = self.agg_by_cluster(df, agg)
115
  target = self.extract_and_scale_reps(df, agg)
116
- return pd.DataFrame({'actual': source.stack(), 'estimate':target.stack()})
 
 
 
 
 
 
117
 
118
  def compare_total(self, df, agg=None):
119
  """Aggregate df by columns"""
@@ -130,37 +131,35 @@ class Clusters:
130
  estimate_values = {}
131
 
132
  for col in df.columns: # Iterate over original df columns to ensure all are covered
133
- if col not in reps_unscaled.columns: # Column might not be in reps_unscaled if it was dropped or not selected
134
- if agg.get(col, 'sum') == 'mean':
135
- estimate_values[col] = np.nan # Or some other placeholder like 0, or actual.get(col, 0)
136
- else:
137
- estimate_values[col] = 0
138
- gr.Warning(f"Column '{col}' not found in representative policies output for 'compare_total'. Estimate will be 0/NaN.")
139
  continue
140
 
141
  if agg.get(col, 'sum') == 'mean':
142
- weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
143
- total_weight = self.policy_count.sum()
144
- estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
 
 
 
145
  else: # sum
146
- estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
147
-
148
  estimate = pd.Series(estimate_values)
149
-
150
- else: # Original logic if no agg is specified (all sum)
151
  actual = df.sum()
152
  estimate = self.extract_and_scale_reps(df).sum()
153
 
154
- # Ensure alignment for error calculation
155
- actual, estimate = actual.align(estimate, fill_value=0)
156
- error = np.where(actual != 0, estimate / actual - 1, 0)
157
 
158
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
159
 
160
-
161
- # --- Plotting functions (plot_cashflows_comparison, plot_scatter_comparison) remain unchanged ---
162
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
163
- """Create cashflow comparison plots"""
 
164
  if not cfs_list or not cluster_obj or not titles:
165
  return None
166
  num_plots = len(cfs_list)
@@ -173,20 +172,30 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
173
  fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
174
  axes = axes.flatten()
175
 
176
- for i, (df, title) in enumerate(zip(cfs_list, titles)):
177
  if i < len(axes):
178
- # Ensure df passed to compare_total is appropriate.
179
- # If df has policy_id as index, it matches expectations of downstream functions in Clusters.
180
- # If not, ensure policy_id is a column or handle appropriately.
181
- if df.index.name != 'policy_id' and 'policy_id' not in df.columns:
182
- gr.Warning(f"DataFrame for plot '{title}' does not have 'policy_id' as index or column. Results may be incorrect.")
183
 
184
- comparison = cluster_obj.compare_total(df.set_index('policy_id') if 'policy_id' in df.columns and df.index.name != 'policy_id' else df)
185
- comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
186
- axes[i].set_xlabel('Time')
187
- axes[i].set_ylabel('Value')
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- for j in range(i + 1, len(axes)):
190
  fig.delaxes(axes[j])
191
 
192
  plt.tight_layout()
@@ -198,47 +207,77 @@ def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
198
  return img
199
 
200
  def plot_scatter_comparison(df_compare_output, title):
201
- """Create scatter plot comparison from compare() output"""
 
 
 
202
  if df_compare_output is None or df_compare_output.empty:
203
- fig, ax = plt.subplots(figsize=(12, 8))
204
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
205
  ax.set_title(title)
206
- buf = io.BytesIO()
207
- plt.savefig(buf, format='png', dpi=100)
208
- buf.seek(0)
209
- img = Image.open(buf)
210
- plt.close(fig)
211
- return img
212
-
213
- fig, ax = plt.subplots(figsize=(12, 8))
214
-
215
- if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
216
  gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
217
- ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
 
218
  else:
219
- unique_levels = df_compare_output.index.get_level_values(1).unique()
220
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
221
-
222
- for item_level, color_val in zip(unique_levels, colors):
223
- subset = df_compare_output.xs(item_level, level=1)
224
- ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=str(item_level)) # Ensure label is string
225
- if len(unique_levels) > 1 and len(unique_levels) <= 20: # Increased legend item limit slightly
226
- ax.legend(title=str(df_compare_output.index.names[1]))
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  ax.set_xlabel('Actual')
229
  ax.set_ylabel('Estimate')
230
- ax.set_title(title)
231
- ax.grid(True)
232
-
233
- lims = [
234
- np.nanmin([ax.get_xlim(), ax.get_ylim()]), # Use nanmin/nanmax
235
- np.nanmax([ax.get_xlim(), ax.get_ylim()]),
236
- ]
237
- if lims[0] != lims[1] and np.isfinite(lims[0]) and np.isfinite(lims[1]): # Check for valid limits
238
- ax.plot(lims, lims, 'r-', linewidth=0.5)
239
- ax.set_xlim(lims)
240
- ax.set_ylim(lims)
241
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  buf = io.BytesIO()
243
  plt.savefig(buf, format='png', dpi=100)
244
  buf.seek(0)
@@ -246,56 +285,74 @@ def plot_scatter_comparison(df_compare_output, title):
246
  plt.close(fig)
247
  return img
248
 
249
- # --- Main processing function (process_files) ---
250
- # Ensure DataFrames passed to Clusters methods have 'policy_id' as index if expected.
251
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
252
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
253
  """Main processing function - now accepts file paths"""
254
  try:
255
- # Consider using engine='calamine' for faster Excel reading if available (pip install pandas[calamine])
256
- # e.g., cfs = pd.read_excel(cashflow_base_path, index_col=0, engine='calamine')
257
- cfs = pd.read_excel(cashflow_base_path, index_col=0)
258
- cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
259
- cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
260
-
261
- pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
263
 
264
- # Ensure index is named 'policy_id' if it's not already named, assuming index is the policy identifier
265
- for df in [cfs, cfs_lapse50, cfs_mort15, pol_data_full]:
266
- if df.index.name is None:
267
- df.index.name = 'policy_id'
268
- if 'policy_id' not in df.columns and df.index.name == 'policy_id': # Add policy_id as column if its only an index
269
- df.reset_index(inplace=True) # this makes policy_id a column
270
- df.set_index('policy_id', inplace=True) # and keeps it as index
271
-
272
- if all(col in pol_data_full.columns or col == pol_data_full.index.name for col in required_cols):
273
- # If policy_id is index, it won't be in columns. Adjust selection.
274
- cols_to_select = [col for col in required_cols if col in pol_data_full.columns]
275
- if pol_data_full.index.name in required_cols and pol_data_full.index.name not in cols_to_select:
276
- # This case is tricky; if an ID is part of required_cols and is the index.
277
- # For simplicity, assume required_cols are actual data columns.
278
- pass # Let it proceed, it might be handled by selection or error later.
279
-
280
- pol_data = pol_data_full[cols_to_select].copy() # Use .copy() to avoid SettingWithCopyWarning
281
- # If 'policy_id' was the index and required, it's implicitly handled or needs specific logic.
282
- # For K-Means, policy_id itself is usually not a feature.
283
  else:
284
- missing_req_cols = [col for col in required_cols if col not in pol_data_full.columns and col != pol_data_full.index.name]
285
- gr.Warning(f"Policy data might be missing required columns: {missing_req_cols}. Found: {pol_data_full.columns.tolist()}")
286
- pol_data = pol_data_full # Fallback, but ensure it's numeric for clustering/scaling
287
-
288
- pvs = pd.read_excel(pv_base_path, index_col=0)
289
- pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
290
- pvs_mort15 = pd.read_excel(pv_mort_path, index_col=0)
291
-
292
- for df in [pvs, pvs_lapse50, pvs_mort15]:
293
- if df.index.name is None:
294
- df.index.name = 'policy_id'
295
- if 'policy_id' not in df.columns and df.index.name == 'policy_id':
296
- df.reset_index(inplace=True)
297
- df.set_index('policy_id', inplace=True)
298
 
 
 
 
 
 
299
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
300
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
301
 
@@ -303,44 +360,44 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
303
 
304
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
305
 
306
- # DataFrames passed to Clusters should be policy_id indexed for .values to exclude it.
307
- # Or, select only feature columns before passing.
308
- # The Clusters class now expects a DataFrame and will use .values, so pass only feature columns.
309
- # If index is policy_id, df.values will not include it. This is good.
310
-
311
  # --- 1. Cashflow Calibration ---
312
- # Ensure 'cfs' DataFrame does not include 'policy_id' when .values is called in Clusters
313
- cluster_cfs = Clusters(cfs.reset_index().set_index('policy_id')) # Pass with policy_id as index
314
 
315
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
316
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
317
-
318
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
319
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
320
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
321
-
322
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
323
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
324
 
325
  # --- 2. Policy Attribute Calibration ---
326
- loc_vars_attrs = pd.DataFrame() # Initialize
327
- if not pol_data.empty:
328
- # Ensure pol_data is purely numeric for scaling and KMeans
329
- numeric_pol_data = pol_data.select_dtypes(include=np.number)
330
- if not numeric_pol_data.empty and not (numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True) == 0).all():
331
- loc_vars_attrs = (numeric_pol_data - numeric_pol_data.min(numeric_only=True)) / \
332
- (numeric_pol_data.max(numeric_only=True) - numeric_pol_data.min(numeric_only=True))
333
- loc_vars_attrs.index = numeric_pol_data.index # Preserve index
 
334
  else:
335
- gr.Warning("Policy data for attribute calibration is empty, non-numeric, or has no variance. Skipping attribute calibration content.")
336
- loc_vars_attrs = numeric_pol_data # or an empty DataFrame with original index
 
 
 
 
 
337
  else:
338
- gr.Warning("Policy data is empty. Skipping attribute calibration content.")
 
339
 
340
- if not loc_vars_attrs.empty:
341
- cluster_attrs = Clusters(loc_vars_attrs.reset_index().set_index('policy_id')) # Pass with policy_id as index
342
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
343
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
344
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
345
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
346
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
@@ -348,41 +405,39 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
348
  results['attr_total_cf_base'] = pd.DataFrame()
349
  results['attr_policy_attrs_total'] = pd.DataFrame()
350
  results['attr_total_pv_base'] = pd.DataFrame()
351
- results['attr_cashflow_plot'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data') # Empty plot
352
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - Cashflows (Base) - No Data')
353
 
354
 
355
  # --- 3. Present Value Calibration ---
356
- cluster_pvs = Clusters(pvs.reset_index().set_index('policy_id')) # Pass with policy_id as index
357
 
358
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
359
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
360
-
361
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
362
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
363
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
364
-
365
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
366
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
367
 
368
  # --- Summary Comparison Plot Data ---
369
  error_data = {}
370
-
371
- def get_error_safe(compare_result, col_name=None):
372
- if compare_result is None or compare_result.empty or 'error' not in compare_result.columns: # Check if None
373
  return np.nan
374
- if col_name and col_name in compare_result.index:
375
- return abs(compare_result.loc[col_name, 'error'])
376
- else:
377
- return abs(compare_result['error']).mean()
 
 
 
 
378
 
379
  key_pv_col = None
380
- # Use pvs.columns (which should be only feature columns after reset_index().set_index())
381
- # Or, use the original pvs DataFrame if it's guaranteed to have the PV_NetCF column.
382
- # For safety, check in the original pvs DataFrame which has not been stripped of columns.
383
- original_pvs_cols = pd.read_excel(pv_base_path).columns # Quick read just for columns
384
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
385
- if potential_col in original_pvs_cols: # Check against original columns
386
  key_pv_col = potential_col
387
  break
388
 
@@ -392,11 +447,12 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
392
  get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
393
  ]
394
 
395
- if not loc_vars_attrs.empty:
396
- error_data['Attr Calib.'] = [
397
- get_error_safe(results.get('attr_total_pv_base'), key_pv_col), # This was pvs, should be fine
398
- get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col), # Re-calculate for pvs_lapse50
399
- get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col) # Re-calculate for pvs_mort15
 
400
  ]
401
  else:
402
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
@@ -407,17 +463,26 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
407
  get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
408
  ]
409
 
410
- summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
411
 
412
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
413
- summary_df.plot(kind='bar', ax=ax_summary, grid=True)
 
 
 
 
 
 
 
414
  ax_summary.set_ylabel('Absolute Error Rate')
415
- title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
416
  ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
417
  ax_summary.tick_params(axis='x', rotation=0)
418
- ax_summary.legend(title='Calibration Method')
 
 
 
419
  plt.tight_layout()
420
-
421
  buf_summary = io.BytesIO()
422
  plt.savefig(buf_summary, format='png', dpi=100)
423
  buf_summary.seek(0)
@@ -430,15 +495,22 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
430
  gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
431
  return {"error": f"File not found: {e.filename}"}
432
  except KeyError as e:
433
- # Check if the KeyError is from trying to access a column that became an index
434
- gr.Error(f"A required column or index is missing or misnamed: {e}. Please check data format and ensure 'policy_id' is correctly handled as index for feature dataframes.")
435
  return {"error": f"Missing column/index: {e}"}
 
 
 
 
 
436
  except Exception as e:
 
437
  import traceback
438
- gr.Error(f"Error processing files: {str(e)}. Trace: {traceback.format_exc()}")
439
- return {"error": f"Error processing files: {str(e)}"}
440
 
441
- # --- Gradio interface creation (create_interface, etc.) remains unchanged ---
 
 
442
  def create_interface():
443
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
444
  gr.Markdown("""
@@ -448,15 +520,13 @@ def create_interface():
448
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
449
 
450
  **Required Files (Excel .xlsx):**
451
- - Cashflows - Base Scenario (index = policy_id, columns = time periods)
452
- - Cashflows - Lapse Stress (+50%) (index = policy_id)
453
- - Cashflows - Mortality Stress (+15%) (index = policy_id)
454
- - Policy Data (index = policy_id, including 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth' as columns)
455
- - Present Values - Base Scenario (index = policy_id, columns = PV components like 'PV_NetCF')
456
- - Present Values - Lapse Stress (index = policy_id)
457
- - Present Values - Mortality Stress (index = policy_id)
458
-
459
- *Note: Ensure 'policy_id' is the index for all input files for correct processing.*
460
  """)
461
 
462
  with gr.Row():
@@ -503,11 +573,7 @@ def create_interface():
503
  attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
504
  attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
505
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
506
- with gr.Row(): # Changed to Row for consistency
507
- attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
508
- # Added placeholders for other scenarios if they were intended
509
- # attr_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
510
- # attr_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
511
 
512
  with gr.TabItem("💰 Present Value Calibration"):
513
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
@@ -522,62 +588,46 @@ def create_interface():
522
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
523
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
524
 
525
- # --- Helper function to prepare outputs ---
526
  def get_all_output_components():
527
  return [
528
  summary_plot_output,
529
- # Cashflow Calib Outputs
530
  cf_total_base_table_out, cf_policy_attrs_total_out,
531
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
532
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
533
- # Attribute Calib Outputs
534
  attr_total_cf_base_out, attr_policy_attrs_total_out,
535
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
536
- # PV Calib Outputs
537
  pv_total_cf_base_out, pv_policy_attrs_total_out,
538
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
539
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
540
  ]
541
 
542
- # --- Action for Analyze Button ---
543
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
544
  files = [f1, f2, f3, f4, f5, f6, f7]
545
-
546
  file_paths = []
547
- # Check if any FileData object is None (no file uploaded for a slot)
548
- if any(f_obj is None for f_obj in files):
549
- # Attempt to load from EXAMPLE_FILES if any input is missing
550
- # This logic might be complex if mixing examples and uploads.
551
- # For now, strict: all files must be present.
552
- gr.Error("Missing file input for one or more fields. Please upload all required files or load the complete example dataset.")
553
- return [None] * len(get_all_output_components())
554
-
555
- for i, f_obj in enumerate(files):
556
- # f_obj is TempFilePath (older Gradio) or FileData (newer) or str (from example load)
557
- if hasattr(f_obj, 'name') and isinstance(f_obj.name, str): # Gradio FileData or similar
558
- file_paths.append(f_obj.name)
559
- elif isinstance(f_obj, str): # Path from example load
560
- file_paths.append(f_obj)
561
- else: # Should not happen if inputs are Files or paths
562
- gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
563
  return [None] * len(get_all_output_components())
 
 
 
 
564
 
565
  results = process_files(*file_paths)
566
 
567
- if "error" in results : # Check if process_files returned an error dict
568
- # Error already shown by gr.Error in process_files
569
- return [None] * len(get_all_output_components())
570
-
 
571
  return [
572
  results.get('summary_plot'),
573
- # CF Calib
574
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
575
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
576
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
577
- # Attr Calib
578
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
579
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
580
- # PV Calib
581
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
582
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
583
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
@@ -590,50 +640,50 @@ def create_interface():
590
  outputs=get_all_output_components()
591
  )
592
 
593
- # --- Action for Load Example Data Button ---
594
  def load_example_files():
595
- # Create dummy example files if they don't exist for demonstration if needed
596
- # For this exercise, we assume they exist or user is warned.
597
- os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True) # Ensure dir exists
598
-
599
- missing_files = []
600
  for key, fp in EXAMPLE_FILES.items():
601
  if not os.path.exists(fp):
602
- missing_files.append(fp)
603
- # Create a minimal dummy Excel file if it's missing
604
  try:
605
- dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
606
- if "cashflow" in key or "pv" in key: # Time series like
607
- dummy_df_data = {'policy_id': [1,2,3], '0': [1,2,3], '1': [4,5,6]}
 
 
608
  elif "policy_data" in key:
609
- dummy_df_data = {'policy_id': [1,2,3], 'age_at_entry': [20,30,40], 'policy_term': [10,20,15],
610
- 'sum_assured': [1000,2000,1500], 'duration_mth': [5,10,7]}
611
-
612
- dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
613
- dummy_df.to_excel(fp)
614
- gr.Warning(f"Example file '{fp}' was missing and a dummy file has been created. Results may not be meaningful.")
 
 
 
 
 
 
 
 
615
  except Exception as e:
616
- gr.Warning(f"Could not create dummy file for {fp}: {e}")
 
617
 
618
-
619
- if missing_files and not all(os.path.exists(fp) for fp in EXAMPLE_FILES.values()): # Re-check after dummy creation attempt
620
- # If still missing after trying to create dummies
621
- gr.Error(f"Critical example data files are missing from '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist or check permissions.")
622
- return [None] * 7 # Return None for all file inputs
623
 
624
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
625
- # Return the string paths for the file components
626
  return [
627
- gr.File(value=EXAMPLE_FILES["cashflow_base"], Labeled_input=cashflow_base_input.label),
628
- gr.File(value=EXAMPLE_FILES["cashflow_lapse"], Labeled_input=cashflow_lapse_input.label),
629
- gr.File(value=EXAMPLE_FILES["cashflow_mort"], Labeled_input=cashflow_mort_input.label),
630
- gr.File(value=EXAMPLE_FILES["policy_data"], Labeled_input=policy_data_input.label),
631
- gr.File(value=EXAMPLE_FILES["pv_base"], Labeled_input=pv_base_input.label),
632
- gr.File(value=EXAMPLE_FILES["pv_lapse"], Labeled_input=pv_lapse_input.label),
633
- gr.File(value=EXAMPLE_FILES["pv_mort"], Labeled_input=pv_mort_input.label)
634
  ]
635
 
636
-
637
  load_example_btn.click(
638
  load_example_files,
639
  inputs=[],
@@ -646,30 +696,7 @@ def create_interface():
646
  if __name__ == "__main__":
647
  if not os.path.exists(EXAMPLE_DATA_DIR):
648
  os.makedirs(EXAMPLE_DATA_DIR)
649
- print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or they will be generated as dummies.")
650
-
651
- # Simple check and dummy file creation for example data if not present
652
- for key, fp in EXAMPLE_FILES.items():
653
- if not os.path.exists(fp):
654
- print(f"Example file {fp} not found. Attempting to create a dummy file.")
655
- try:
656
- dummy_df_data = {'policy_id': [1,2,3], 'col1': [0.1,0.2,0.3], 'col2':[10,20,30]}
657
- if "cashflow" in key or "pv" in key:
658
- dummy_df_data = {f'{i}':np.random.rand(3) for i in range(10)} # 10 time periods
659
- dummy_df_data['policy_id'] = [f'P{j}' for j in range(3)]
660
- elif "policy_data" in key:
661
- dummy_df_data = {'policy_id': [f'P{j}' for j in range(3)],
662
- 'age_at_entry': np.random.randint(20, 50, 3),
663
- 'policy_term': np.random.randint(10, 30, 3),
664
- 'sum_assured': np.random.randint(10000, 50000, 3),
665
- 'duration_mth': np.random.randint(1, 120, 3)}
666
-
667
- dummy_df = pd.DataFrame(dummy_df_data).set_index('policy_id')
668
- dummy_df.to_excel(fp)
669
- print(f"Dummy file for '{fp}' created.")
670
- except Exception as e:
671
- print(f"Could not create dummy file for {fp}: {e}")
672
-
673
 
674
  demo_app = create_interface()
675
  demo_app.launch()
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score # r2_score is not used but kept from original
6
  import matplotlib.pyplot as plt
7
+ # import matplotlib.cm # No longer explicitly needed for rainbow
8
+ import seaborn as sns # Added Seaborn
9
  import io
10
+ import os
11
  from PIL import Image
12
 
13
  # Define the paths for example data
 
23
  }
24
 
25
  class Clusters:
26
+ def __init__(self, loc_vars):
27
+ # loc_vars is expected to be a DataFrame for cfs, loc_vars_attrs, pvs
28
+ # For KMeans, we need a NumPy array. If loc_vars is a DataFrame, .values extracts the data.
29
+ if isinstance(loc_vars, pd.DataFrame):
30
+ loc_vars_np = np.ascontiguousarray(loc_vars.values)
31
+ else: # If it's already a NumPy array (e.g. from previous processing not shown)
32
+ loc_vars_np = np.ascontiguousarray(loc_vars)
33
+
34
+ self.kmeans = KMeans(n_clusters=1000, random_state=0, n_init=10).fit(loc_vars_np)
35
+ closest, _ = pairwise_distances_argmin_min(self.kmeans.cluster_centers_, loc_vars_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ rep_ids = pd.Series(data=(closest + 1)) # 0-based to 1-based indexes
38
+ rep_ids.name = 'policy_id'
39
+ rep_ids.index.name = 'cluster_id' # This index represents the cluster number (0 to 999)
40
+ self.rep_ids = rep_ids
41
 
42
+ # policy_count should be based on the length of the input data used for clustering
43
+ self.policy_count = self.agg_by_cluster(pd.DataFrame({'policy_count': [1] * len(loc_vars_np)}))['policy_count']
44
+
45
 
46
  def agg_by_cluster(self, df, agg=None):
47
  """Aggregate columns by cluster"""
48
  temp = df.copy()
49
+ temp['cluster_id'] = self.kmeans.labels_ # labels_ are 0-indexed cluster assignments
50
  temp = temp.set_index('cluster_id')
51
+ agg_dict = {c: (agg[c] if agg and c in agg else 'sum') for c in temp.columns if c != 'cluster_id'} if agg else "sum"
52
+ if not agg_dict: # handles case where temp has only cluster_id or agg makes agg_dict empty
53
+ return pd.DataFrame(index=temp.index.unique()) # return empty DF with cluster_id index
54
+ return temp.groupby(level='cluster_id').agg(agg_dict)
55
+
56
 
57
  def extract_reps(self, df):
58
  """Extract the rows of representative policies"""
59
+ # df is expected to have policy_id as its index or as a column
 
 
 
 
60
  if 'policy_id' not in df.columns and df.index.name != 'policy_id':
61
+ raise ValueError("DataFrame for extract_reps must have 'policy_id' as index or column.")
 
 
 
 
 
 
 
 
 
 
62
 
63
+ df_to_merge = df.reset_index() if df.index.name == 'policy_id' else df.copy()
64
+
65
+ # Ensure policy_id column exists after reset_index or in copy
66
+ if 'policy_id' not in df_to_merge.columns:
67
+ # This case implies policy_id was the index but reset_index didn't create it (e.g. unnamed index)
68
+ # This should be handled by input data prep: ensure policy_id is a named index or a column.
69
+ # For robustness, if original df had named index 'policy_id', reset_index works.
70
+ # If it was an unnamed index that is policy_id, it's more problematic.
71
+ # Assuming 'policy_id' is present in df_to_merge now.
72
+ pass
73
 
74
 
75
+ temp = pd.merge(self.rep_ids.reset_index(), df_to_merge, how='left', on='policy_id')
76
+ # temp now has 'cluster_id' from rep_ids and other columns from df_to_merge
77
+ temp = temp.set_index('cluster_id')
78
+ return temp.drop(columns=['policy_id'], errors='ignore')
79
 
80
 
81
  def extract_and_scale_reps(self, df, agg=None):
82
  """Extract and scale the rows of representative policies"""
83
  extracted_df = self.extract_reps(df)
84
  if agg:
85
+ # Ensure we only try to multiply columns that exist in extracted_df
86
+ cols_to_multiply = [col for col in df.columns if col in extracted_df.columns]
87
+ mult = pd.DataFrame({
88
+ c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1)
89
+ for c in cols_to_multiply
90
+ })
91
+ mult.index = extracted_df.index # Align index for multiplication
92
+
93
+ # Only multiply existing columns
94
+ result_df = extracted_df.copy()
95
+ for col in cols_to_multiply:
96
+ result_df[col] = extracted_df[col].mul(mult[col])
97
+ return result_df
98
  else:
99
+ # Scale all numeric columns in extracted_df
100
+ numeric_cols = extracted_df.select_dtypes(include=np.number).columns
101
+ result_df = extracted_df.copy()
102
+ for col in numeric_cols:
103
+ result_df[col] = extracted_df[col].mul(self.policy_count, axis=0)
104
+ return result_df
105
+
106
 
107
  def compare(self, df, agg=None):
108
  """Returns a multi-indexed Dataframe comparing actual and estimate"""
109
  source = self.agg_by_cluster(df, agg)
110
  target = self.extract_and_scale_reps(df, agg)
111
+
112
+ # Ensure consistent columns for stacking, could be an issue if agg is selective
113
+ common_columns = source.columns.intersection(target.columns)
114
+ source_stacked = source[common_columns].stack()
115
+ target_stacked = target[common_columns].stack()
116
+
117
+ return pd.DataFrame({'actual': source_stacked, 'estimate': target_stacked})
118
 
119
  def compare_total(self, df, agg=None):
120
  """Aggregate df by columns"""
 
131
  estimate_values = {}
132
 
133
  for col in df.columns: # Iterate over original df columns to ensure all are covered
134
+ if col not in reps_unscaled.columns:
135
+ estimate_values[col] = np.nan # Column not in representative policies
 
 
 
 
136
  continue
137
 
138
  if agg.get(col, 'sum') == 'mean':
139
+ if self.policy_count.sum() > 0:
140
+ weighted_sum = (reps_unscaled[col].astype(float) * self.policy_count.astype(float)).sum()
141
+ total_weight = self.policy_count.sum()
142
+ estimate_values[col] = weighted_sum / total_weight
143
+ else:
144
+ estimate_values[col] = np.nan # Avoid division by zero
145
  else: # sum
146
+ estimate_values[col] = (reps_unscaled[col].astype(float) * self.policy_count.astype(float)).sum()
 
147
  estimate = pd.Series(estimate_values)
148
+ else:
 
149
  actual = df.sum()
150
  estimate = self.extract_and_scale_reps(df).sum()
151
 
152
+ actual, estimate = actual.align(estimate, fill_value=0) # Align before calculating error
153
+ error = np.where(actual != 0, (estimate / actual) - 1, 0) # estimate/actual can be NaN if actual is 0
154
+ error = np.nan_to_num(error, nan=0.0) # Replace NaNs from 0/0 with 0
155
 
156
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
157
 
158
+ ## Plotting Functions (Modified for Seaborn)
159
+ ---
160
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
161
+ """Create cashflow comparison plots using Seaborn"""
162
+ sns.set_style("whitegrid") # Apply Seaborn styling
163
  if not cfs_list or not cluster_obj or not titles:
164
  return None
165
  num_plots = len(cfs_list)
 
172
  fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
173
  axes = axes.flatten()
174
 
175
+ for i, (df_orig, title) in enumerate(zip(cfs_list, titles)):
176
  if i < len(axes):
177
+ ax = axes[i]
178
+ # Assuming df_orig has policy_id as index, or it's handled before compare_total
179
+ comparison_df = cluster_obj.compare_total(df_orig)
 
 
180
 
181
+ # Prepare data for Seaborn lineplot (long format)
182
+ plot_data = comparison_df[['actual', 'estimate']].copy()
183
+ # Assuming the index of comparison_df represents 'Time'
184
+ plot_data['Time'] = plot_data.index.astype(str) # Ensure Time is string for categorical plotting if not truly numeric
185
+ try: # If Time can be numeric, use it as such.
186
+ plot_data['Time'] = pd.to_numeric(plot_data['Time'])
187
+ except ValueError:
188
+ pass # Keep as string if not convertible
189
+
190
+ plot_data_melted = plot_data.melt(id_vars='Time', var_name='Legend', value_name='Value')
191
+
192
+ sns.lineplot(x='Time', y='Value', hue='Legend', data=plot_data_melted, ax=ax, errorbar=None)
193
+ ax.set_title(title)
194
+ ax.set_xlabel('Time')
195
+ ax.set_ylabel('Value')
196
+ # ax.grid(True) # whitegrid style includes a grid
197
 
198
+ for j in range(i + 1, len(axes)): # Hide any unused subplots
199
  fig.delaxes(axes[j])
200
 
201
  plt.tight_layout()
 
207
  return img
208
 
209
  def plot_scatter_comparison(df_compare_output, title):
210
+ """Create scatter plot comparison from compare() output using Seaborn"""
211
+ sns.set_style("whitegrid")
212
+ fig, ax = plt.subplots(figsize=(12, 8))
213
+
214
  if df_compare_output is None or df_compare_output.empty:
 
215
  ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
216
  ax.set_title(title)
217
+ elif not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
 
 
 
 
 
 
 
 
 
218
  gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
219
+ sns.scatterplot(x='actual', y='estimate', data=df_compare_output, s=25, alpha=0.7, ax=ax, legend=False)
220
+ ax.set_title(title)
221
  else:
222
+ plot_data = df_compare_output.reset_index()
223
+ hue_col_name = df_compare_output.index.names[1]
224
+ # Ensure the hue column is treated as categorical by converting to string
225
+ plot_data[hue_col_name] = plot_data[hue_col_name].astype(str)
226
+
227
+ unique_levels = plot_data[hue_col_name].nunique()
228
+ show_legend_flag = "auto"
229
+ if unique_levels == 1:
230
+ show_legend_flag = False
231
+ elif unique_levels > 10:
232
+ show_legend_flag = False
233
+ gr.Warning(f"Warning: Too many unique values ({unique_levels}) in '{hue_col_name}' for scatter plot legend. Legend hidden.")
234
+
235
+ sns.scatterplot(x='actual', y='estimate', hue=hue_col_name, data=plot_data,
236
+ s=25, alpha=0.7, ax=ax, legend=show_legend_flag)
237
+ ax.set_title(title)
238
+ if show_legend_flag == True and ax.get_legend() is not None:
239
+ ax.get_legend().set_title(str(hue_col_name))
240
+ elif show_legend_flag == "auto" and ax.get_legend() is not None: # Seaborn decided to show it
241
+ ax.get_legend().set_title(str(hue_col_name))
242
+
243
 
244
  ax.set_xlabel('Actual')
245
  ax.set_ylabel('Estimate')
246
+ # ax.grid(True) # whitegrid includes it
247
+
248
+ # Draw identity line
249
+ # Must draw after scatterplot to get correct limits
250
+ # Delay lims calculation until after plot, ensure data exists
251
+ if not (df_compare_output is None or df_compare_output.empty):
252
+ # Get limits from data if axes limits are too wide or default
253
+ # This ensures the identity line is relevant to the plotted data
254
+ all_values = pd.concat([plot_data['actual'], plot_data['estimate']]).dropna() if 'plot_data' in locals() else \
255
+ pd.concat([df_compare_output['actual'], df_compare_output['estimate']]).dropna()
256
+
257
+ if not all_values.empty:
258
+ min_val = all_values.min()
259
+ max_val = all_values.max()
260
+
261
+ # Use current axis limits if they are tighter than data range (e.g., user zoomed)
262
+ # But if they are default (-0.05 to 0.05 for empty data), use data range.
263
+ ax_xlims = ax.get_xlim()
264
+ ax_ylims = ax.get_ylim()
265
+
266
+ plot_min = np.nanmin([min_val, ax_xlims[0], ax_ylims[0]])
267
+ plot_max = np.nanmax([max_val, ax_xlims[1], ax_ylims[1]])
268
+
269
+ # Handle cases where min and max might be too close or NaN
270
+ if np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min < plot_max:
271
+ ax.plot([plot_min, plot_max], [plot_min, plot_max], 'r-', linewidth=0.7, alpha=0.8, zorder=0)
272
+ ax.set_xlim(plot_min, plot_max)
273
+ ax.set_ylim(plot_min, plot_max)
274
+ elif np.isfinite(plot_min) and np.isfinite(plot_max) and plot_min == plot_max: # Single point
275
+ margin = abs(plot_min * 0.1) if plot_min != 0 else 0.1
276
+ ax.plot([plot_min], [plot_min], 'ro') # Mark the point
277
+ ax.set_xlim(plot_min - margin, plot_min + margin)
278
+ ax.set_ylim(plot_min - margin, plot_min + margin)
279
+
280
+
281
  buf = io.BytesIO()
282
  plt.savefig(buf, format='png', dpi=100)
283
  buf.seek(0)
 
285
  plt.close(fig)
286
  return img
287
 
288
+ ## Main Processing and Gradio UI (Largely Unchanged)
289
+ ---
290
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
291
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
292
  """Main processing function - now accepts file paths"""
293
  try:
294
+ # Ensure 'policy_id' is the index for dataframes used in clustering/comparison
295
+ def read_and_prep_excel(path, set_policy_id_index=True):
296
+ df = pd.read_excel(path) # Read first, then set index
297
+ if 'policy_id' not in df.columns:
298
+ # Try to find it in unnamed index columns if any, or assume first column
299
+ # This is risky; ideally, 'policy_id' is an explicit column name
300
+ gr.Warning(f"'policy_id' column not found in {os.path.basename(path)}. Attempting to use first column or existing index.")
301
+ if df.columns[0].lower() == 'policy_id' or 'policyid' in df.columns[0].lower():
302
+ df.rename(columns={df.columns[0]: 'policy_id'}, inplace=True)
303
+ # Or if it is in the index already but unnamed
304
+ elif df.index.name is None and len(df.index) == len(df): # A heuristic
305
+ pass # keep as is, will try to use index later
306
+ else: # Fallback if no clear policy_id column found and index is not it
307
+ gr.Error(f"Cannot reliably find 'policy_id' in {os.path.basename(path)}.")
308
+ # For this example, let's assume files WILL have policy_id column or as first column
309
+ # This part needs robust handling based on expected file structures.
310
+ # If it's always index_col=0 as in original:
311
+ df = pd.read_excel(path, index_col=0)
312
+ if df.index.name != 'policy_id': # if index_col=0 was not named 'policy_id'
313
+ df.index.name = 'policy_id' # Name it 'policy_id'
314
+ return df.reset_index() # Make policy_id a column then set as index
315
+
316
+ if set_policy_id_index:
317
+ return df.set_index('policy_id')
318
+ return df
319
+
320
+ cfs = read_and_prep_excel(cashflow_base_path).select_dtypes(include=np.number)
321
+ cfs_lapse50 = read_and_prep_excel(cashflow_lapse_path).select_dtypes(include=np.number)
322
+ cfs_mort15 = read_and_prep_excel(cashflow_mort_path).select_dtypes(include=np.number)
323
+
324
+ pol_data_full_raw = read_and_prep_excel(policy_data_path, set_policy_id_index=False)
325
+ # Ensure the correct columns are selected for pol_data
326
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
327
 
328
+ # Check if required_cols exist, case-insensitively, and normalize names
329
+ rename_map = {}
330
+ available_cols_lower = {col.lower(): col for col in pol_data_full_raw.columns}
331
+ for req_col in required_cols:
332
+ if req_col.lower() in available_cols_lower:
333
+ rename_map[available_cols_lower[req_col.lower()]] = req_col # Map original to standardized
334
+ pol_data_full_renamed = pol_data_full_raw.rename(columns=rename_map)
335
+
336
+ if all(col in pol_data_full_renamed.columns for col in required_cols):
337
+ pol_data = pol_data_full_renamed.set_index('policy_id')[required_cols].select_dtypes(include=np.number)
 
 
 
 
 
 
 
 
 
338
  else:
339
+ missing = [col for col in required_cols if col not in pol_data_full_renamed.columns]
340
+ gr.Warning(f"Policy data might be missing required columns: {missing}. Found: {pol_data_full_renamed.columns.tolist()}")
341
+ # Fallback: use all numeric columns if required are missing, set policy_id as index
342
+ pol_data = pol_data_full_renamed.set_index('policy_id').select_dtypes(include=np.number)
343
+ if pol_data.empty and not pol_data_full_renamed.select_dtypes(include=np.number).empty:
344
+ gr.Warning("Policy data became empty after trying to select numeric types with policy_id index. Check input.")
345
+
346
+
347
+ pvs = read_and_prep_excel(pv_base_path).select_dtypes(include=np.number)
348
+ pvs_lapse50 = read_and_prep_excel(pv_lapse_path).select_dtypes(include=np.number)
349
+ pvs_mort15 = read_and_prep_excel(pv_mort_path).select_dtypes(include=np.number)
 
 
 
350
 
351
+ # DataFrames for Clusters class should not include the policy_id if it's an index
352
+ # The class constructor expects features only (typically a DataFrame where .values gives numeric data)
353
+ # The current read_and_prep_excel sets policy_id as index. This is fine.
354
+ # KMeans will be called on df.values implicitly.
355
+
356
  cfs_list = [cfs, cfs_lapse50, cfs_mort15]
357
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
358
 
 
360
 
361
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
362
 
 
 
 
 
 
363
  # --- 1. Cashflow Calibration ---
364
+ # Pass DataFrame with features only. If policy_id is index, df.values is correct.
365
+ cluster_cfs = Clusters(cfs)
366
 
367
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
368
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
369
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
370
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
371
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
372
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
373
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
374
 
375
  # --- 2. Policy Attribute Calibration ---
376
+ loc_vars_attrs_input = pol_data # pol_data is already features with policy_id as index
377
+ if not loc_vars_attrs_input.empty:
378
+ # Standardize policy attributes if there's variance
379
+ min_vals = loc_vars_attrs_input.min()
380
+ max_vals = loc_vars_attrs_input.max()
381
+ range_vals = max_vals - min_vals
382
+ if (range_vals == 0).all(): # No variance
383
+ gr.Warning("Policy data for attribute calibration has no variance. Using original values (may lead to poor clustering if scales differ).")
384
+ loc_vars_attrs_scaled = loc_vars_attrs_input
385
  else:
386
+ # Scale only columns with variance, keep others as is (or handle as 0 if appropriate)
387
+ loc_vars_attrs_scaled = loc_vars_attrs_input.copy()
388
+ for col in range_vals.index:
389
+ if range_vals[col] > 1e-9: # Check for non-zero range with tolerance
390
+ loc_vars_attrs_scaled[col] = (loc_vars_attrs_input[col] - min_vals[col]) / range_vals[col]
391
+ else: # if no variance, scaled value is 0 or 0.5 (or original)
392
+ loc_vars_attrs_scaled[col] = 0.0 # Or np.nan, or keep original: loc_vars_attrs_input[col]
393
  else:
394
+ gr.Warning("Policy data for attribute calibration is empty. Skipping attribute calibration plots.")
395
+ loc_vars_attrs_scaled = pd.DataFrame(index=pol_data.index) # Empty DF with correct index
396
 
397
+ if not loc_vars_attrs_scaled.empty:
398
+ cluster_attrs = Clusters(loc_vars_attrs_scaled) # Pass the scaled data
399
  results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
400
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs) # Compare against original pol_data
401
  results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
402
  results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
403
  results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
 
405
  results['attr_total_cf_base'] = pd.DataFrame()
406
  results['attr_policy_attrs_total'] = pd.DataFrame()
407
  results['attr_total_pv_base'] = pd.DataFrame()
408
+ results['attr_cashflow_plot'] = plot_cashflows_comparison([], None, []) # Empty plot
409
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(pd.DataFrame(), 'Policy Attr. Calib. - No Data')
410
 
411
 
412
  # --- 3. Present Value Calibration ---
413
+ cluster_pvs = Clusters(pvs)
414
 
415
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
416
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
417
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
418
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
419
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
 
420
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
421
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
422
 
423
  # --- Summary Comparison Plot Data ---
424
  error_data = {}
425
+ def get_error_safe(compare_result_df, col_name=None):
426
+ if compare_result_df is None or compare_result_df.empty or 'error' not in compare_result_df.columns:
 
427
  return np.nan
428
+ # Ensure col_name, if provided, is actually an index in the DataFrame
429
+ # compare_result_df has an index of column names of the original data (e.g. PV_NetCF)
430
+ if col_name and col_name in compare_result_df.index:
431
+ error_val = compare_result_df.loc[col_name, 'error']
432
+ return abs(error_val) if pd.notna(error_val) else np.nan
433
+ else: # Mean absolute error of all error column values
434
+ valid_errors = compare_result_df['error'].dropna()
435
+ return abs(valid_errors).mean() if not valid_errors.empty else np.nan
436
 
437
  key_pv_col = None
438
+ # pvs dataframe here has policy_id as index, columns are features.
 
 
 
439
  for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
440
+ if potential_col in pvs.columns: # pvs is already loaded and indexed
441
  key_pv_col = potential_col
442
  break
443
 
 
447
  get_error_safe(results.get('cf_pv_total_mort'), key_pv_col)
448
  ]
449
 
450
+ if not loc_vars_attrs_scaled.empty: # Check if attribute calibration was performed
451
+ error_data['Attr Calib.'] = [
452
+ get_error_safe(results.get('attr_total_pv_base'), key_pv_col),
453
+ # For stressed PVs under Attr Calib, we need to call compare_total from cluster_attrs
454
+ get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
455
+ get_error_safe(cluster_attrs.compare_total(pvs_mort15), key_pv_col)
456
  ]
457
  else:
458
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
 
463
  get_error_safe(results.get('pv_total_pv_mort'), key_pv_col)
464
  ]
465
 
466
+ summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%']).astype(float) # Ensure float for plotting
467
 
468
  fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
469
+ sns.set_style("whitegrid")
470
+
471
+ # Melt for Seaborn barplot
472
+ summary_df_melted = summary_df.reset_index().rename(columns={'index': 'Scenario'})
473
+ summary_df_melted = summary_df_melted.melt(id_vars='Scenario', var_name='Calibration Method', value_name='Absolute Error Rate')
474
+
475
+ sns.barplot(x='Scenario', y='Absolute Error Rate', hue='Calibration Method', data=summary_df_melted, ax=ax_summary)
476
+
477
  ax_summary.set_ylabel('Absolute Error Rate')
478
+ title_suffix = f' for {key_pv_col}' if key_pv_col else ' (Mean Absolute Error)'
479
  ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
480
  ax_summary.tick_params(axis='x', rotation=0)
481
+ if ax_summary.get_legend():
482
+ ax_summary.get_legend().set_title('Calibration Method')
483
+ ax_summary.grid(True, axis='y') # Horizontal grid lines for bar plot
484
+
485
  plt.tight_layout()
 
486
  buf_summary = io.BytesIO()
487
  plt.savefig(buf_summary, format='png', dpi=100)
488
  buf_summary.seek(0)
 
495
  gr.Error(f"File not found: {e.filename}. Please ensure example files are in '{EXAMPLE_DATA_DIR}' or all files are uploaded.")
496
  return {"error": f"File not found: {e.filename}"}
497
  except KeyError as e:
498
+ gr.Error(f"A required column/index ('policy_id' or feature column) is missing or misnamed: {e}. Please check data format.")
 
499
  return {"error": f"Missing column/index: {e}"}
500
+ except ValueError as e: # Catch other value errors like from plotting or data prep
501
+ gr.Error(f"Data processing or plotting error: {str(e)}")
502
+ import traceback
503
+ traceback.print_exc()
504
+ return {"error": f"Data error: {str(e)}"}
505
  except Exception as e:
506
+ gr.Error(f"An unexpected error occurred: {str(e)}")
507
  import traceback
508
+ traceback.print_exc()
509
+ return {"error": f"Unexpected error: {str(e)}"}
510
 
511
+ # --- Gradio interface creation (create_interface, etc.) ---
512
+ # This part remains unchanged from your original script.
513
+ # Ensure dummy file creation in if __name__ == "__main__": handles policy_id correctly.
514
  def create_interface():
515
  with gr.Blocks(title="Cluster Model Points Analysis") as demo:
516
  gr.Markdown("""
 
520
  Upload your Excel files or use the example data to analyze cashflows, policy attributes, and present values using different calibration methods.
521
 
522
  **Required Files (Excel .xlsx):**
523
+ - Cashflows - Base Scenario (should contain a 'policy_id' column, or it's the first column/index)
524
+ - Cashflows - Lapse Stress (+50%) (similar structure)
525
+ - Cashflows - Mortality Stress (+15%) (similar structure)
526
+ - Policy Data (should contain 'policy_id', 'age_at_entry', 'policy_term', 'sum_assured', 'duration_mth')
527
+ - Present Values - Base Scenario (should contain 'policy_id' and PV columns like 'PV_NetCF')
528
+ - Present Values - Lapse Stress (similar structure)
529
+ - Present Values - Mortality Stress (similar structure)
 
 
530
  """)
531
 
532
  with gr.Row():
 
573
  attr_cashflow_plot_out = gr.Image(label="Cashflow Value Comparisons (Actual vs. Estimate) Across Scenarios")
574
  attr_scatter_cashflows_base_out = gr.Image(label="Scatter Plot - Per-Cluster Cashflows (Base Scenario)")
575
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
576
+ attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total") # Only one PV table shown in original UI for this tab
 
 
 
 
577
 
578
  with gr.TabItem("💰 Present Value Calibration"):
579
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
 
588
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
589
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
590
 
 
591
  def get_all_output_components():
592
  return [
593
  summary_plot_output,
 
594
  cf_total_base_table_out, cf_policy_attrs_total_out,
595
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
596
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
 
597
  attr_total_cf_base_out, attr_policy_attrs_total_out,
598
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
 
599
  pv_total_cf_base_out, pv_policy_attrs_total_out,
600
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
601
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
602
  ]
603
 
 
604
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
605
  files = [f1, f2, f3, f4, f5, f6, f7]
 
606
  file_paths = []
607
+ # Gradio File component now passes full path for temporary files
608
+ for i, f_obj_path in enumerate(files): # f_obj is now a path string or None
609
+ if f_obj_path is None:
610
+ gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
 
 
 
 
 
 
 
 
 
 
 
 
611
  return [None] * len(get_all_output_components())
612
+ if not isinstance(f_obj_path, str): # Should be a path string
613
+ gr.Error(f"Invalid file input for argument {i+1}. Expected path, got {type(f_obj_path)}")
614
+ return [None] * len(get_all_output_components())
615
+ file_paths.append(f_obj_path)
616
 
617
  results = process_files(*file_paths)
618
 
619
+ if "error" in results :
620
+ return [gr.Plot.update(None)] * len(get_all_output_components()) # Clear plots on error
621
+
622
+ # Ensure DataFrames are converted to a format Gradio can display (e.g. List of Lists or pandas)
623
+ # For Dataframe components, pandas DataFrames are fine. For Image, PIL Image is fine.
624
  return [
625
  results.get('summary_plot'),
 
626
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
627
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
628
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
 
629
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
630
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
 
631
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
632
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
633
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
 
640
  outputs=get_all_output_components()
641
  )
642
 
 
643
  def load_example_files():
644
+ # Create dummy example files if they don't exist
645
+ os.makedirs(EXAMPLE_DATA_DIR, exist_ok=True)
 
 
 
646
  for key, fp in EXAMPLE_FILES.items():
647
  if not os.path.exists(fp):
648
+ gr.Info(f"Example file {fp} not found. Attempting to create a dummy file.")
 
649
  try:
650
+ num_policies = 50 # For dummy data
651
+ if "cashflow" in key or "pv" in key:
652
+ dummy_data = {'policy_id': [f'P{j:03d}' for j in range(num_policies)]}
653
+ for i in range(10): # 10 time periods / PV components
654
+ dummy_data[f't{i}'] = np.random.rand(num_policies) * 1000
655
  elif "policy_data" in key:
656
+ dummy_data = {
657
+ 'policy_id': [f'P{j:03d}' for j in range(num_policies)],
658
+ 'age_at_entry': np.random.randint(20, 50, num_policies),
659
+ 'policy_term': np.random.randint(10, 30, num_policies),
660
+ 'sum_assured': np.random.randint(10000, 50000, num_policies),
661
+ 'duration_mth': np.random.randint(1, 240, num_policies)
662
+ }
663
+ else: # Default dummy
664
+ dummy_data = {'policy_id': [f'P{j:03d}' for j in range(num_policies)], 'feature1': np.random.rand(num_policies)}
665
+
666
+ dummy_df = pd.DataFrame(dummy_data)
667
+ # Do not set index here, let read_and_prep_excel handle it.
668
+ dummy_df.to_excel(fp, index=False) # Save without pandas index
669
+ gr.Info(f"Dummy file for '{os.path.basename(fp)}' created in '{EXAMPLE_DATA_DIR}'.")
670
  except Exception as e:
671
+ gr.Error(f"Could not create dummy file for {fp}: {e}")
672
+ return [None] * 7 # Fail loading if dummy creation fails
673
 
674
+ missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
675
+ if missing_files:
676
+ gr.Error(f"Still missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
677
+ return [None] * 7
 
678
 
679
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
680
+ # Return file paths directly to the File components
681
  return [
682
+ EXAMPLE_FILES["cashflow_base"], EXAMPLE_FILES["cashflow_lapse"], EXAMPLE_FILES["cashflow_mort"],
683
+ EXAMPLE_FILES["policy_data"], EXAMPLE_FILES["pv_base"], EXAMPLE_FILES["pv_lapse"],
684
+ EXAMPLE_FILES["pv_mort"]
 
 
 
 
685
  ]
686
 
 
687
  load_example_btn.click(
688
  load_example_files,
689
  inputs=[],
 
696
  if __name__ == "__main__":
697
  if not os.path.exists(EXAMPLE_DATA_DIR):
698
  os.makedirs(EXAMPLE_DATA_DIR)
699
+ print(f"Created directory '{EXAMPLE_DATA_DIR}'. Please place example Excel files there or dummy files will be generated on 'Load Example Data'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
  demo_app = create_interface()
702
  demo_app.launch()