alidenewade commited on
Commit
541bbc3
·
verified ·
1 Parent(s): 6570096

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -207
app.py CHANGED
@@ -2,12 +2,11 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min, r2_score # r2_score is not used but kept
6
- import plotly.graph_objects as go # ADDED
7
- import plotly.express as px # ADDED
8
- from plotly.subplots import make_subplots # ADDED
9
  import io
10
- import os
11
  from PIL import Image
12
 
13
  # Define the paths for example data
@@ -53,6 +52,7 @@ class Clusters:
53
  if agg:
54
  cols = df.columns
55
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
 
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
@@ -68,199 +68,143 @@ class Clusters:
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
 
71
  actual_values = {}
72
  for col in df.columns:
73
  if agg.get(col, 'sum') == 'mean':
74
  actual_values[col] = df[col].mean()
75
- else:
76
  actual_values[col] = df[col].sum()
77
  actual = pd.Series(actual_values)
78
 
 
79
  reps_unscaled = self.extract_reps(df)
80
  estimate_values = {}
81
 
82
  for col in df.columns:
83
  if agg.get(col, 'sum') == 'mean':
 
84
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
85
  total_weight = self.policy_count.sum()
86
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
87
- else:
88
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
 
89
  estimate = pd.Series(estimate_values)
90
- else:
 
91
  actual = df.sum()
92
  estimate = self.extract_and_scale_reps(df).sum()
93
 
 
94
  error = np.where(actual != 0, estimate / actual - 1, 0)
 
95
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
96
 
97
 
98
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
99
- """Create cashflow comparison plots using Plotly"""
100
  if not cfs_list or not cluster_obj or not titles:
101
  return None
102
  num_plots = len(cfs_list)
103
  if num_plots == 0:
104
  return None
105
 
 
106
  cols = 2
107
  rows = (num_plots + cols - 1) // cols
108
 
109
- # Use subplot titles from the input 'titles'
110
- subplot_titles_full = titles[:num_plots] + [""] * (rows * cols - num_plots)
111
-
112
- fig = make_subplots(
113
- rows=rows, cols=cols,
114
- subplot_titles=subplot_titles_full
115
- )
116
-
117
- plot_idx = 0
118
- for i_df, (df, title) in enumerate(zip(cfs_list, titles)): # Use i_df to avoid conflict with internal loop i
119
- if plot_idx < rows * cols:
120
- r = plot_idx // cols + 1
121
- c = plot_idx % cols + 1
122
  comparison = cluster_obj.compare_total(df)
123
-
124
- fig.add_trace(go.Scatter(x=comparison.index, y=comparison['actual'], name='Actual',
125
- legendgroup='group1', showlegend=(plot_idx == 0)), row=r, col=c)
126
- fig.add_trace(go.Scatter(x=comparison.index, y=comparison['estimate'], name='Estimate',
127
- legendgroup='group2', showlegend=(plot_idx == 0)), row=r, col=c)
128
-
129
- fig.update_xaxes(title_text='Time', showgrid=True, row=r, col=c)
130
- fig.update_yaxes(title_text='Value', showgrid=True, row=r, col=c)
131
- plot_idx += 1
132
 
133
- # Hide unused subplots by making axes invisible and clearing titles
134
- for i in range(plot_idx, rows * cols):
135
- r = i // cols + 1
136
- c = i % cols + 1
137
- fig.update_xaxes(visible=False, row=r, col=c)
138
- fig.update_yaxes(visible=False, row=r, col=c)
139
- if fig.layout.annotations and i < len(fig.layout.annotations):
140
- fig.layout.annotations[i].update(text="")
141
-
142
-
143
- fig_width = 1500
144
- fig_height = 500 * rows
145
- fig.update_layout(
146
- width=fig_width,
147
- height=fig_height,
148
- margin=dict(l=60, r=30, t=60, b=60) # Adjusted margins
149
- )
150
 
151
- try:
152
- # Requires kaleido: pip install kaleido
153
- img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height)
154
- buf = io.BytesIO(img_bytes)
 
 
 
 
 
 
155
  img = Image.open(buf)
 
156
  return img
157
- except Exception as e:
158
- print(f"Error generating cashflow plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
159
- # Create a placeholder error image
160
- error_fig = go.Figure()
161
- error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
162
- error_fig.update_layout(width=fig_width, height=fig_height)
163
- img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
164
- return Image.open(io.BytesIO(img_bytes))
165
 
166
-
167
- def plot_scatter_comparison(df_compare_output, title):
168
- """Create scatter plot comparison from compare() output using Plotly"""
169
- fig_width = 1200
170
- fig_height = 800
171
-
172
- if df_compare_output is None or df_compare_output.empty:
173
- fig = go.Figure()
174
- fig.add_annotation(
175
- text="No data to display",
176
- xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
177
- font=dict(size=15)
178
- )
179
- fig.update_layout(title_text=title, width=fig_width, height=fig_height)
180
  else:
181
- if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
182
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
183
- fig = px.scatter(df_compare_output, x='actual', y='estimate', title=title)
184
- fig.update_traces(marker=dict(size=5, opacity=0.6)) # Set marker size and opacity
185
- else:
186
- df_reset = df_compare_output.reset_index()
187
- level_1_name = df_compare_output.index.names[1] if df_compare_output.index.names[1] else 'category'
188
- if level_1_name not in df_reset.columns: # Handle case where level name might not be in columns
189
- df_reset = df_reset.rename(columns={df_reset.columns[1]: level_1_name})
190
-
191
-
192
- fig = px.scatter(df_reset, x='actual', y='estimate', color=level_1_name,
193
- title=title,
194
- labels={'actual': 'Actual', 'estimate': 'Estimate', level_1_name: level_1_name})
195
- fig.update_traces(marker=dict(size=5, opacity=0.6)) # Set marker size and opacity
196
-
197
- num_unique_levels = df_reset[level_1_name].nunique()
198
- if num_unique_levels == 0 or num_unique_levels > 10:
199
- fig.update_layout(showlegend=False)
200
- elif num_unique_levels == 1: # Show legend even for one item if it's named
201
- fig.update_layout(showlegend=True)
202
-
203
-
204
- fig.update_xaxes(showgrid=True, title_text='Actual')
205
- fig.update_yaxes(showgrid=True, title_text='Estimate')
206
-
207
- # Draw identity line
208
- if not df_compare_output.empty:
209
- min_val_actual = df_compare_output['actual'].min()
210
- max_val_actual = df_compare_output['actual'].max()
211
- min_val_estimate = df_compare_output['estimate'].min()
212
- max_val_estimate = df_compare_output['estimate'].max()
213
-
214
- # Handle cases where min/max might be NaN (e.g. if all data is NaN)
215
- if pd.isna(min_val_actual) or pd.isna(min_val_estimate) or pd.isna(max_val_actual) or pd.isna(max_val_estimate):
216
- lims = [0,1] # Default if data is problematic
217
- else:
218
- overall_min = min(min_val_actual, min_val_estimate)
219
- overall_max = max(max_val_actual, max_val_estimate)
220
- lims = [overall_min, overall_max]
221
-
222
-
223
- if lims[0] != lims[1]: # Avoid issues if all data is single point or NaN
224
- fig.add_trace(go.Scatter(
225
- x=lims, y=lims, mode='lines', name='Identity',
226
- line=dict(color='red', width=1), # Adjusted width for Plotly
227
- showlegend=False
228
- ))
229
- fig.update_xaxes(range=lims)
230
- fig.update_yaxes(range=lims, scaleanchor="x", scaleratio=1) # Makes axes square based on data range
231
 
232
- fig.update_layout(width=fig_width, height=fig_height)
233
-
234
- try:
235
- # Requires kaleido: pip install kaleido
236
- img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height)
237
- buf = io.BytesIO(img_bytes)
238
- img = Image.open(buf)
239
- return img
240
- except Exception as e:
241
- print(f"Error generating scatter plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
242
- error_fig = go.Figure()
243
- error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
244
- error_fig.update_layout(width=fig_width, height=fig_height, title_text=title)
245
- img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
246
- return Image.open(io.BytesIO(img_bytes))
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
 
249
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
250
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
251
  """Main processing function - now accepts file paths"""
252
  try:
 
253
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
254
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
255
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
256
 
257
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
258
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
259
  if all(col in pol_data_full.columns for col in required_cols):
260
  pol_data = pol_data_full[required_cols]
261
  else:
262
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
263
- pol_data = pol_data_full
264
 
265
  pvs = pd.read_excel(pv_base_path, index_col=0)
266
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
@@ -270,43 +214,38 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
270
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
271
 
272
  results = {}
 
273
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
274
 
275
  # --- 1. Cashflow Calibration ---
276
  cluster_cfs = Clusters(cfs)
 
277
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
278
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
279
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
280
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
281
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
282
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
283
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
284
 
285
  # --- 2. Policy Attribute Calibration ---
286
- if not pol_data.empty and (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True)).all() != 0:
287
- loc_vars_attrs = (pol_data - pol_data.min(numeric_only=True)) / (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True))
288
- loc_vars_attrs = loc_vars_attrs.fillna(0) # Fill NaNs that may result from division by zero if a column has no variance
289
  else:
290
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
291
- loc_vars_attrs = pol_data.copy() # Use a copy
292
 
293
- if not loc_vars_attrs.empty and pd.api.types.is_numeric_dtype(loc_vars_attrs.values): # Check if data is numeric for KMeans
294
- try:
295
- cluster_attrs = Clusters(loc_vars_attrs)
296
- results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
297
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
298
- results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
299
- results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
300
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
301
- except Exception as e_attr_clust: # Catch errors during clustering (e.g. if data is not suitable)
302
- gr.Error(f"Error during policy attribute clustering: {e_attr_clust}")
303
- results['attr_total_cf_base'] = pd.DataFrame()
304
- results['attr_policy_attrs_total'] = pd.DataFrame()
305
- results['attr_total_pv_base'] = pd.DataFrame()
306
- results['attr_cashflow_plot'] = None
307
- results['attr_scatter_cashflows_base'] = None
308
  else:
309
- gr.Warning("Skipping attribute calibration as data is empty or non-numeric after processing.")
310
  results['attr_total_cf_base'] = pd.DataFrame()
311
  results['attr_policy_attrs_total'] = pd.DataFrame()
312
  results['attr_total_pv_base'] = pd.DataFrame()
@@ -316,36 +255,48 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
316
 
317
  # --- 3. Present Value Calibration ---
318
  cluster_pvs = Clusters(pvs)
 
319
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
320
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
321
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
322
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
323
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
 
324
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
325
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
326
 
327
  # --- Summary Comparison Plot Data ---
 
 
328
  error_data = {}
 
 
329
  def get_error_safe(compare_result, col_name=None):
330
  if compare_result.empty:
331
  return np.nan
332
  if col_name and col_name in compare_result.index:
333
  return abs(compare_result.loc[col_name, 'error'])
334
  else:
 
335
  return abs(compare_result['error']).mean()
336
 
 
337
  key_pv_col = None
338
- for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
339
  if potential_col in pvs.columns:
340
  key_pv_col = potential_col
341
  break
342
 
 
343
  error_data['CF Calib.'] = [
344
  get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
345
  get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
346
  get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
347
  ]
348
- if results.get('attr_total_pv_base') is not None and not results['attr_total_pv_base'].empty : # Check if Attr Calib was successful
 
 
349
  error_data['Attr Calib.'] = [
350
  get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
351
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
@@ -354,51 +305,32 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
354
  else:
355
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
356
 
 
 
357
  error_data['PV Calib.'] = [
358
  get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
359
  get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
360
  get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
361
  ]
362
 
 
363
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
 
 
 
 
364
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
365
- plot_title = f'Calibration Method Comparison - Error in Total PV{title_suffix}'
366
- fig_width = 1000
367
- fig_height = 600
368
-
369
- summary_df_melted = summary_df.reset_index().melt(id_vars='index', var_name='Calibration Method', value_name='Absolute Error Rate')
370
- summary_df_melted.rename(columns={'index': 'Scenario'}, inplace=True)
371
-
372
-
373
- fig_summary = px.bar(
374
- summary_df_melted,
375
- x='Scenario',
376
- y='Absolute Error Rate',
377
- color='Calibration Method',
378
- barmode='group',
379
- title=plot_title
380
- )
381
- fig_summary.update_layout(
382
- width=fig_width, height=fig_height,
383
- xaxis_tickangle=0,
384
- yaxis_title='Absolute Error Rate',
385
- legend_title_text='Calibration Method'
386
- )
387
- fig_summary.update_yaxes(showgrid=True)
388
-
389
- try:
390
- # Requires kaleido: pip install kaleido
391
- buf_summary_bytes = fig_summary.to_image(format="png", width=fig_width, height=fig_height)
392
- buf_summary = io.BytesIO(buf_summary_bytes)
393
- results['summary_plot'] = Image.open(buf_summary)
394
- except Exception as e:
395
- print(f"Error generating summary plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
396
- error_fig = go.Figure()
397
- error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
398
- error_fig.update_layout(width=fig_width, height=fig_height, title_text=plot_title)
399
- img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
400
- results['summary_plot'] = Image.open(io.BytesIO(img_bytes))
401
-
402
  return results
403
 
404
  except FileNotFoundError as e:
@@ -409,9 +341,6 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
409
  return {"error": f"Missing column: {e}"}
410
  except Exception as e:
411
  gr.Error(f"Error processing files: {str(e)}")
412
- # Optionally log the full traceback for debugging
413
- import traceback
414
- traceback.print_exc()
415
  return {"error": f"Error processing files: {str(e)}"}
416
 
417
 
@@ -431,14 +360,14 @@ def create_interface():
431
  - Present Values - Base Scenario
432
  - Present Values - Lapse Stress
433
  - Present Values - Mortality Stress
434
-
435
- **Note:** Plot generation uses Plotly and Kaleido. If plots appear as errors, ensure Kaleido is installed (`pip install kaleido`).
436
  """)
437
 
438
  with gr.Row():
439
  with gr.Column(scale=1):
440
  gr.Markdown("### Upload Files or Load Examples")
 
441
  load_example_btn = gr.Button("Load Example Data")
 
442
  with gr.Row():
443
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
444
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
@@ -449,11 +378,12 @@ def create_interface():
449
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
450
  with gr.Row():
451
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
 
452
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
453
 
454
  with gr.Tabs():
455
  with gr.TabItem("📊 Summary"):
456
- summary_plot_output = gr.Image(label="Calibration Methods Comparison") # Stays as gr.Image
457
 
458
  with gr.TabItem("💸 Cashflow Calibration"):
459
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
@@ -478,6 +408,7 @@ def create_interface():
478
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
479
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
480
 
 
481
  with gr.TabItem("💰 Present Value Calibration"):
482
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
483
  with gr.Row():
@@ -491,46 +422,59 @@ def create_interface():
491
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
492
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
493
 
 
494
  def get_all_output_components():
495
  return [
496
  summary_plot_output,
 
497
  cf_total_base_table_out, cf_policy_attrs_total_out,
498
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
499
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
 
500
  attr_total_cf_base_out, attr_policy_attrs_total_out,
501
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
 
502
  pv_total_cf_base_out, pv_policy_attrs_total_out,
503
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
504
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
505
  ]
506
 
 
507
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
508
  files = [f1, f2, f3, f4, f5, f6, f7]
 
509
  file_paths = []
510
  for i, f_obj in enumerate(files):
511
  if f_obj is None:
512
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
513
  return [None] * len(get_all_output_components())
 
 
514
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
515
  file_paths.append(f_obj.name)
 
516
  elif isinstance(f_obj, str):
517
  file_paths.append(f_obj)
518
  else:
519
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
520
  return [None] * len(get_all_output_components())
521
 
 
522
  results = process_files(*file_paths)
523
 
524
- if "error" in results:
525
  return [None] * len(get_all_output_components())
526
 
527
  return [
528
  results.get('summary_plot'),
 
529
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
530
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
531
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
 
532
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
533
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
 
534
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
535
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
536
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
@@ -543,11 +487,12 @@ def create_interface():
543
  outputs=get_all_output_components()
544
  )
545
 
 
546
  def load_example_files():
547
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
548
  if missing_files:
549
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
550
- return [None] * 7
551
 
552
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
553
  return [
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.cm
 
8
  import io
9
+ import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
 
52
  if agg:
53
  cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
+ # Ensure mult has same index as extract_reps(df) for proper alignment
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
 
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
+ # Calculate actual values using specified aggregation
72
  actual_values = {}
73
  for col in df.columns:
74
  if agg.get(col, 'sum') == 'mean':
75
  actual_values[col] = df[col].mean()
76
+ else: # sum
77
  actual_values[col] = df[col].sum()
78
  actual = pd.Series(actual_values)
79
 
80
+ # Calculate estimate values
81
  reps_unscaled = self.extract_reps(df)
82
  estimate_values = {}
83
 
84
  for col in df.columns:
85
  if agg.get(col, 'sum') == 'mean':
86
+ # Weighted average for mean columns
87
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
  total_weight = self.policy_count.sum()
89
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
90
+ else: # sum
91
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
92
+
93
  estimate = pd.Series(estimate_values)
94
+
95
+ else: # Original logic if no agg is specified (all sum)
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
+ # Calculate error, handling division by zero
100
  error = np.where(actual != 0, estimate / actual - 1, 0)
101
+
102
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
 
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
+ """Create cashflow comparison plots"""
107
  if not cfs_list or not cluster_obj or not titles:
108
  return None
109
  num_plots = len(cfs_list)
110
  if num_plots == 0:
111
  return None
112
 
113
+ # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
116
 
117
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
118
+ axes = axes.flatten()
119
+
120
+ for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
+ if i < len(axes):
 
 
 
 
 
 
 
 
122
  comparison = cluster_obj.compare_total(df)
123
+ comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
124
+ axes[i].set_xlabel('Time')
125
+ axes[i].set_ylabel('Value')
 
 
 
 
 
 
126
 
127
+ # Hide any unused subplots
128
+ for j in range(i + 1, len(axes)):
129
+ fig.delaxes(axes[j])
130
+
131
+ plt.tight_layout()
132
+ buf = io.BytesIO()
133
+ plt.savefig(buf, format='png', dpi=100)
134
+ buf.seek(0)
135
+ img = Image.open(buf)
136
+ plt.close(fig)
137
+ return img
 
 
 
 
 
 
138
 
139
+ def plot_scatter_comparison(df_compare_output, title):
140
+ """Create scatter plot comparison from compare() output"""
141
+ if df_compare_output is None or df_compare_output.empty:
142
+ # Create a blank plot with a message
143
+ fig, ax = plt.subplots(figsize=(12, 8))
144
+ ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
145
+ ax.set_title(title)
146
+ buf = io.BytesIO()
147
+ plt.savefig(buf, format='png', dpi=100)
148
+ buf.seek(0)
149
  img = Image.open(buf)
150
+ plt.close(fig)
151
  return img
 
 
 
 
 
 
 
 
152
 
153
+ fig, ax = plt.subplots(figsize=(12, 8))
154
+
155
+ if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
157
+ ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
 
 
 
 
 
 
 
 
 
158
  else:
159
+ unique_levels = df_compare_output.index.get_level_values(1).unique()
160
+ colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ for item_level, color_val in zip(unique_levels, colors):
163
+ subset = df_compare_output.xs(item_level, level=1)
164
+ ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
+ if len(unique_levels) > 1 and len(unique_levels) <= 10: # Add legend if reasonable number of items
166
+ ax.legend(title=df_compare_output.index.names[1])
167
+
168
+ ax.set_xlabel('Actual')
169
+ ax.set_ylabel('Estimate')
170
+ ax.set_title(title)
171
+ ax.grid(True)
172
+
173
+ # Draw identity line
174
+ lims = [
175
+ np.min([ax.get_xlim(), ax.get_ylim()]),
176
+ np.max([ax.get_xlim(), ax.get_ylim()]),
177
+ ]
178
+ if lims[0] != lims[1]: # Avoid issues if data is all zeros or single point
179
+ ax.plot(lims, lims, 'r-', linewidth=0.5)
180
+ ax.set_xlim(lims)
181
+ ax.set_ylim(lims)
182
+
183
+ buf = io.BytesIO()
184
+ plt.savefig(buf, format='png', dpi=100)
185
+ buf.seek(0)
186
+ img = Image.open(buf)
187
+ plt.close(fig)
188
+ return img
189
 
190
 
191
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
192
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
193
  """Main processing function - now accepts file paths"""
194
  try:
195
+ # Read uploaded files using paths
196
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
197
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
198
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
199
 
200
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
201
+ # Ensure the correct columns are selected for pol_data
202
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
203
  if all(col in pol_data_full.columns for col in required_cols):
204
  pol_data = pol_data_full[required_cols]
205
  else:
206
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
207
+ pol_data = pol_data_full # proceed with whatever columns are there
208
 
209
  pvs = pd.read_excel(pv_base_path, index_col=0)
210
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
 
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
215
 
216
  results = {}
217
+
218
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
220
  # --- 1. Cashflow Calibration ---
221
  cluster_cfs = Clusters(cfs)
222
+
223
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
224
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
225
+
226
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
227
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
228
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
229
+
230
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
231
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
232
 
233
  # --- 2. Policy Attribute Calibration ---
234
+ # Standardize policy attributes
235
+ if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0: # check for variance
236
+ loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
237
  else:
238
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
+ loc_vars_attrs = pol_data # Use original if no variance, KMeans might handle it or fail gracefully
240
 
241
+ if not loc_vars_attrs.empty:
242
+ cluster_attrs = Clusters(loc_vars_attrs)
243
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
244
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
245
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
246
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
247
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
 
 
 
 
 
 
 
 
248
  else:
 
249
  results['attr_total_cf_base'] = pd.DataFrame()
250
  results['attr_policy_attrs_total'] = pd.DataFrame()
251
  results['attr_total_pv_base'] = pd.DataFrame()
 
255
 
256
  # --- 3. Present Value Calibration ---
257
  cluster_pvs = Clusters(pvs)
258
+
259
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
260
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
261
+
262
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
263
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
264
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
265
+
266
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
267
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
268
 
269
  # --- Summary Comparison Plot Data ---
270
+ # Error metric for key PV column or mean absolute error
271
+
272
  error_data = {}
273
+
274
+ # Function to safely get error value
275
  def get_error_safe(compare_result, col_name=None):
276
  if compare_result.empty:
277
  return np.nan
278
  if col_name and col_name in compare_result.index:
279
  return abs(compare_result.loc[col_name, 'error'])
280
  else:
281
+ # Use mean absolute error if specific column not found or col_name is None
282
  return abs(compare_result['error']).mean()
283
 
284
+ # Determine key PV column (try common names)
285
  key_pv_col = None
286
+ for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']: # Add more common names if needed
287
  if potential_col in pvs.columns:
288
  key_pv_col = potential_col
289
  break
290
 
291
+ # Cashflow Calibration Errors
292
  error_data['CF Calib.'] = [
293
  get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
294
  get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
295
  get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
296
  ]
297
+
298
+ # Policy Attribute Calibration Errors
299
+ if not loc_vars_attrs.empty:
300
  error_data['Attr Calib.'] = [
301
  get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
302
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
 
305
  else:
306
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
307
 
308
+
309
+ # Present Value Calibration Errors
310
  error_data['PV Calib.'] = [
311
  get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
312
  get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
313
  get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
314
  ]
315
 
316
+ # Create Summary Plot
317
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
318
+
319
+ fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
320
+ summary_df.plot(kind='bar', ax=ax_summary, grid=True)
321
+ ax_summary.set_ylabel('Absolute Error Rate')
322
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
323
+ ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
324
+ ax_summary.tick_params(axis='x', rotation=0)
325
+ ax_summary.legend(title='Calibration Method')
326
+ plt.tight_layout()
327
+
328
+ buf_summary = io.BytesIO()
329
+ plt.savefig(buf_summary, format='png', dpi=100)
330
+ buf_summary.seek(0)
331
+ results['summary_plot'] = Image.open(buf_summary)
332
+ plt.close(fig_summary)
333
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  return results
335
 
336
  except FileNotFoundError as e:
 
341
  return {"error": f"Missing column: {e}"}
342
  except Exception as e:
343
  gr.Error(f"Error processing files: {str(e)}")
 
 
 
344
  return {"error": f"Error processing files: {str(e)}"}
345
 
346
 
 
360
  - Present Values - Base Scenario
361
  - Present Values - Lapse Stress
362
  - Present Values - Mortality Stress
 
 
363
  """)
364
 
365
  with gr.Row():
366
  with gr.Column(scale=1):
367
  gr.Markdown("### Upload Files or Load Examples")
368
+
369
  load_example_btn = gr.Button("Load Example Data")
370
+
371
  with gr.Row():
372
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
373
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
 
378
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
379
  with gr.Row():
380
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
381
+
382
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
383
 
384
  with gr.Tabs():
385
  with gr.TabItem("📊 Summary"):
386
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison")
387
 
388
  with gr.TabItem("💸 Cashflow Calibration"):
389
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
 
408
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
409
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
410
 
411
+
412
  with gr.TabItem("💰 Present Value Calibration"):
413
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
414
  with gr.Row():
 
422
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
423
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
424
 
425
+ # --- Helper function to prepare outputs ---
426
  def get_all_output_components():
427
  return [
428
  summary_plot_output,
429
+ # Cashflow Calib Outputs
430
  cf_total_base_table_out, cf_policy_attrs_total_out,
431
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
432
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
433
+ # Attribute Calib Outputs
434
  attr_total_cf_base_out, attr_policy_attrs_total_out,
435
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
436
+ # PV Calib Outputs
437
  pv_total_cf_base_out, pv_policy_attrs_total_out,
438
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
439
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
440
  ]
441
 
442
+ # --- Action for Analyze Button ---
443
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
444
  files = [f1, f2, f3, f4, f5, f6, f7]
445
+
446
  file_paths = []
447
  for i, f_obj in enumerate(files):
448
  if f_obj is None:
449
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
450
  return [None] * len(get_all_output_components())
451
+
452
+ # If f_obj is a Gradio FileData object (from direct upload)
453
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
454
  file_paths.append(f_obj.name)
455
+ # If f_obj is already a string path (from example load)
456
  elif isinstance(f_obj, str):
457
  file_paths.append(f_obj)
458
  else:
459
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
460
  return [None] * len(get_all_output_components())
461
 
462
+
463
  results = process_files(*file_paths)
464
 
465
+ if "error" in results: # Check if process_files returned an error dictionary
466
  return [None] * len(get_all_output_components())
467
 
468
  return [
469
  results.get('summary_plot'),
470
+ # CF Calib
471
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
472
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
473
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
474
+ # Attr Calib
475
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
476
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
477
+ # PV Calib
478
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
479
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
480
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
 
487
  outputs=get_all_output_components()
488
  )
489
 
490
+ # --- Action for Load Example Data Button ---
491
  def load_example_files():
492
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
493
  if missing_files:
494
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
495
+ return [None] * 7 # Return None for all file inputs
496
 
497
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
498
  return [