alidenewade commited on
Commit
6570096
·
verified ·
1 Parent(s): 3adb7ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -152
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import pairwise_distances_argmin_min, r2_score
6
- import matplotlib.pyplot as plt
7
- import matplotlib.cm
 
8
  import io
9
- import os # Added for path joining
10
  from PIL import Image
11
 
12
  # Define the paths for example data
@@ -52,7 +53,6 @@ class Clusters:
52
  if agg:
53
  cols = df.columns
54
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
55
- # Ensure mult has same index as extract_reps(df) for proper alignment
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
@@ -68,143 +68,199 @@ class Clusters:
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
71
- # Calculate actual values using specified aggregation
72
  actual_values = {}
73
  for col in df.columns:
74
  if agg.get(col, 'sum') == 'mean':
75
  actual_values[col] = df[col].mean()
76
- else: # sum
77
  actual_values[col] = df[col].sum()
78
  actual = pd.Series(actual_values)
79
 
80
- # Calculate estimate values
81
  reps_unscaled = self.extract_reps(df)
82
  estimate_values = {}
83
 
84
  for col in df.columns:
85
  if agg.get(col, 'sum') == 'mean':
86
- # Weighted average for mean columns
87
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
88
  total_weight = self.policy_count.sum()
89
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
90
- else: # sum
91
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
92
-
93
  estimate = pd.Series(estimate_values)
94
-
95
- else: # Original logic if no agg is specified (all sum)
96
  actual = df.sum()
97
  estimate = self.extract_and_scale_reps(df).sum()
98
 
99
- # Calculate error, handling division by zero
100
  error = np.where(actual != 0, estimate / actual - 1, 0)
101
-
102
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
103
 
104
 
105
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
106
- """Create cashflow comparison plots"""
107
  if not cfs_list or not cluster_obj or not titles:
108
  return None
109
  num_plots = len(cfs_list)
110
  if num_plots == 0:
111
  return None
112
 
113
- # Determine subplot layout
114
  cols = 2
115
  rows = (num_plots + cols - 1) // cols
116
 
117
- fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
118
- axes = axes.flatten()
119
-
120
- for i, (df, title) in enumerate(zip(cfs_list, titles)):
121
- if i < len(axes):
 
 
 
 
 
 
 
 
122
  comparison = cluster_obj.compare_total(df)
123
- comparison[['actual', 'estimate']].plot(ax=axes[i], grid=True, title=title)
124
- axes[i].set_xlabel('Time')
125
- axes[i].set_ylabel('Value')
 
 
 
 
 
 
126
 
127
- # Hide any unused subplots
128
- for j in range(i + 1, len(axes)):
129
- fig.delaxes(axes[j])
130
-
131
- plt.tight_layout()
132
- buf = io.BytesIO()
133
- plt.savefig(buf, format='png', dpi=100)
134
- buf.seek(0)
135
- img = Image.open(buf)
136
- plt.close(fig)
137
- return img
 
 
 
 
 
 
138
 
139
- def plot_scatter_comparison(df_compare_output, title):
140
- """Create scatter plot comparison from compare() output"""
141
- if df_compare_output is None or df_compare_output.empty:
142
- # Create a blank plot with a message
143
- fig, ax = plt.subplots(figsize=(12, 8))
144
- ax.text(0.5, 0.5, "No data to display", ha='center', va='center', fontsize=15)
145
- ax.set_title(title)
146
- buf = io.BytesIO()
147
- plt.savefig(buf, format='png', dpi=100)
148
- buf.seek(0)
149
  img = Image.open(buf)
150
- plt.close(fig)
151
  return img
 
 
 
 
 
 
 
 
152
 
153
- fig, ax = plt.subplots(figsize=(12, 8))
154
-
155
- if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
156
- gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
157
- ax.scatter(df_compare_output['actual'], df_compare_output['estimate'], s=9, alpha=0.6)
 
 
 
 
 
 
 
 
 
158
  else:
159
- unique_levels = df_compare_output.index.get_level_values(1).unique()
160
- colors = matplotlib.cm.rainbow(np.linspace(0, 1, len(unique_levels)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- for item_level, color_val in zip(unique_levels, colors):
163
- subset = df_compare_output.xs(item_level, level=1)
164
- ax.scatter(subset['actual'], subset['estimate'], color=color_val, s=9, alpha=0.6, label=item_level)
165
- if len(unique_levels) > 1 and len(unique_levels) <= 10: # Add legend if reasonable number of items
166
- ax.legend(title=df_compare_output.index.names[1])
167
-
168
- ax.set_xlabel('Actual')
169
- ax.set_ylabel('Estimate')
170
- ax.set_title(title)
171
- ax.grid(True)
172
-
173
- # Draw identity line
174
- lims = [
175
- np.min([ax.get_xlim(), ax.get_ylim()]),
176
- np.max([ax.get_xlim(), ax.get_ylim()]),
177
- ]
178
- if lims[0] != lims[1]: # Avoid issues if data is all zeros or single point
179
- ax.plot(lims, lims, 'r-', linewidth=0.5)
180
- ax.set_xlim(lims)
181
- ax.set_ylim(lims)
182
-
183
- buf = io.BytesIO()
184
- plt.savefig(buf, format='png', dpi=100)
185
- buf.seek(0)
186
- img = Image.open(buf)
187
- plt.close(fig)
188
- return img
189
 
190
 
191
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
192
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
193
  """Main processing function - now accepts file paths"""
194
  try:
195
- # Read uploaded files using paths
196
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
197
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
198
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
199
 
200
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
201
- # Ensure the correct columns are selected for pol_data
202
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
203
  if all(col in pol_data_full.columns for col in required_cols):
204
  pol_data = pol_data_full[required_cols]
205
  else:
206
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
207
- pol_data = pol_data_full # proceed with whatever columns are there
208
 
209
  pvs = pd.read_excel(pv_base_path, index_col=0)
210
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
@@ -214,38 +270,43 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
214
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
215
 
216
  results = {}
217
-
218
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
219
 
220
  # --- 1. Cashflow Calibration ---
221
  cluster_cfs = Clusters(cfs)
222
-
223
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
224
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
225
-
226
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
227
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
228
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
229
-
230
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
231
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
232
 
233
  # --- 2. Policy Attribute Calibration ---
234
- # Standardize policy attributes
235
- if not pol_data.empty and (pol_data.max() - pol_data.min()).all() != 0: # check for variance
236
- loc_vars_attrs = (pol_data - pol_data.min()) / (pol_data.max() - pol_data.min())
237
  else:
238
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
239
- loc_vars_attrs = pol_data # Use original if no variance, KMeans might handle it or fail gracefully
240
 
241
- if not loc_vars_attrs.empty:
242
- cluster_attrs = Clusters(loc_vars_attrs)
243
- results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
244
- results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
245
- results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
246
- results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
247
- results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
 
 
 
 
 
 
 
 
248
  else:
 
249
  results['attr_total_cf_base'] = pd.DataFrame()
250
  results['attr_policy_attrs_total'] = pd.DataFrame()
251
  results['attr_total_pv_base'] = pd.DataFrame()
@@ -255,48 +316,36 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
255
 
256
  # --- 3. Present Value Calibration ---
257
  cluster_pvs = Clusters(pvs)
258
-
259
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
260
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
261
-
262
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
263
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
264
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
265
-
266
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
267
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
268
 
269
  # --- Summary Comparison Plot Data ---
270
- # Error metric for key PV column or mean absolute error
271
-
272
  error_data = {}
273
-
274
- # Function to safely get error value
275
  def get_error_safe(compare_result, col_name=None):
276
  if compare_result.empty:
277
  return np.nan
278
  if col_name and col_name in compare_result.index:
279
  return abs(compare_result.loc[col_name, 'error'])
280
  else:
281
- # Use mean absolute error if specific column not found or col_name is None
282
  return abs(compare_result['error']).mean()
283
 
284
- # Determine key PV column (try common names)
285
  key_pv_col = None
286
- for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']: # Add more common names if needed
287
  if potential_col in pvs.columns:
288
  key_pv_col = potential_col
289
  break
290
 
291
- # Cashflow Calibration Errors
292
  error_data['CF Calib.'] = [
293
  get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
294
  get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
295
  get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
296
  ]
297
-
298
- # Policy Attribute Calibration Errors
299
- if not loc_vars_attrs.empty:
300
  error_data['Attr Calib.'] = [
301
  get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
302
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
@@ -305,32 +354,51 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
305
  else:
306
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
307
 
308
-
309
- # Present Value Calibration Errors
310
  error_data['PV Calib.'] = [
311
  get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
312
  get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
313
  get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
314
  ]
315
 
316
- # Create Summary Plot
317
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
318
-
319
- fig_summary, ax_summary = plt.subplots(figsize=(10, 6))
320
- summary_df.plot(kind='bar', ax=ax_summary, grid=True)
321
- ax_summary.set_ylabel('Absolute Error Rate')
322
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
323
- ax_summary.set_title(f'Calibration Method Comparison - Error in Total PV{title_suffix}')
324
- ax_summary.tick_params(axis='x', rotation=0)
325
- ax_summary.legend(title='Calibration Method')
326
- plt.tight_layout()
327
-
328
- buf_summary = io.BytesIO()
329
- plt.savefig(buf_summary, format='png', dpi=100)
330
- buf_summary.seek(0)
331
- results['summary_plot'] = Image.open(buf_summary)
332
- plt.close(fig_summary)
333
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  return results
335
 
336
  except FileNotFoundError as e:
@@ -341,6 +409,9 @@ def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
341
  return {"error": f"Missing column: {e}"}
342
  except Exception as e:
343
  gr.Error(f"Error processing files: {str(e)}")
 
 
 
344
  return {"error": f"Error processing files: {str(e)}"}
345
 
346
 
@@ -360,14 +431,14 @@ def create_interface():
360
  - Present Values - Base Scenario
361
  - Present Values - Lapse Stress
362
  - Present Values - Mortality Stress
 
 
363
  """)
364
 
365
  with gr.Row():
366
  with gr.Column(scale=1):
367
  gr.Markdown("### Upload Files or Load Examples")
368
-
369
  load_example_btn = gr.Button("Load Example Data")
370
-
371
  with gr.Row():
372
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
373
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
@@ -378,12 +449,11 @@ def create_interface():
378
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
379
  with gr.Row():
380
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
381
-
382
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
383
 
384
  with gr.Tabs():
385
  with gr.TabItem("📊 Summary"):
386
- summary_plot_output = gr.Image(label="Calibration Methods Comparison")
387
 
388
  with gr.TabItem("💸 Cashflow Calibration"):
389
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
@@ -408,7 +478,6 @@ def create_interface():
408
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
409
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
410
 
411
-
412
  with gr.TabItem("💰 Present Value Calibration"):
413
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
414
  with gr.Row():
@@ -422,59 +491,46 @@ def create_interface():
422
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
423
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
424
 
425
- # --- Helper function to prepare outputs ---
426
  def get_all_output_components():
427
  return [
428
  summary_plot_output,
429
- # Cashflow Calib Outputs
430
  cf_total_base_table_out, cf_policy_attrs_total_out,
431
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
432
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
433
- # Attribute Calib Outputs
434
  attr_total_cf_base_out, attr_policy_attrs_total_out,
435
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
436
- # PV Calib Outputs
437
  pv_total_cf_base_out, pv_policy_attrs_total_out,
438
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
439
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
440
  ]
441
 
442
- # --- Action for Analyze Button ---
443
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
444
  files = [f1, f2, f3, f4, f5, f6, f7]
445
-
446
  file_paths = []
447
  for i, f_obj in enumerate(files):
448
  if f_obj is None:
449
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
450
  return [None] * len(get_all_output_components())
451
-
452
- # If f_obj is a Gradio FileData object (from direct upload)
453
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
454
  file_paths.append(f_obj.name)
455
- # If f_obj is already a string path (from example load)
456
  elif isinstance(f_obj, str):
457
  file_paths.append(f_obj)
458
  else:
459
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
460
  return [None] * len(get_all_output_components())
461
 
462
-
463
  results = process_files(*file_paths)
464
 
465
- if "error" in results: # Check if process_files returned an error dictionary
466
  return [None] * len(get_all_output_components())
467
 
468
  return [
469
  results.get('summary_plot'),
470
- # CF Calib
471
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
472
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
473
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
474
- # Attr Calib
475
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
476
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
477
- # PV Calib
478
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
479
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
480
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
@@ -487,12 +543,11 @@ def create_interface():
487
  outputs=get_all_output_components()
488
  )
489
 
490
- # --- Action for Load Example Data Button ---
491
  def load_example_files():
492
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
493
  if missing_files:
494
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
495
- return [None] * 7 # Return None for all file inputs
496
 
497
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
498
  return [
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances_argmin_min, r2_score # r2_score is not used but kept
6
+ import plotly.graph_objects as go # ADDED
7
+ import plotly.express as px # ADDED
8
+ from plotly.subplots import make_subplots # ADDED
9
  import io
10
+ import os
11
  from PIL import Image
12
 
13
  # Define the paths for example data
 
53
  if agg:
54
  cols = df.columns
55
  mult = pd.DataFrame({c: (self.policy_count if (c not in agg or agg[c] == 'sum') else 1) for c in cols})
 
56
  extracted_df = self.extract_reps(df)
57
  mult.index = extracted_df.index
58
  return extracted_df.mul(mult)
 
68
  def compare_total(self, df, agg=None):
69
  """Aggregate df by columns"""
70
  if agg:
 
71
  actual_values = {}
72
  for col in df.columns:
73
  if agg.get(col, 'sum') == 'mean':
74
  actual_values[col] = df[col].mean()
75
+ else:
76
  actual_values[col] = df[col].sum()
77
  actual = pd.Series(actual_values)
78
 
 
79
  reps_unscaled = self.extract_reps(df)
80
  estimate_values = {}
81
 
82
  for col in df.columns:
83
  if agg.get(col, 'sum') == 'mean':
 
84
  weighted_sum = (reps_unscaled[col] * self.policy_count).sum()
85
  total_weight = self.policy_count.sum()
86
  estimate_values[col] = weighted_sum / total_weight if total_weight > 0 else 0
87
+ else:
88
  estimate_values[col] = (reps_unscaled[col] * self.policy_count).sum()
 
89
  estimate = pd.Series(estimate_values)
90
+ else:
 
91
  actual = df.sum()
92
  estimate = self.extract_and_scale_reps(df).sum()
93
 
 
94
  error = np.where(actual != 0, estimate / actual - 1, 0)
 
95
  return pd.DataFrame({'actual': actual, 'estimate': estimate, 'error': error})
96
 
97
 
98
  def plot_cashflows_comparison(cfs_list, cluster_obj, titles):
99
+ """Create cashflow comparison plots using Plotly"""
100
  if not cfs_list or not cluster_obj or not titles:
101
  return None
102
  num_plots = len(cfs_list)
103
  if num_plots == 0:
104
  return None
105
 
 
106
  cols = 2
107
  rows = (num_plots + cols - 1) // cols
108
 
109
+ # Use subplot titles from the input 'titles'
110
+ subplot_titles_full = titles[:num_plots] + [""] * (rows * cols - num_plots)
111
+
112
+ fig = make_subplots(
113
+ rows=rows, cols=cols,
114
+ subplot_titles=subplot_titles_full
115
+ )
116
+
117
+ plot_idx = 0
118
+ for i_df, (df, title) in enumerate(zip(cfs_list, titles)): # Use i_df to avoid conflict with internal loop i
119
+ if plot_idx < rows * cols:
120
+ r = plot_idx // cols + 1
121
+ c = plot_idx % cols + 1
122
  comparison = cluster_obj.compare_total(df)
123
+
124
+ fig.add_trace(go.Scatter(x=comparison.index, y=comparison['actual'], name='Actual',
125
+ legendgroup='group1', showlegend=(plot_idx == 0)), row=r, col=c)
126
+ fig.add_trace(go.Scatter(x=comparison.index, y=comparison['estimate'], name='Estimate',
127
+ legendgroup='group2', showlegend=(plot_idx == 0)), row=r, col=c)
128
+
129
+ fig.update_xaxes(title_text='Time', showgrid=True, row=r, col=c)
130
+ fig.update_yaxes(title_text='Value', showgrid=True, row=r, col=c)
131
+ plot_idx += 1
132
 
133
+ # Hide unused subplots by making axes invisible and clearing titles
134
+ for i in range(plot_idx, rows * cols):
135
+ r = i // cols + 1
136
+ c = i % cols + 1
137
+ fig.update_xaxes(visible=False, row=r, col=c)
138
+ fig.update_yaxes(visible=False, row=r, col=c)
139
+ if fig.layout.annotations and i < len(fig.layout.annotations):
140
+ fig.layout.annotations[i].update(text="")
141
+
142
+
143
+ fig_width = 1500
144
+ fig_height = 500 * rows
145
+ fig.update_layout(
146
+ width=fig_width,
147
+ height=fig_height,
148
+ margin=dict(l=60, r=30, t=60, b=60) # Adjusted margins
149
+ )
150
 
151
+ try:
152
+ # Requires kaleido: pip install kaleido
153
+ img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height)
154
+ buf = io.BytesIO(img_bytes)
 
 
 
 
 
 
155
  img = Image.open(buf)
 
156
  return img
157
+ except Exception as e:
158
+ print(f"Error generating cashflow plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
159
+ # Create a placeholder error image
160
+ error_fig = go.Figure()
161
+ error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
162
+ error_fig.update_layout(width=fig_width, height=fig_height)
163
+ img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
164
+ return Image.open(io.BytesIO(img_bytes))
165
 
166
+
167
+ def plot_scatter_comparison(df_compare_output, title):
168
+ """Create scatter plot comparison from compare() output using Plotly"""
169
+ fig_width = 1200
170
+ fig_height = 800
171
+
172
+ if df_compare_output is None or df_compare_output.empty:
173
+ fig = go.Figure()
174
+ fig.add_annotation(
175
+ text="No data to display",
176
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
177
+ font=dict(size=15)
178
+ )
179
+ fig.update_layout(title_text=title, width=fig_width, height=fig_height)
180
  else:
181
+ if not isinstance(df_compare_output.index, pd.MultiIndex) or df_compare_output.index.nlevels < 2:
182
+ gr.Warning("Scatter plot data is not in the expected multi-index format. Plotting raw actual vs estimate.")
183
+ fig = px.scatter(df_compare_output, x='actual', y='estimate', title=title)
184
+ fig.update_traces(marker=dict(size=5, opacity=0.6)) # Set marker size and opacity
185
+ else:
186
+ df_reset = df_compare_output.reset_index()
187
+ level_1_name = df_compare_output.index.names[1] if df_compare_output.index.names[1] else 'category'
188
+ if level_1_name not in df_reset.columns: # Handle case where level name might not be in columns
189
+ df_reset = df_reset.rename(columns={df_reset.columns[1]: level_1_name})
190
+
191
+
192
+ fig = px.scatter(df_reset, x='actual', y='estimate', color=level_1_name,
193
+ title=title,
194
+ labels={'actual': 'Actual', 'estimate': 'Estimate', level_1_name: level_1_name})
195
+ fig.update_traces(marker=dict(size=5, opacity=0.6)) # Set marker size and opacity
196
+
197
+ num_unique_levels = df_reset[level_1_name].nunique()
198
+ if num_unique_levels == 0 or num_unique_levels > 10:
199
+ fig.update_layout(showlegend=False)
200
+ elif num_unique_levels == 1: # Show legend even for one item if it's named
201
+ fig.update_layout(showlegend=True)
202
+
203
+
204
+ fig.update_xaxes(showgrid=True, title_text='Actual')
205
+ fig.update_yaxes(showgrid=True, title_text='Estimate')
206
+
207
+ # Draw identity line
208
+ if not df_compare_output.empty:
209
+ min_val_actual = df_compare_output['actual'].min()
210
+ max_val_actual = df_compare_output['actual'].max()
211
+ min_val_estimate = df_compare_output['estimate'].min()
212
+ max_val_estimate = df_compare_output['estimate'].max()
213
+
214
+ # Handle cases where min/max might be NaN (e.g. if all data is NaN)
215
+ if pd.isna(min_val_actual) or pd.isna(min_val_estimate) or pd.isna(max_val_actual) or pd.isna(max_val_estimate):
216
+ lims = [0,1] # Default if data is problematic
217
+ else:
218
+ overall_min = min(min_val_actual, min_val_estimate)
219
+ overall_max = max(max_val_actual, max_val_estimate)
220
+ lims = [overall_min, overall_max]
221
+
222
+
223
+ if lims[0] != lims[1]: # Avoid issues if all data is single point or NaN
224
+ fig.add_trace(go.Scatter(
225
+ x=lims, y=lims, mode='lines', name='Identity',
226
+ line=dict(color='red', width=1), # Adjusted width for Plotly
227
+ showlegend=False
228
+ ))
229
+ fig.update_xaxes(range=lims)
230
+ fig.update_yaxes(range=lims, scaleanchor="x", scaleratio=1) # Makes axes square based on data range
231
 
232
+ fig.update_layout(width=fig_width, height=fig_height)
233
+
234
+ try:
235
+ # Requires kaleido: pip install kaleido
236
+ img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height)
237
+ buf = io.BytesIO(img_bytes)
238
+ img = Image.open(buf)
239
+ return img
240
+ except Exception as e:
241
+ print(f"Error generating scatter plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
242
+ error_fig = go.Figure()
243
+ error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
244
+ error_fig.update_layout(width=fig_width, height=fig_height, title_text=title)
245
+ img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
246
+ return Image.open(io.BytesIO(img_bytes))
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
 
249
  def process_files(cashflow_base_path, cashflow_lapse_path, cashflow_mort_path,
250
  policy_data_path, pv_base_path, pv_lapse_path, pv_mort_path):
251
  """Main processing function - now accepts file paths"""
252
  try:
 
253
  cfs = pd.read_excel(cashflow_base_path, index_col=0)
254
  cfs_lapse50 = pd.read_excel(cashflow_lapse_path, index_col=0)
255
  cfs_mort15 = pd.read_excel(cashflow_mort_path, index_col=0)
256
 
257
  pol_data_full = pd.read_excel(policy_data_path, index_col=0)
 
258
  required_cols = ['age_at_entry', 'policy_term', 'sum_assured', 'duration_mth']
259
  if all(col in pol_data_full.columns for col in required_cols):
260
  pol_data = pol_data_full[required_cols]
261
  else:
262
  gr.Warning(f"Policy data might be missing required columns. Found: {pol_data_full.columns.tolist()}")
263
+ pol_data = pol_data_full
264
 
265
  pvs = pd.read_excel(pv_base_path, index_col=0)
266
  pvs_lapse50 = pd.read_excel(pv_lapse_path, index_col=0)
 
270
  scen_titles = ['Base', 'Lapse+50%', 'Mort+15%']
271
 
272
  results = {}
 
273
  mean_attrs = {'age_at_entry':'mean', 'policy_term':'mean', 'duration_mth':'mean', 'sum_assured': 'sum'}
274
 
275
  # --- 1. Cashflow Calibration ---
276
  cluster_cfs = Clusters(cfs)
 
277
  results['cf_total_base_table'] = cluster_cfs.compare_total(cfs)
278
  results['cf_policy_attrs_total'] = cluster_cfs.compare_total(pol_data, agg=mean_attrs)
 
279
  results['cf_pv_total_base'] = cluster_cfs.compare_total(pvs)
280
  results['cf_pv_total_lapse'] = cluster_cfs.compare_total(pvs_lapse50)
281
  results['cf_pv_total_mort'] = cluster_cfs.compare_total(pvs_mort15)
 
282
  results['cf_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_cfs, scen_titles)
283
  results['cf_scatter_cashflows_base'] = plot_scatter_comparison(cluster_cfs.compare(cfs), 'Cashflow Calib. - Cashflows (Base)')
284
 
285
  # --- 2. Policy Attribute Calibration ---
286
+ if not pol_data.empty and (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True)).all() != 0:
287
+ loc_vars_attrs = (pol_data - pol_data.min(numeric_only=True)) / (pol_data.max(numeric_only=True) - pol_data.min(numeric_only=True))
288
+ loc_vars_attrs = loc_vars_attrs.fillna(0) # Fill NaNs that may result from division by zero if a column has no variance
289
  else:
290
  gr.Warning("Policy data for attribute calibration is empty or has no variance. Skipping attribute calibration plots.")
291
+ loc_vars_attrs = pol_data.copy() # Use a copy
292
 
293
+ if not loc_vars_attrs.empty and pd.api.types.is_numeric_dtype(loc_vars_attrs.values): # Check if data is numeric for KMeans
294
+ try:
295
+ cluster_attrs = Clusters(loc_vars_attrs)
296
+ results['attr_total_cf_base'] = cluster_attrs.compare_total(cfs)
297
+ results['attr_policy_attrs_total'] = cluster_attrs.compare_total(pol_data, agg=mean_attrs)
298
+ results['attr_total_pv_base'] = cluster_attrs.compare_total(pvs)
299
+ results['attr_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_attrs, scen_titles)
300
+ results['attr_scatter_cashflows_base'] = plot_scatter_comparison(cluster_attrs.compare(cfs), 'Policy Attr. Calib. - Cashflows (Base)')
301
+ except Exception as e_attr_clust: # Catch errors during clustering (e.g. if data is not suitable)
302
+ gr.Error(f"Error during policy attribute clustering: {e_attr_clust}")
303
+ results['attr_total_cf_base'] = pd.DataFrame()
304
+ results['attr_policy_attrs_total'] = pd.DataFrame()
305
+ results['attr_total_pv_base'] = pd.DataFrame()
306
+ results['attr_cashflow_plot'] = None
307
+ results['attr_scatter_cashflows_base'] = None
308
  else:
309
+ gr.Warning("Skipping attribute calibration as data is empty or non-numeric after processing.")
310
  results['attr_total_cf_base'] = pd.DataFrame()
311
  results['attr_policy_attrs_total'] = pd.DataFrame()
312
  results['attr_total_pv_base'] = pd.DataFrame()
 
316
 
317
  # --- 3. Present Value Calibration ---
318
  cluster_pvs = Clusters(pvs)
 
319
  results['pv_total_cf_base'] = cluster_pvs.compare_total(cfs)
320
  results['pv_policy_attrs_total'] = cluster_pvs.compare_total(pol_data, agg=mean_attrs)
 
321
  results['pv_total_pv_base'] = cluster_pvs.compare_total(pvs)
322
  results['pv_total_pv_lapse'] = cluster_pvs.compare_total(pvs_lapse50)
323
  results['pv_total_pv_mort'] = cluster_pvs.compare_total(pvs_mort15)
 
324
  results['pv_cashflow_plot'] = plot_cashflows_comparison(cfs_list, cluster_pvs, scen_titles)
325
  results['pv_scatter_pvs_base'] = plot_scatter_comparison(cluster_pvs.compare(pvs), 'PV Calib. - PVs (Base)')
326
 
327
  # --- Summary Comparison Plot Data ---
 
 
328
  error_data = {}
 
 
329
  def get_error_safe(compare_result, col_name=None):
330
  if compare_result.empty:
331
  return np.nan
332
  if col_name and col_name in compare_result.index:
333
  return abs(compare_result.loc[col_name, 'error'])
334
  else:
 
335
  return abs(compare_result['error']).mean()
336
 
 
337
  key_pv_col = None
338
+ for potential_col in ['PV_NetCF', 'pv_net_cf', 'net_cf_pv', 'PV_Net_CF']:
339
  if potential_col in pvs.columns:
340
  key_pv_col = potential_col
341
  break
342
 
 
343
  error_data['CF Calib.'] = [
344
  get_error_safe(cluster_cfs.compare_total(pvs), key_pv_col),
345
  get_error_safe(cluster_cfs.compare_total(pvs_lapse50), key_pv_col),
346
  get_error_safe(cluster_cfs.compare_total(pvs_mort15), key_pv_col)
347
  ]
348
+ if results.get('attr_total_pv_base') is not None and not results['attr_total_pv_base'].empty : # Check if Attr Calib was successful
 
 
349
  error_data['Attr Calib.'] = [
350
  get_error_safe(cluster_attrs.compare_total(pvs), key_pv_col),
351
  get_error_safe(cluster_attrs.compare_total(pvs_lapse50), key_pv_col),
 
354
  else:
355
  error_data['Attr Calib.'] = [np.nan, np.nan, np.nan]
356
 
 
 
357
  error_data['PV Calib.'] = [
358
  get_error_safe(cluster_pvs.compare_total(pvs), key_pv_col),
359
  get_error_safe(cluster_pvs.compare_total(pvs_lapse50), key_pv_col),
360
  get_error_safe(cluster_pvs.compare_total(pvs_mort15), key_pv_col)
361
  ]
362
 
 
363
  summary_df = pd.DataFrame(error_data, index=['Base', 'Lapse+50%', 'Mort+15%'])
 
 
 
 
364
  title_suffix = f' ({key_pv_col})' if key_pv_col else ' (Mean Absolute Error)'
365
+ plot_title = f'Calibration Method Comparison - Error in Total PV{title_suffix}'
366
+ fig_width = 1000
367
+ fig_height = 600
368
+
369
+ summary_df_melted = summary_df.reset_index().melt(id_vars='index', var_name='Calibration Method', value_name='Absolute Error Rate')
370
+ summary_df_melted.rename(columns={'index': 'Scenario'}, inplace=True)
371
+
372
+
373
+ fig_summary = px.bar(
374
+ summary_df_melted,
375
+ x='Scenario',
376
+ y='Absolute Error Rate',
377
+ color='Calibration Method',
378
+ barmode='group',
379
+ title=plot_title
380
+ )
381
+ fig_summary.update_layout(
382
+ width=fig_width, height=fig_height,
383
+ xaxis_tickangle=0,
384
+ yaxis_title='Absolute Error Rate',
385
+ legend_title_text='Calibration Method'
386
+ )
387
+ fig_summary.update_yaxes(showgrid=True)
388
+
389
+ try:
390
+ # Requires kaleido: pip install kaleido
391
+ buf_summary_bytes = fig_summary.to_image(format="png", width=fig_width, height=fig_height)
392
+ buf_summary = io.BytesIO(buf_summary_bytes)
393
+ results['summary_plot'] = Image.open(buf_summary)
394
+ except Exception as e:
395
+ print(f"Error generating summary plot image with Plotly/Kaleido: {e}. Ensure Kaleido is installed.")
396
+ error_fig = go.Figure()
397
+ error_fig.add_annotation(text=f"Plot Error: {e}", showarrow=False)
398
+ error_fig.update_layout(width=fig_width, height=fig_height, title_text=plot_title)
399
+ img_bytes = error_fig.to_image(format="png", width=fig_width, height=fig_height)
400
+ results['summary_plot'] = Image.open(io.BytesIO(img_bytes))
401
+
402
  return results
403
 
404
  except FileNotFoundError as e:
 
409
  return {"error": f"Missing column: {e}"}
410
  except Exception as e:
411
  gr.Error(f"Error processing files: {str(e)}")
412
+ # Optionally log the full traceback for debugging
413
+ import traceback
414
+ traceback.print_exc()
415
  return {"error": f"Error processing files: {str(e)}"}
416
 
417
 
 
431
  - Present Values - Base Scenario
432
  - Present Values - Lapse Stress
433
  - Present Values - Mortality Stress
434
+
435
+ **Note:** Plot generation uses Plotly and Kaleido. If plots appear as errors, ensure Kaleido is installed (`pip install kaleido`).
436
  """)
437
 
438
  with gr.Row():
439
  with gr.Column(scale=1):
440
  gr.Markdown("### Upload Files or Load Examples")
 
441
  load_example_btn = gr.Button("Load Example Data")
 
442
  with gr.Row():
443
  cashflow_base_input = gr.File(label="Cashflows - Base", file_types=[".xlsx"])
444
  cashflow_lapse_input = gr.File(label="Cashflows - Lapse Stress", file_types=[".xlsx"])
 
449
  pv_lapse_input = gr.File(label="Present Values - Lapse Stress", file_types=[".xlsx"])
450
  with gr.Row():
451
  pv_mort_input = gr.File(label="Present Values - Mortality Stress", file_types=[".xlsx"])
 
452
  analyze_btn = gr.Button("Analyze Dataset", variant="primary", size="lg")
453
 
454
  with gr.Tabs():
455
  with gr.TabItem("📊 Summary"):
456
+ summary_plot_output = gr.Image(label="Calibration Methods Comparison") # Stays as gr.Image
457
 
458
  with gr.TabItem("💸 Cashflow Calibration"):
459
  gr.Markdown("### Results: Using Annual Cashflows as Calibration Variables")
 
478
  with gr.Accordion("Present Value Comparisons (Total)", open=False):
479
  attr_total_pv_base_out = gr.Dataframe(label="PVs - Base Scenario Total")
480
 
 
481
  with gr.TabItem("💰 Present Value Calibration"):
482
  gr.Markdown("### Results: Using Present Values (Base Scenario) as Calibration Variables")
483
  with gr.Row():
 
491
  pv_total_pv_lapse_out = gr.Dataframe(label="PVs - Lapse Stress Total")
492
  pv_total_pv_mort_out = gr.Dataframe(label="PVs - Mortality Stress Total")
493
 
 
494
  def get_all_output_components():
495
  return [
496
  summary_plot_output,
 
497
  cf_total_base_table_out, cf_policy_attrs_total_out,
498
  cf_cashflow_plot_out, cf_scatter_cashflows_base_out,
499
  cf_pv_total_base_out, cf_pv_total_lapse_out, cf_pv_total_mort_out,
 
500
  attr_total_cf_base_out, attr_policy_attrs_total_out,
501
  attr_cashflow_plot_out, attr_scatter_cashflows_base_out, attr_total_pv_base_out,
 
502
  pv_total_cf_base_out, pv_policy_attrs_total_out,
503
  pv_cashflow_plot_out, pv_scatter_pvs_base_out,
504
  pv_total_pv_base_out, pv_total_pv_lapse_out, pv_total_pv_mort_out
505
  ]
506
 
 
507
  def handle_analysis(f1, f2, f3, f4, f5, f6, f7):
508
  files = [f1, f2, f3, f4, f5, f6, f7]
 
509
  file_paths = []
510
  for i, f_obj in enumerate(files):
511
  if f_obj is None:
512
  gr.Error(f"Missing file input for argument {i+1}. Please upload all files or load examples.")
513
  return [None] * len(get_all_output_components())
 
 
514
  if hasattr(f_obj, 'name') and isinstance(f_obj.name, str):
515
  file_paths.append(f_obj.name)
 
516
  elif isinstance(f_obj, str):
517
  file_paths.append(f_obj)
518
  else:
519
  gr.Error(f"Invalid file input for argument {i+1}. Type: {type(f_obj)}")
520
  return [None] * len(get_all_output_components())
521
 
 
522
  results = process_files(*file_paths)
523
 
524
+ if "error" in results:
525
  return [None] * len(get_all_output_components())
526
 
527
  return [
528
  results.get('summary_plot'),
 
529
  results.get('cf_total_base_table'), results.get('cf_policy_attrs_total'),
530
  results.get('cf_cashflow_plot'), results.get('cf_scatter_cashflows_base'),
531
  results.get('cf_pv_total_base'), results.get('cf_pv_total_lapse'), results.get('cf_pv_total_mort'),
 
532
  results.get('attr_total_cf_base'), results.get('attr_policy_attrs_total'),
533
  results.get('attr_cashflow_plot'), results.get('attr_scatter_cashflows_base'), results.get('attr_total_pv_base'),
 
534
  results.get('pv_total_cf_base'), results.get('pv_policy_attrs_total'),
535
  results.get('pv_cashflow_plot'), results.get('pv_scatter_pvs_base'),
536
  results.get('pv_total_pv_base'), results.get('pv_total_pv_lapse'), results.get('pv_total_pv_mort')
 
543
  outputs=get_all_output_components()
544
  )
545
 
 
546
  def load_example_files():
547
  missing_files = [fp for fp in EXAMPLE_FILES.values() if not os.path.exists(fp)]
548
  if missing_files:
549
  gr.Error(f"Missing example data files in '{EXAMPLE_DATA_DIR}': {', '.join(missing_files)}. Please ensure they exist.")
550
+ return [None] * 7
551
 
552
  gr.Info("Example data paths loaded. Click 'Analyze Dataset'.")
553
  return [