Spaces:
Sleeping
Sleeping
| import json, joblib, pandas as pd, numpy as np | |
| import gradio as gr | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| # Load model + metadata + dataset | |
| model = joblib.load("maternal_rf_model.joblib") | |
| with open("maternal_metadata.json","r",encoding="utf-8") as f: | |
| meta = json.load(f) | |
| try: | |
| df_clean = pd.read_csv("maternal_cleaned.csv") | |
| except FileNotFoundError: | |
| df_clean = None | |
| numeric_features = meta["numeric_features"] | |
| categorical_features = meta["categorical_features"] | |
| # ---------- Prediction history ---------- | |
| prediction_history = [] | |
| # ---------- Prediction function ---------- | |
| def predict_risk(age, gravida, gest_weeks, weight, height_cm, | |
| bp_sys, bp_dias, fetal_hr, | |
| anaemia, jaundice, fetal_position, fetal_movement, | |
| urine_albumin, urine_sugar): | |
| row = { | |
| "Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks, | |
| "WeightKg": weight, "HeightCm": height_cm, | |
| "BP_Systolic": bp_sys, "BP_Diastolic": bp_dias, | |
| "FetalHR": fetal_hr, | |
| "Anaemia": anaemia, "Jaundice": jaundice, | |
| "FetalPosition": fetal_position, "FetalMovement": fetal_movement, | |
| "UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar | |
| } | |
| X = pd.DataFrame([row], columns=numeric_features+categorical_features) | |
| prob = model.predict_proba(X)[:,1][0] | |
| pred = int(model.predict(X)[0]) | |
| label = "High Risk" if pred==1 else "Not High Risk" | |
| # Save to history | |
| history_row = row.copy() | |
| history_row["Prediction"] = label | |
| history_row["Probability"] = round(prob, 4) | |
| prediction_history.append(history_row) | |
| return {"Prediction": label, "Probability_high_risk": round(prob,4)} | |
| # ---------- Plot functions ---------- | |
| def plot_age_distribution(): | |
| if df_clean is None: return plt.figure() | |
| fig, ax = plt.subplots(figsize=(6,4)) | |
| sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue") | |
| ax.set_title("Age Distribution") | |
| return fig | |
| def plot_risk_counts(): | |
| if df_clean is None: return plt.figure() | |
| fig, ax = plt.subplots(figsize=(6,4)) | |
| sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2") | |
| ax.set_title("High Risk vs Non-Risk Counts") | |
| return fig | |
| def plot_gestation_box(): | |
| if df_clean is None: return plt.figure() | |
| fig, ax = plt.subplots(figsize=(6,4)) | |
| sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2") | |
| ax.set_title("Gestation Weeks vs Risk") | |
| return fig | |
| def plot_feature_importance(): | |
| ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"] | |
| cat_names = ohe.get_feature_names_out(categorical_features) | |
| feature_names = numeric_features + list(cat_names) | |
| importances = model.named_steps["clf"].feature_importances_ | |
| feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances}) | |
| feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10) | |
| fig, ax = plt.subplots(figsize=(8,5)) | |
| sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis") | |
| ax.set_title("Top 10 Feature Importances") | |
| return fig | |
| def plot_corr_heatmap(): | |
| if df_clean is None: return plt.figure() | |
| fig, ax = plt.subplots(figsize=(8,6)) | |
| corr = df_clean[numeric_features+["HighRisk"]].corr() | |
| sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax) | |
| ax.set_title("Correlation Heatmap") | |
| return fig | |
| # ---------- History update ---------- | |
| def update_history(): | |
| return pd.DataFrame(prediction_history) | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo: | |
| gr.Markdown("## Maternal Risk Prediction Dashboard") | |
| with gr.Tab("Prediction"): | |
| gr.Markdown("Enter maternal health parameters to predict risk.") | |
| with gr.Row(): | |
| age = gr.Number(label="Age") | |
| gravida = gr.Number(label="Gravida") | |
| gest = gr.Number(label="Gestation Weeks") | |
| weight = gr.Number(label="Weight (kg)") | |
| height = gr.Number(label="Height (cm)") | |
| with gr.Row(): | |
| bp_sys = gr.Number(label="BP Systolic") | |
| bp_dias = gr.Number(label="BP Diastolic") | |
| fetal_hr = gr.Number(label="Fetal Heart Rate") | |
| anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia") | |
| jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice") | |
| with gr.Row(): | |
| fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position") | |
| fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement") | |
| urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin") | |
| urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar") | |
| out = gr.JSON(label="Result") | |
| btn = gr.Button("Predict Risk") | |
| btn.click(predict_risk, | |
| inputs=[age,gravida,gest,weight,height, | |
| bp_sys,bp_dias,fetal_hr, | |
| anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug], | |
| outputs=out) | |
| with gr.Tab("Data Insights"): | |
| gr.Plot(plot_age_distribution) | |
| gr.Plot(plot_risk_counts) | |
| gr.Plot(plot_gestation_box) | |
| with gr.Tab("Model Insights"): | |
| gr.Plot(plot_feature_importance) | |
| gr.Plot(plot_corr_heatmap) | |
| with gr.Tab("Prediction History"): | |
| history_table = gr.DataFrame(label="Prediction History", interactive=False) | |
| refresh_btn = gr.Button("Refresh History") | |
| refresh_btn.click(fn=update_history, outputs=history_table) | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ### About this App | |
| This dashboard predicts maternal high-risk pregnancy using a RandomForest model. | |
| - **Dataset:** Cleaned maternal health records | |
| - **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar | |
| - **Output:** High Risk vs Not High Risk with probability | |
| """) | |
| demo.launch() |