Spaces:
Sleeping
Sleeping
File size: 6,100 Bytes
4d654cd 06a2968 4d654cd 06a2968 4d654cd af531ed 4d654cd 06a2968 e02d280 06a2968 e02d280 06a2968 4d654cd 06a2968 4d654cd 06a2968 4d654cd 06a2968 4d654cd 06a2968 4d654cd e02d280 4d654cd e02d280 4d654cd e02d280 4d654cd e02d280 4d654cd e02d280 4d654cd 06a2968 e02d280 06a2968 e02d280 4d654cd e02d280 4d654cd 06a2968 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import json, joblib, pandas as pd, numpy as np
import gradio as gr
import seaborn as sns
import matplotlib.pyplot as plt
# Load model + metadata + dataset
model = joblib.load("maternal_rf_model.joblib")
with open("maternal_metadata.json","r",encoding="utf-8") as f:
meta = json.load(f)
try:
df_clean = pd.read_csv("maternal_cleaned.csv")
except FileNotFoundError:
df_clean = None
numeric_features = meta["numeric_features"]
categorical_features = meta["categorical_features"]
# ---------- Prediction history ----------
prediction_history = []
# ---------- Prediction function ----------
def predict_risk(age, gravida, gest_weeks, weight, height_cm,
bp_sys, bp_dias, fetal_hr,
anaemia, jaundice, fetal_position, fetal_movement,
urine_albumin, urine_sugar):
row = {
"Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks,
"WeightKg": weight, "HeightCm": height_cm,
"BP_Systolic": bp_sys, "BP_Diastolic": bp_dias,
"FetalHR": fetal_hr,
"Anaemia": anaemia, "Jaundice": jaundice,
"FetalPosition": fetal_position, "FetalMovement": fetal_movement,
"UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar
}
X = pd.DataFrame([row], columns=numeric_features+categorical_features)
prob = model.predict_proba(X)[:,1][0]
pred = int(model.predict(X)[0])
label = "High Risk" if pred==1 else "Not High Risk"
# Save to history
history_row = row.copy()
history_row["Prediction"] = label
history_row["Probability"] = round(prob, 4)
prediction_history.append(history_row)
return {"Prediction": label, "Probability_high_risk": round(prob,4)}
# ---------- Plot functions ----------
def plot_age_distribution():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
ax.set_title("Age Distribution")
return fig
def plot_risk_counts():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
ax.set_title("High Risk vs Non-Risk Counts")
return fig
def plot_gestation_box():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
ax.set_title("Gestation Weeks vs Risk")
return fig
def plot_feature_importance():
ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"]
cat_names = ohe.get_feature_names_out(categorical_features)
feature_names = numeric_features + list(cat_names)
importances = model.named_steps["clf"].feature_importances_
feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances})
feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10)
fig, ax = plt.subplots(figsize=(8,5))
sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis")
ax.set_title("Top 10 Feature Importances")
return fig
def plot_corr_heatmap():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(8,6))
corr = df_clean[numeric_features+["HighRisk"]].corr()
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
ax.set_title("Correlation Heatmap")
return fig
# ---------- History update ----------
def update_history():
return pd.DataFrame(prediction_history)
# ---------- Gradio UI ----------
with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo:
gr.Markdown("## Maternal Risk Prediction Dashboard")
with gr.Tab("Prediction"):
gr.Markdown("Enter maternal health parameters to predict risk.")
with gr.Row():
age = gr.Number(label="Age")
gravida = gr.Number(label="Gravida")
gest = gr.Number(label="Gestation Weeks")
weight = gr.Number(label="Weight (kg)")
height = gr.Number(label="Height (cm)")
with gr.Row():
bp_sys = gr.Number(label="BP Systolic")
bp_dias = gr.Number(label="BP Diastolic")
fetal_hr = gr.Number(label="Fetal Heart Rate")
anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia")
jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice")
with gr.Row():
fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position")
fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement")
urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin")
urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar")
out = gr.JSON(label="Result")
btn = gr.Button("Predict Risk")
btn.click(predict_risk,
inputs=[age,gravida,gest,weight,height,
bp_sys,bp_dias,fetal_hr,
anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug],
outputs=out)
with gr.Tab("Data Insights"):
gr.Plot(plot_age_distribution)
gr.Plot(plot_risk_counts)
gr.Plot(plot_gestation_box)
with gr.Tab("Model Insights"):
gr.Plot(plot_feature_importance)
gr.Plot(plot_corr_heatmap)
with gr.Tab("Prediction History"):
history_table = gr.DataFrame(label="Prediction History", interactive=False)
refresh_btn = gr.Button("Refresh History")
refresh_btn.click(fn=update_history, outputs=history_table)
with gr.Tab("About"):
gr.Markdown("""
### About this App
This dashboard predicts maternal high-risk pregnancy using a RandomForest model.
- **Dataset:** Cleaned maternal health records
- **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar
- **Output:** High Risk vs Not High Risk with probability
""")
demo.launch() |