Tani21's picture
Update app.py
af531ed verified
import json, joblib, pandas as pd, numpy as np
import gradio as gr
import seaborn as sns
import matplotlib.pyplot as plt
# Load model + metadata + dataset
model = joblib.load("maternal_rf_model.joblib")
with open("maternal_metadata.json","r",encoding="utf-8") as f:
meta = json.load(f)
try:
df_clean = pd.read_csv("maternal_cleaned.csv")
except FileNotFoundError:
df_clean = None
numeric_features = meta["numeric_features"]
categorical_features = meta["categorical_features"]
# ---------- Prediction history ----------
prediction_history = []
# ---------- Prediction function ----------
def predict_risk(age, gravida, gest_weeks, weight, height_cm,
bp_sys, bp_dias, fetal_hr,
anaemia, jaundice, fetal_position, fetal_movement,
urine_albumin, urine_sugar):
row = {
"Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks,
"WeightKg": weight, "HeightCm": height_cm,
"BP_Systolic": bp_sys, "BP_Diastolic": bp_dias,
"FetalHR": fetal_hr,
"Anaemia": anaemia, "Jaundice": jaundice,
"FetalPosition": fetal_position, "FetalMovement": fetal_movement,
"UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar
}
X = pd.DataFrame([row], columns=numeric_features+categorical_features)
prob = model.predict_proba(X)[:,1][0]
pred = int(model.predict(X)[0])
label = "High Risk" if pred==1 else "Not High Risk"
# Save to history
history_row = row.copy()
history_row["Prediction"] = label
history_row["Probability"] = round(prob, 4)
prediction_history.append(history_row)
return {"Prediction": label, "Probability_high_risk": round(prob,4)}
# ---------- Plot functions ----------
def plot_age_distribution():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
ax.set_title("Age Distribution")
return fig
def plot_risk_counts():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
ax.set_title("High Risk vs Non-Risk Counts")
return fig
def plot_gestation_box():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(6,4))
sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
ax.set_title("Gestation Weeks vs Risk")
return fig
def plot_feature_importance():
ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"]
cat_names = ohe.get_feature_names_out(categorical_features)
feature_names = numeric_features + list(cat_names)
importances = model.named_steps["clf"].feature_importances_
feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances})
feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10)
fig, ax = plt.subplots(figsize=(8,5))
sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis")
ax.set_title("Top 10 Feature Importances")
return fig
def plot_corr_heatmap():
if df_clean is None: return plt.figure()
fig, ax = plt.subplots(figsize=(8,6))
corr = df_clean[numeric_features+["HighRisk"]].corr()
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
ax.set_title("Correlation Heatmap")
return fig
# ---------- History update ----------
def update_history():
return pd.DataFrame(prediction_history)
# ---------- Gradio UI ----------
with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo:
gr.Markdown("## Maternal Risk Prediction Dashboard")
with gr.Tab("Prediction"):
gr.Markdown("Enter maternal health parameters to predict risk.")
with gr.Row():
age = gr.Number(label="Age")
gravida = gr.Number(label="Gravida")
gest = gr.Number(label="Gestation Weeks")
weight = gr.Number(label="Weight (kg)")
height = gr.Number(label="Height (cm)")
with gr.Row():
bp_sys = gr.Number(label="BP Systolic")
bp_dias = gr.Number(label="BP Diastolic")
fetal_hr = gr.Number(label="Fetal Heart Rate")
anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia")
jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice")
with gr.Row():
fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position")
fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement")
urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin")
urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar")
out = gr.JSON(label="Result")
btn = gr.Button("Predict Risk")
btn.click(predict_risk,
inputs=[age,gravida,gest,weight,height,
bp_sys,bp_dias,fetal_hr,
anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug],
outputs=out)
with gr.Tab("Data Insights"):
gr.Plot(plot_age_distribution)
gr.Plot(plot_risk_counts)
gr.Plot(plot_gestation_box)
with gr.Tab("Model Insights"):
gr.Plot(plot_feature_importance)
gr.Plot(plot_corr_heatmap)
with gr.Tab("Prediction History"):
history_table = gr.DataFrame(label="Prediction History", interactive=False)
refresh_btn = gr.Button("Refresh History")
refresh_btn.click(fn=update_history, outputs=history_table)
with gr.Tab("About"):
gr.Markdown("""
### About this App
This dashboard predicts maternal high-risk pregnancy using a RandomForest model.
- **Dataset:** Cleaned maternal health records
- **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar
- **Output:** High Risk vs Not High Risk with probability
""")
demo.launch()