File size: 6,100 Bytes
4d654cd
06a2968
4d654cd
 
06a2968
4d654cd
af531ed
4d654cd
06a2968
e02d280
 
 
 
 
06a2968
 
 
 
e02d280
 
 
06a2968
4d654cd
 
 
 
06a2968
4d654cd
 
 
06a2968
4d654cd
 
 
06a2968
4d654cd
 
06a2968
4d654cd
e02d280
 
 
 
 
 
 
 
4d654cd
 
 
e02d280
4d654cd
 
 
 
 
 
e02d280
4d654cd
 
 
 
 
 
e02d280
4d654cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e02d280
4d654cd
 
 
 
 
06a2968
e02d280
 
 
 
06a2968
e02d280
4d654cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e02d280
 
 
 
 
4d654cd
 
 
 
 
 
 
 
06a2968
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json, joblib, pandas as pd, numpy as np
import gradio as gr
import seaborn as sns
import matplotlib.pyplot as plt

# Load model + metadata + dataset
model = joblib.load("maternal_rf_model.joblib")
with open("maternal_metadata.json","r",encoding="utf-8") as f:
    meta = json.load(f)

try:
    df_clean = pd.read_csv("maternal_cleaned.csv")
except FileNotFoundError:
    df_clean = None

numeric_features = meta["numeric_features"]
categorical_features = meta["categorical_features"]

# ---------- Prediction history ----------
prediction_history = []

# ---------- Prediction function ----------
def predict_risk(age, gravida, gest_weeks, weight, height_cm,
                 bp_sys, bp_dias, fetal_hr,
                 anaemia, jaundice, fetal_position, fetal_movement,
                 urine_albumin, urine_sugar):
    row = {
        "Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks,
        "WeightKg": weight, "HeightCm": height_cm,
        "BP_Systolic": bp_sys, "BP_Diastolic": bp_dias,
        "FetalHR": fetal_hr,
        "Anaemia": anaemia, "Jaundice": jaundice,
        "FetalPosition": fetal_position, "FetalMovement": fetal_movement,
        "UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar
    }
    X = pd.DataFrame([row], columns=numeric_features+categorical_features)
    prob = model.predict_proba(X)[:,1][0]
    pred = int(model.predict(X)[0])
    label = "High Risk" if pred==1 else "Not High Risk"

    # Save to history
    history_row = row.copy()
    history_row["Prediction"] = label
    history_row["Probability"] = round(prob, 4)
    prediction_history.append(history_row)

    return {"Prediction": label, "Probability_high_risk": round(prob,4)}

# ---------- Plot functions ----------
def plot_age_distribution():
    if df_clean is None: return plt.figure()
    fig, ax = plt.subplots(figsize=(6,4))
    sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue")
    ax.set_title("Age Distribution")
    return fig

def plot_risk_counts():
    if df_clean is None: return plt.figure()
    fig, ax = plt.subplots(figsize=(6,4))
    sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2")
    ax.set_title("High Risk vs Non-Risk Counts")
    return fig

def plot_gestation_box():
    if df_clean is None: return plt.figure()
    fig, ax = plt.subplots(figsize=(6,4))
    sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2")
    ax.set_title("Gestation Weeks vs Risk")
    return fig

def plot_feature_importance():
    ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"]
    cat_names = ohe.get_feature_names_out(categorical_features)
    feature_names = numeric_features + list(cat_names)
    importances = model.named_steps["clf"].feature_importances_
    feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances})
    feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10)
    fig, ax = plt.subplots(figsize=(8,5))
    sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis")
    ax.set_title("Top 10 Feature Importances")
    return fig

def plot_corr_heatmap():
    if df_clean is None: return plt.figure()
    fig, ax = plt.subplots(figsize=(8,6))
    corr = df_clean[numeric_features+["HighRisk"]].corr()
    sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
    ax.set_title("Correlation Heatmap")
    return fig

# ---------- History update ----------
def update_history():
    return pd.DataFrame(prediction_history)

# ---------- Gradio UI ----------
with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo:
    gr.Markdown("## Maternal Risk Prediction Dashboard")

    with gr.Tab("Prediction"):
        gr.Markdown("Enter maternal health parameters to predict risk.")
        with gr.Row():
            age = gr.Number(label="Age")
            gravida = gr.Number(label="Gravida")
            gest = gr.Number(label="Gestation Weeks")
            weight = gr.Number(label="Weight (kg)")
            height = gr.Number(label="Height (cm)")
        with gr.Row():
            bp_sys = gr.Number(label="BP Systolic")
            bp_dias = gr.Number(label="BP Diastolic")
            fetal_hr = gr.Number(label="Fetal Heart Rate")
            anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia")
            jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice")
        with gr.Row():
            fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position")
            fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement")
            urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin")
            urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar")
        out = gr.JSON(label="Result")
        btn = gr.Button("Predict Risk")
        btn.click(predict_risk,
                  inputs=[age,gravida,gest,weight,height,
                          bp_sys,bp_dias,fetal_hr,
                          anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug],
                  outputs=out)

    with gr.Tab("Data Insights"):
        gr.Plot(plot_age_distribution)
        gr.Plot(plot_risk_counts)
        gr.Plot(plot_gestation_box)

    with gr.Tab("Model Insights"):
        gr.Plot(plot_feature_importance)
        gr.Plot(plot_corr_heatmap)

    with gr.Tab("Prediction History"):
        history_table = gr.DataFrame(label="Prediction History", interactive=False)
        refresh_btn = gr.Button("Refresh History")
        refresh_btn.click(fn=update_history, outputs=history_table)

    with gr.Tab("About"):
        gr.Markdown("""
        ### About this App
        This dashboard predicts maternal high-risk pregnancy using a RandomForest model.
        - **Dataset:** Cleaned maternal health records
        - **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar
        - **Output:** High Risk vs Not High Risk with probability
        """)

demo.launch()