Tani21 commited on
Commit
4d654cd
·
verified ·
1 Parent(s): a4fd28e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -130
app.py CHANGED
@@ -1,147 +1,124 @@
1
- import json
2
- import joblib
3
- import numpy as np
4
- import pandas as pd
5
  import gradio as gr
 
 
6
 
7
- # ---------- Load artifacts ----------
8
- MODEL_PATH = "maternal_rf_model.joblib"
9
- META_PATH = "maternal_metadata.json"
10
- DATA_PATH = "maternal_cleaned.csv" # optional, for example defaults or sanity checks
11
-
12
- model = joblib.load(MODEL_PATH)
13
-
14
- with open(META_PATH, "r", encoding="utf-8") as f:
15
  meta = json.load(f)
 
16
 
17
  numeric_features = meta["numeric_features"]
18
  categorical_features = meta["categorical_features"]
19
- target_col = meta["target"]
20
-
21
- # Optional: load cleaned dataset to compute sensible defaults/ranges
22
- try:
23
- df_clean = pd.read_csv(DATA_PATH)
24
- except Exception:
25
- df_clean = None
26
-
27
- # ---------- Define categorical options ----------
28
- # Ensure these match your training preprocessing categories
29
- ANAEMIA_OPTS = ["None", "Minimal", "Medium", "Higher"]
30
- JAUNDICE_OPTS = ["None", "Minimal", "Medium"]
31
- FETAL_POSITION_OPTS = ["Normal", "Abnormal"]
32
- FETAL_MOVEMENT_OPTS = ["Yes", "No"]
33
- URINE_ALBUMIN_OPTS = ["Negative", "Positive"]
34
- URINE_SUGAR_OPTS = ["Negative", "Positive"]
35
-
36
- # ---------- Defaults from dataset (median or most frequent) ----------
37
- def default_num(name, fallback=0.0):
38
- if df_clean is not None and name in df_clean.columns:
39
- return float(np.nanmedian(df_clean[name].values))
40
- return float(fallback)
41
-
42
- def default_cat(name, options, fallback=None):
43
- if df_clean is not None and name in df_clean.columns:
44
- mode = df_clean[name].dropna().astype(str).mode()
45
- if len(mode) > 0 and mode[0] in options:
46
- return mode[0]
47
- return fallback or options[0]
48
-
49
- DEFAULTS = {
50
- "Age": default_num("Age", 22),
51
- "Gravida": default_num("Gravida", 1),
52
- "GestationWeeks": default_num("GestationWeeks", 30),
53
- "WeightKg": default_num("WeightKg", 56),
54
- "HeightCm": default_num("HeightCm", 160),
55
- "BP_Systolic": default_num("BP_Systolic", 100),
56
- "BP_Diastolic": default_num("BP_Diastolic", 60),
57
- "FetalHR": default_num("FetalHR", 140),
58
- "Anaemia": default_cat("Anaemia", ANAEMIA_OPTS, "None"),
59
- "Jaundice": default_cat("Jaundice", JAUNDICE_OPTS, "None"),
60
- "FetalPosition": default_cat("FetalPosition", FETAL_POSITION_OPTS, "Normal"),
61
- "FetalMovement": default_cat("FetalMovement", FETAL_MOVEMENT_OPTS, "Yes"),
62
- "UrineAlbumin": default_cat("UrineAlbumin", URINE_ALBUMIN_OPTS, "Negative"),
63
- "UrineSugar": default_cat("UrineSugar", URINE_SUGAR_OPTS, "Negative"),
64
- }
65
 
66
  # ---------- Prediction function ----------
67
- def predict_risk(
68
- age, gravida, gest_weeks, weight, height_cm,
69
- bp_sys, bp_dias, fetal_hr,
70
- anaemia, jaundice, fetal_position, fetal_movement, urine_albumin, urine_sugar
71
- ):
72
- # Build a single-row DataFrame with exact column order
73
  row = {
74
- "Age": age,
75
- "Gravida": gravida,
76
- "GestationWeeks": gest_weeks,
77
- "WeightKg": weight,
78
- "HeightCm": height_cm,
79
- "BP_Systolic": bp_sys,
80
- "BP_Diastolic": bp_dias,
81
  "FetalHR": fetal_hr,
82
- "Anaemia": anaemia,
83
- "Jaundice": jaundice,
84
- "FetalPosition": fetal_position,
85
- "FetalMovement": fetal_movement,
86
- "UrineAlbumin": urine_albumin,
87
- "UrineSugar": urine_sugar,
88
  }
89
- X = pd.DataFrame([row], columns=numeric_features + categorical_features)
90
-
91
- # Predict
92
- prob = None
93
- try:
94
- prob = model.predict_proba(X)[:, 1][0]
95
- except Exception:
96
- # If model lacks predict_proba (shouldn’t happen for RandomForest), fallback
97
- prob = float(model.predict(X)[0])
98
-
99
  pred = int(model.predict(X)[0])
100
- label = "High Risk" if pred == 1 else "Not High Risk"
101
-
102
- # Friendly output with rounded probability
103
- return {
104
- "Prediction": label,
105
- "Probability_high_risk": round(float(prob), 4)
106
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # ---------- Gradio UI ----------
109
  with gr.Blocks(title="Maternal Risk Prediction") as demo:
110
- gr.Markdown(
111
- "## Maternal Risk Prediction\n"
112
- "Enter clinical inputs to estimate high-risk pregnancy likelihood. "
113
- "This tool uses a trained RandomForest model."
114
- )
115
-
116
- with gr.Row():
117
- with gr.Column():
118
- age_in = gr.Number(label="Age (years)", value=DEFAULTS["Age"])
119
- gravida_in = gr.Number(label="Gravida (1/2/3)", value=DEFAULTS["Gravida"])
120
- gest_in = gr.Number(label="Gestation Weeks", value=DEFAULTS["GestationWeeks"])
121
- weight_in = gr.Number(label="Weight (kg)", value=DEFAULTS["WeightKg"])
122
- height_in = gr.Number(label="Height (cm)", value=DEFAULTS["HeightCm"])
123
- with gr.Column():
124
- bp_sys_in = gr.Number(label="BP Systolic (mmHg)", value=DEFAULTS["BP_Systolic"])
125
- bp_dias_in = gr.Number(label="BP Diastolic (mmHg)", value=DEFAULTS["BP_Diastolic"])
126
- fetal_hr_in = gr.Number(label="Fetal Heart Rate (bpm)", value=DEFAULTS["FetalHR"])
127
- anaemia_in = gr.Dropdown(ANAEMIA_OPTS, label="Anaemia", value=DEFAULTS["Anaemia"])
128
- jaundice_in = gr.Dropdown(JAUNDICE_OPTS, label="Jaundice", value=DEFAULTS["Jaundice"])
129
- with gr.Column():
130
- fetal_pos_in = gr.Dropdown(FETAL_POSITION_OPTS, label="Fetal Position", value=DEFAULTS["FetalPosition"])
131
- fetal_mov_in = gr.Dropdown(FETAL_MOVEMENT_OPTS, label="Fetal Movement", value=DEFAULTS["FetalMovement"])
132
- urine_alb_in = gr.Dropdown(URINE_ALBUMIN_OPTS, label="Urine Albumin", value=DEFAULTS["UrineAlbumin"])
133
- urine_sug_in = gr.Dropdown(URINE_SUGAR_OPTS, label="Urine Sugar", value=DEFAULTS["UrineSugar"])
134
-
135
- predict_btn = gr.Button("Predict Risk")
136
-
137
- out_json = gr.JSON(label="Result")
138
-
139
- predict_btn.click(
140
- predict_risk,
141
- inputs=[age_in, gravida_in, gest_in, weight_in, height_in,
142
- bp_sys_in, bp_dias_in, fetal_hr_in,
143
- anaemia_in, jaundice_in, fetal_pos_in, fetal_mov_in, urine_alb_in, urine_sug_in],
144
- outputs=[out_json]
145
- )
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  demo.launch()
 
1
+ import json, joblib, pandas as pd, numpy as np
 
 
 
2
  import gradio as gr
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
 
6
+ # Load model + metadata + dataset
7
+ model = joblib.load("maternal_rf_model.joblib")
8
+ with open("maternal_metadata.json","r",encoding="utf-8") as f:
 
 
 
 
 
9
  meta = json.load(f)
10
+ df_clean = pd.read_csv("maternal_cleaned.csv")
11
 
12
  numeric_features = meta["numeric_features"]
13
  categorical_features = meta["categorical_features"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # ---------- Prediction function ----------
16
+ def predict_risk(age, gravida, gest_weeks, weight, height_cm,
17
+ bp_sys, bp_dias, fetal_hr,
18
+ anaemia, jaundice, fetal_position, fetal_movement,
19
+ urine_albumin, urine_sugar):
 
 
20
  row = {
21
+ "Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks,
22
+ "WeightKg": weight, "HeightCm": height_cm,
23
+ "BP_Systolic": bp_sys, "BP_Diastolic": bp_dias,
 
 
 
 
24
  "FetalHR": fetal_hr,
25
+ "Anaemia": anaemia, "Jaundice": jaundice,
26
+ "FetalPosition": fetal_position, "FetalMovement": fetal_movement,
27
+ "UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar
 
 
 
28
  }
29
+ X = pd.DataFrame([row], columns=numeric_features+categorical_features)
30
+ prob = model.predict_proba(X)[:,1][0]
 
 
 
 
 
 
 
 
31
  pred = int(model.predict(X)[0])
32
+ label = "High Risk" if pred==1 else "Not High Risk"
33
+ return {"Prediction":label,"Probability_high_risk":round(prob,4)}
34
+
35
+ # ---------- Plot functions ----------
36
+ def plot_age_distribution():
37
+ fig, ax = plt.subplots(figsize=(6,4))
38
+ sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
39
+ ax.set_title("Age Distribution")
40
+ return fig
41
+
42
+ def plot_risk_counts():
43
+ fig, ax = plt.subplots(figsize=(6,4))
44
+ sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
45
+ ax.set_title("High Risk vs Non-Risk Counts")
46
+ return fig
47
+
48
+ def plot_gestation_box():
49
+ fig, ax = plt.subplots(figsize=(6,4))
50
+ sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
51
+ ax.set_title("Gestation Weeks vs Risk")
52
+ return fig
53
+
54
+ def plot_feature_importance():
55
+ ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"]
56
+ cat_names = ohe.get_feature_names_out(categorical_features)
57
+ feature_names = numeric_features + list(cat_names)
58
+ importances = model.named_steps["clf"].feature_importances_
59
+ feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances})
60
+ feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10)
61
+ fig, ax = plt.subplots(figsize=(8,5))
62
+ sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis")
63
+ ax.set_title("Top 10 Feature Importances")
64
+ return fig
65
+
66
+ def plot_corr_heatmap():
67
+ fig, ax = plt.subplots(figsize=(8,6))
68
+ corr = df_clean[numeric_features+["HighRisk"]].corr()
69
+ sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
70
+ ax.set_title("Correlation Heatmap")
71
+ return fig
72
 
73
  # ---------- Gradio UI ----------
74
  with gr.Blocks(title="Maternal Risk Prediction") as demo:
75
+ gr.Markdown("## Maternal Risk Prediction Dashboard")
76
+
77
+ with gr.Tab("Prediction"):
78
+ gr.Markdown("Enter maternal health parameters to predict risk.")
79
+ with gr.Row():
80
+ age = gr.Number(label="Age")
81
+ gravida = gr.Number(label="Gravida")
82
+ gest = gr.Number(label="Gestation Weeks")
83
+ weight = gr.Number(label="Weight (kg)")
84
+ height = gr.Number(label="Height (cm)")
85
+ with gr.Row():
86
+ bp_sys = gr.Number(label="BP Systolic")
87
+ bp_dias = gr.Number(label="BP Diastolic")
88
+ fetal_hr = gr.Number(label="Fetal Heart Rate")
89
+ anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia")
90
+ jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice")
91
+ with gr.Row():
92
+ fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position")
93
+ fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement")
94
+ urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin")
95
+ urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar")
96
+ out = gr.JSON(label="Result")
97
+ btn = gr.Button("Predict Risk")
98
+ btn.click(predict_risk,
99
+ inputs=[age,gravida,gest,weight,height,
100
+ bp_sys,bp_dias,fetal_hr,
101
+ anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug],
102
+ outputs=out)
103
+
104
+ with gr.Tab("Data Insights"):
105
+ gr.Markdown("### Dataset Overview")
106
+ gr.Plot(plot_age_distribution)
107
+ gr.Plot(plot_risk_counts)
108
+ gr.Plot(plot_gestation_box)
109
+
110
+ with gr.Tab("Model Insights"):
111
+ gr.Markdown("### Model Behavior")
112
+ gr.Plot(plot_feature_importance)
113
+ gr.Plot(plot_corr_heatmap)
114
+
115
+ with gr.Tab("About"):
116
+ gr.Markdown("""
117
+ ### About this App
118
+ This dashboard predicts maternal high-risk pregnancy using a RandomForest model.
119
+ - **Dataset:** Cleaned maternal health records
120
+ - **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar
121
+ - **Output:** High Risk vs Not High Risk with probability
122
+ """)
123
 
124
  demo.launch()