Tani21 commited on
Commit
e02d280
·
verified ·
1 Parent(s): 47ae043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -6
app.py CHANGED
@@ -4,14 +4,21 @@ import seaborn as sns
4
  import matplotlib.pyplot as plt
5
 
6
  # Load model + metadata + dataset
7
- model = joblib.load("maternal_rf_model.joblib")
8
  with open("maternal_metadata.json","r",encoding="utf-8") as f:
9
  meta = json.load(f)
10
- df_clean = pd.read_csv("maternal_cleaned.csv")
 
 
 
 
11
 
12
  numeric_features = meta["numeric_features"]
13
  categorical_features = meta["categorical_features"]
14
 
 
 
 
15
  # ---------- Prediction function ----------
16
  def predict_risk(age, gravida, gest_weeks, weight, height_cm,
17
  bp_sys, bp_dias, fetal_hr,
@@ -30,22 +37,32 @@ def predict_risk(age, gravida, gest_weeks, weight, height_cm,
30
  prob = model.predict_proba(X)[:,1][0]
31
  pred = int(model.predict(X)[0])
32
  label = "High Risk" if pred==1 else "Not High Risk"
33
- return {"Prediction":label,"Probability_high_risk":round(prob,4)}
 
 
 
 
 
 
 
34
 
35
  # ---------- Plot functions ----------
36
  def plot_age_distribution():
 
37
  fig, ax = plt.subplots(figsize=(6,4))
38
  sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
39
  ax.set_title("Age Distribution")
40
  return fig
41
 
42
  def plot_risk_counts():
 
43
  fig, ax = plt.subplots(figsize=(6,4))
44
  sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
45
  ax.set_title("High Risk vs Non-Risk Counts")
46
  return fig
47
 
48
  def plot_gestation_box():
 
49
  fig, ax = plt.subplots(figsize=(6,4))
50
  sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
51
  ax.set_title("Gestation Weeks vs Risk")
@@ -64,14 +81,19 @@ def plot_feature_importance():
64
  return fig
65
 
66
  def plot_corr_heatmap():
 
67
  fig, ax = plt.subplots(figsize=(8,6))
68
  corr = df_clean[numeric_features+["HighRisk"]].corr()
69
  sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
70
  ax.set_title("Correlation Heatmap")
71
  return fig
72
 
 
 
 
 
73
  # ---------- Gradio UI ----------
74
- with gr.Blocks(title="Maternal Risk Prediction") as demo:
75
  gr.Markdown("## Maternal Risk Prediction Dashboard")
76
 
77
  with gr.Tab("Prediction"):
@@ -102,16 +124,19 @@ with gr.Blocks(title="Maternal Risk Prediction") as demo:
102
  outputs=out)
103
 
104
  with gr.Tab("Data Insights"):
105
- gr.Markdown("### Dataset Overview")
106
  gr.Plot(plot_age_distribution)
107
  gr.Plot(plot_risk_counts)
108
  gr.Plot(plot_gestation_box)
109
 
110
  with gr.Tab("Model Insights"):
111
- gr.Markdown("### Model Behavior")
112
  gr.Plot(plot_feature_importance)
113
  gr.Plot(plot_corr_heatmap)
114
 
 
 
 
 
 
115
  with gr.Tab("About"):
116
  gr.Markdown("""
117
  ### About this App
 
4
  import matplotlib.pyplot as plt
5
 
6
  # Load model + metadata + dataset
7
+ model = joblib.load("maternal_risk_model.joblib")
8
  with open("maternal_metadata.json","r",encoding="utf-8") as f:
9
  meta = json.load(f)
10
+
11
+ try:
12
+ df_clean = pd.read_csv("maternal_cleaned.csv")
13
+ except FileNotFoundError:
14
+ df_clean = None
15
 
16
  numeric_features = meta["numeric_features"]
17
  categorical_features = meta["categorical_features"]
18
 
19
+ # ---------- Prediction history ----------
20
+ prediction_history = []
21
+
22
  # ---------- Prediction function ----------
23
  def predict_risk(age, gravida, gest_weeks, weight, height_cm,
24
  bp_sys, bp_dias, fetal_hr,
 
37
  prob = model.predict_proba(X)[:,1][0]
38
  pred = int(model.predict(X)[0])
39
  label = "High Risk" if pred==1 else "Not High Risk"
40
+
41
+ # Save to history
42
+ history_row = row.copy()
43
+ history_row["Prediction"] = label
44
+ history_row["Probability"] = round(prob, 4)
45
+ prediction_history.append(history_row)
46
+
47
+ return {"Prediction": label, "Probability_high_risk": round(prob,4)}
48
 
49
  # ---------- Plot functions ----------
50
  def plot_age_distribution():
51
+ if df_clean is None: return plt.figure()
52
  fig, ax = plt.subplots(figsize=(6,4))
53
  sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
54
  ax.set_title("Age Distribution")
55
  return fig
56
 
57
  def plot_risk_counts():
58
+ if df_clean is None: return plt.figure()
59
  fig, ax = plt.subplots(figsize=(6,4))
60
  sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
61
  ax.set_title("High Risk vs Non-Risk Counts")
62
  return fig
63
 
64
  def plot_gestation_box():
65
+ if df_clean is None: return plt.figure()
66
  fig, ax = plt.subplots(figsize=(6,4))
67
  sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
68
  ax.set_title("Gestation Weeks vs Risk")
 
81
  return fig
82
 
83
  def plot_corr_heatmap():
84
+ if df_clean is None: return plt.figure()
85
  fig, ax = plt.subplots(figsize=(8,6))
86
  corr = df_clean[numeric_features+["HighRisk"]].corr()
87
  sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
88
  ax.set_title("Correlation Heatmap")
89
  return fig
90
 
91
+ # ---------- History update ----------
92
+ def update_history():
93
+ return pd.DataFrame(prediction_history)
94
+
95
  # ---------- Gradio UI ----------
96
+ with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo:
97
  gr.Markdown("## Maternal Risk Prediction Dashboard")
98
 
99
  with gr.Tab("Prediction"):
 
124
  outputs=out)
125
 
126
  with gr.Tab("Data Insights"):
 
127
  gr.Plot(plot_age_distribution)
128
  gr.Plot(plot_risk_counts)
129
  gr.Plot(plot_gestation_box)
130
 
131
  with gr.Tab("Model Insights"):
 
132
  gr.Plot(plot_feature_importance)
133
  gr.Plot(plot_corr_heatmap)
134
 
135
+ with gr.Tab("Prediction History"):
136
+ history_table = gr.DataFrame(label="Prediction History", interactive=False)
137
+ refresh_btn = gr.Button("Refresh History")
138
+ refresh_btn.click(fn=update_history, outputs=history_table)
139
+
140
  with gr.Tab("About"):
141
  gr.Markdown("""
142
  ### About this App