alidenewade commited on
Commit
f647840
·
verified ·
1 Parent(s): b9e8b82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -99
app.py CHANGED
@@ -2,117 +2,146 @@ import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
  from sklearn.cluster import KMeans
5
- from sklearn.metrics import r2_score
6
  import matplotlib.pyplot as plt
7
  import io
 
8
 
9
- def cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters):
10
- # Basic checks and reads
11
  try:
12
- policy_df = pd.read_excel(policy_file.name)
13
- cashflow_df = pd.read_excel(cashflow_file.name, index_col=0)
14
- pv_df = pd.read_excel(pv_file.name, index_col=0)
15
- except Exception as e:
16
- return (None, None, None, f"Error reading files: {e}")
17
-
18
- # Use policy attributes clustering as simple example
19
- required_cols = ['IssueAge', 'PolicyTerm', 'SumAssured', 'Duration']
20
- if not all(col in policy_df.columns for col in required_cols):
21
- return (None, None, None, f"Policy data missing required columns: {required_cols}")
22
-
23
- X = policy_df[required_cols].fillna(0)
24
- X_scaled = (X - X.mean()) / X.std()
25
-
26
- # Cluster
27
- try:
28
- kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  kmeans.fit(X_scaled)
30
- policy_df['Cluster'] = kmeans.labels_
31
- except Exception as e:
32
- return (None, None, None, f"Clustering error: {e}")
33
-
34
- # Select model points as closest to cluster centers
35
- from sklearn.metrics import pairwise_distances_argmin_min
36
- centers = kmeans.cluster_centers_
37
- closest, _ = pairwise_distances_argmin_min(centers, X_scaled)
38
- model_points = policy_df.iloc[closest].copy()
39
-
40
- # Calculate weights (count per cluster)
41
- counts = policy_df['Cluster'].value_counts()
42
- model_points['Weight'] = model_points['Cluster'].map(counts)
43
-
44
- # Create CSV for download
45
- csv_buffer = io.StringIO()
46
- model_points.to_csv(csv_buffer, index=False)
47
- csv_data = csv_buffer.getvalue()
48
-
49
- # Aggregate cashflows weighted by cluster counts
50
- proxy_cashflows = cashflow_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum()
51
- seriatim_cashflows = cashflow_df.sum()
52
-
53
- # Plot aggregated cashflows
54
- fig, ax = plt.subplots(figsize=(8,4))
55
- seriatim_cashflows.plot(ax=ax, label='Seriatim Cashflows')
56
- proxy_cashflows.plot(ax=ax, label='Proxy Cashflows', linestyle='--')
57
- ax.set_title('Aggregated Cashflows Comparison')
58
- ax.legend()
59
- ax.grid(True)
60
-
61
- buf = io.BytesIO()
62
- plt.savefig(buf, format='png')
63
- plt.close(fig)
64
- buf.seek(0)
65
- cashflow_plot = buf.read()
66
-
67
- # Aggregate present values weighted
68
- proxy_pv = pv_df.loc[model_points.index].multiply(model_points['Weight'], axis=0).sum().values[0]
69
- seriatim_pv = pv_df.sum().values[0]
70
-
71
- # Present Value comparison plot (bar)
72
- fig2, ax2 = plt.subplots(figsize=(5,4))
73
- ax2.bar(['Seriatim PV', 'Proxy PV'], [seriatim_pv, proxy_pv], color=['blue', 'orange'])
74
- ax2.set_title('Aggregated Present Values')
75
- ax2.grid(axis='y')
76
-
77
- buf2 = io.BytesIO()
78
- plt.savefig(buf2, format='png')
79
- plt.close(fig2)
80
- buf2.seek(0)
81
- pv_plot = buf2.read()
82
-
83
- # Accuracy metrics
84
- common_idx = seriatim_cashflows.index.intersection(proxy_cashflows.index)
85
- r2 = r2_score(seriatim_cashflows.loc[common_idx], proxy_cashflows.loc[common_idx])
86
- pv_error = abs(proxy_pv - seriatim_pv) / seriatim_pv * 100 if seriatim_pv != 0 else float('inf')
87
-
88
- metrics_text = (
89
- f"R-squared for aggregated cashflows: {r2:.4f}\n"
90
- f"Absolute percentage error in present value: {pv_error:.4f}%"
91
- )
92
 
93
- return csv_data, cashflow_plot, pv_plot, metrics_text
 
94
 
 
95
  with gr.Blocks() as demo:
96
- gr.Markdown("# Actuarial Model Point Selection")
97
-
98
  with gr.Row():
99
  with gr.Column():
100
- policy_input = gr.File(label="Upload Policy Data (Excel)")
101
- cashflow_input = gr.File(label="Upload Cashflow Data (Excel)")
102
- pv_input = gr.File(label="Upload Present Value Data (Excel)")
103
- clusters_input = gr.Slider(minimum=2, maximum=100, step=1, value=10, label="Number of Model Points")
104
- run_btn = gr.Button("Run Clustering")
105
-
106
  with gr.Column():
107
- output_csv = gr.Textbox(label="Model Points CSV Output", lines=10)
108
- cashflow_img = gr.Image(label="Aggregated Cashflows Comparison")
109
- pv_img = gr.Image(label="Aggregated Present Values Comparison")
110
- metrics_box = gr.Textbox(label="Accuracy Metrics", lines=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  run_btn.click(
113
- cluster_analysis,
114
- inputs=[policy_input, cashflow_input, pv_input, clusters_input],
115
- outputs=[output_csv, cashflow_img, pv_img, metrics_box]
116
  )
117
 
118
- demo.launch(debug=True)
 
2
  import pandas as pd
3
  import numpy as np
4
  from sklearn.cluster import KMeans
5
+ from sklearn.metrics import r2_score, pairwise_distances_argmin_min
6
  import matplotlib.pyplot as plt
7
  import io
8
+ import os
9
 
10
+ def run_cluster_analysis(policy_file, cashflow_file, pv_file, num_clusters, cluster_type):
 
11
  try:
12
+ # Load data
13
+ policy_df = pd.read_excel(policy_file)
14
+ cashflow_df = pd.read_excel(cashflow_file, index_col=0)
15
+ pv_df = pd.read_excel(pv_file, index_col=0)
16
+
17
+ # Normalize column names for robustness
18
+ policy_df.columns = policy_df.columns.str.strip().str.lower()
19
+ pv_df.columns = pv_df.columns.str.strip().str.lower()
20
+
21
+ if cluster_type == "Policy Attributes":
22
+ required_cols = ['issueage', 'policyterm', 'sumassured', 'duration']
23
+ missing = [col for col in required_cols if col not in policy_df.columns]
24
+ if missing:
25
+ return (None, None, None, f"Policy data missing required columns: {missing}")
26
+ X = policy_df[required_cols].fillna(0)
27
+ elif cluster_type == "Net Cashflows":
28
+ X = cashflow_df.fillna(0)
29
+ elif cluster_type == "Present Values":
30
+ if 'pv_net_cf' not in pv_df.columns:
31
+ return (None, None, None, "Missing 'PV_Net_CF' column in PV file.")
32
+ X = pv_df[['pv_net_cf']].fillna(0)
33
+ else:
34
+ return (None, None, None, "Invalid clustering variable choice.")
35
+
36
+ # Scale data
37
+ X_scaled = (X - X.mean()) / X.std(ddof=0)
38
+ X_scaled = X_scaled.fillna(0)
39
+
40
+ # Run KMeans
41
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init='auto')
42
  kmeans.fit(X_scaled)
43
+ policy_df['cluster'] = kmeans.labels_
44
+
45
+ # Get closest policies (model points)
46
+ closest_idxs = pairwise_distances_argmin_min(kmeans.cluster_centers_, X_scaled)[0]
47
+ model_points = policy_df.iloc[closest_idxs].copy()
48
+ cluster_counts = policy_df['cluster'].value_counts()
49
+ model_points['weight'] = model_points['cluster'].map(cluster_counts)
50
+
51
+ # Aggregate comparisons
52
+ total_seriatim_cf = cashflow_df.sum(axis=0)
53
+ total_seriatim_pv = pv_df.sum(axis=0)
54
+ proxy_cf = cashflow_df.loc[model_points.index].multiply(model_points['weight'], axis=0).sum(axis=0)
55
+ proxy_pv = pv_df.loc[model_points.index].multiply(model_points['weight'], axis=0).sum(axis=0)
56
+
57
+ # Output CSV
58
+ csv_buf = io.StringIO()
59
+ model_points.to_csv(csv_buf, index=False)
60
+ csv_bytes = csv_buf.getvalue().encode()
61
+
62
+ # Cashflow plot
63
+ fig1, ax1 = plt.subplots()
64
+ total_seriatim_cf.plot(ax=ax1, label="Seriatim", color="blue")
65
+ proxy_cf.plot(ax=ax1, label="Proxy", linestyle="--", color="orange")
66
+ ax1.set_title("Aggregated Cashflows")
67
+ ax1.legend()
68
+ ax1.grid()
69
+ buf1 = io.BytesIO()
70
+ plt.savefig(buf1, format='png')
71
+ buf1.seek(0)
72
+ plt.close(fig1)
73
+
74
+ # PV plot
75
+ fig2, ax2 = plt.subplots()
76
+ pv_plot = pd.DataFrame({
77
+ "Seriatim PV": [total_seriatim_pv.iloc[0]],
78
+ "Proxy PV": [proxy_pv.iloc[0]]
79
+ })
80
+ pv_plot.plot(kind="bar", ax=ax2, color=["blue", "orange"])
81
+ ax2.set_title("Aggregated Present Values")
82
+ ax2.set_xticks([0])
83
+ ax2.set_xticklabels(["Total PV"])
84
+ ax2.grid(axis='y')
85
+ buf2 = io.BytesIO()
86
+ plt.savefig(buf2, format='png')
87
+ buf2.seek(0)
88
+ plt.close(fig2)
89
+
90
+ # Metrics
91
+ r2 = r2_score(total_seriatim_cf, proxy_cf)
92
+ pv_err = abs((proxy_pv.iloc[0] - total_seriatim_pv.iloc[0]) / total_seriatim_pv.iloc[0]) * 100
93
+ metrics = (
94
+ f"--- Accuracy Metrics ---\n"
95
+ f"R-squared (Cashflows): {r2:.4f}\n"
96
+ f"Absolute % Error (Present Value): {pv_err:.2f}%"
97
+ )
98
+
99
+ return csv_bytes, buf1, buf2, metrics
 
 
 
 
 
100
 
101
+ except Exception as e:
102
+ return (None, None, None, f"An error occurred: {str(e)}")
103
 
104
+ # Build UI
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("## Actuarial Model Point Selection via Cluster Analysis")
107
+
108
  with gr.Row():
109
  with gr.Column():
110
+ policy_file = gr.File(label="Upload Policy Data (.xlsx)", file_types=[".xlsx", ".xls"])
111
+ cashflow_file = gr.File(label="Upload Cashflow Data (.xlsx)", file_types=[".xlsx", ".xls"])
112
+ pv_file = gr.File(label="Upload Present Value Data (.xlsx)", file_types=[".xlsx", ".xls"])
113
+
 
 
114
  with gr.Column():
115
+ num_clusters = gr.Slider(10, 2000, value=1000, step=10, label="Number of Model Points (k)")
116
+ cluster_type = gr.Dropdown(
117
+ ["Policy Attributes", "Net Cashflows", "Present Values"],
118
+ value="Present Values",
119
+ label="Clustering Variable"
120
+ )
121
+ run_btn = gr.Button("Run Cluster Analysis")
122
+
123
+ with gr.Row():
124
+ output_csv = gr.File(label="Download Model Points (CSV)")
125
+ output_cf_plot = gr.Image(label="Cashflow Comparison")
126
+ output_pv_plot = gr.Image(label="PV Comparison")
127
+ output_metrics = gr.Textbox(label="Accuracy Metrics", lines=5)
128
+
129
+ def wrapper(policy_file, cashflow_file, pv_file, num_clusters, cluster_type):
130
+ csv_bytes, img_cf, img_pv, metrics = run_cluster_analysis(policy_file.name, cashflow_file.name, pv_file.name, num_clusters, cluster_type)
131
+
132
+ if csv_bytes is not None:
133
+ csv_path = "/tmp/model_points.csv"
134
+ with open(csv_path, "wb") as f:
135
+ f.write(csv_bytes)
136
+ else:
137
+ csv_path = None
138
+
139
+ return csv_path, img_cf, img_pv, metrics
140
 
141
  run_btn.click(
142
+ fn=wrapper,
143
+ inputs=[policy_file, cashflow_file, pv_file, num_clusters, cluster_type],
144
+ outputs=[output_csv, output_cf_plot, output_pv_plot, output_metrics]
145
  )
146
 
147
+ demo.launch()