import json, joblib, pandas as pd, numpy as np import gradio as gr import seaborn as sns import matplotlib.pyplot as plt # Load model + metadata + dataset model = joblib.load("maternal_rf_model.joblib") with open("maternal_metadata.json","r",encoding="utf-8") as f: meta = json.load(f) try: df_clean = pd.read_csv("maternal_cleaned.csv") except FileNotFoundError: df_clean = None numeric_features = meta["numeric_features"] categorical_features = meta["categorical_features"] # ---------- Prediction history ---------- prediction_history = [] # ---------- Prediction function ---------- def predict_risk(age, gravida, gest_weeks, weight, height_cm, bp_sys, bp_dias, fetal_hr, anaemia, jaundice, fetal_position, fetal_movement, urine_albumin, urine_sugar): row = { "Age": age, "Gravida": gravida, "GestationWeeks": gest_weeks, "WeightKg": weight, "HeightCm": height_cm, "BP_Systolic": bp_sys, "BP_Diastolic": bp_dias, "FetalHR": fetal_hr, "Anaemia": anaemia, "Jaundice": jaundice, "FetalPosition": fetal_position, "FetalMovement": fetal_movement, "UrineAlbumin": urine_albumin, "UrineSugar": urine_sugar } X = pd.DataFrame([row], columns=numeric_features+categorical_features) prob = model.predict_proba(X)[:,1][0] pred = int(model.predict(X)[0]) label = "High Risk" if pred==1 else "Not High Risk" # Save to history history_row = row.copy() history_row["Prediction"] = label history_row["Probability"] = round(prob, 4) prediction_history.append(history_row) return {"Prediction": label, "Probability_high_risk": round(prob,4)} # ---------- Plot functions ---------- def plot_age_distribution(): if df_clean is None: return plt.figure() fig, ax = plt.subplots(figsize=(6,4)) sns.histplot(df_clean["Age"], bins=10, kde=True, ax=ax, color="skyblue") ax.set_title("Age Distribution") return fig def plot_risk_counts(): if df_clean is None: return plt.figure() fig, ax = plt.subplots(figsize=(6,4)) sns.countplot(x="HighRisk", data=df_clean, ax=ax, palette="Set2") ax.set_title("High Risk vs Non-Risk Counts") return fig def plot_gestation_box(): if df_clean is None: return plt.figure() fig, ax = plt.subplots(figsize=(6,4)) sns.boxplot(x="HighRisk", y="GestationWeeks", data=df_clean, ax=ax, palette="Set2") ax.set_title("Gestation Weeks vs Risk") return fig def plot_feature_importance(): ohe = model.named_steps["preprocessor"].named_transformers_["cat"].named_steps["onehot"] cat_names = ohe.get_feature_names_out(categorical_features) feature_names = numeric_features + list(cat_names) importances = model.named_steps["clf"].feature_importances_ feat_imp = pd.DataFrame({"Feature":feature_names,"Importance":importances}) feat_imp = feat_imp.sort_values("Importance",ascending=False).head(10) fig, ax = plt.subplots(figsize=(8,5)) sns.barplot(x="Importance", y="Feature", data=feat_imp, ax=ax, palette="viridis") ax.set_title("Top 10 Feature Importances") return fig def plot_corr_heatmap(): if df_clean is None: return plt.figure() fig, ax = plt.subplots(figsize=(8,6)) corr = df_clean[numeric_features+["HighRisk"]].corr() sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", ax=ax) ax.set_title("Correlation Heatmap") return fig # ---------- History update ---------- def update_history(): return pd.DataFrame(prediction_history) # ---------- Gradio UI ---------- with gr.Blocks(title="Maternal Risk Prediction Dashboard") as demo: gr.Markdown("## Maternal Risk Prediction Dashboard") with gr.Tab("Prediction"): gr.Markdown("Enter maternal health parameters to predict risk.") with gr.Row(): age = gr.Number(label="Age") gravida = gr.Number(label="Gravida") gest = gr.Number(label="Gestation Weeks") weight = gr.Number(label="Weight (kg)") height = gr.Number(label="Height (cm)") with gr.Row(): bp_sys = gr.Number(label="BP Systolic") bp_dias = gr.Number(label="BP Diastolic") fetal_hr = gr.Number(label="Fetal Heart Rate") anaemia = gr.Dropdown(["None","Minimal","Medium","Higher"], label="Anaemia") jaundice = gr.Dropdown(["None","Minimal","Medium"], label="Jaundice") with gr.Row(): fetal_pos = gr.Dropdown(["Normal","Abnormal"], label="Fetal Position") fetal_mov = gr.Dropdown(["Yes","No"], label="Fetal Movement") urine_alb = gr.Dropdown(["Negative","Positive"], label="Urine Albumin") urine_sug = gr.Dropdown(["Negative","Positive"], label="Urine Sugar") out = gr.JSON(label="Result") btn = gr.Button("Predict Risk") btn.click(predict_risk, inputs=[age,gravida,gest,weight,height, bp_sys,bp_dias,fetal_hr, anaemia,jaundice,fetal_pos,fetal_mov,urine_alb,urine_sug], outputs=out) with gr.Tab("Data Insights"): gr.Plot(plot_age_distribution) gr.Plot(plot_risk_counts) gr.Plot(plot_gestation_box) with gr.Tab("Model Insights"): gr.Plot(plot_feature_importance) gr.Plot(plot_corr_heatmap) with gr.Tab("Prediction History"): history_table = gr.DataFrame(label="Prediction History", interactive=False) refresh_btn = gr.Button("Refresh History") refresh_btn.click(fn=update_history, outputs=history_table) with gr.Tab("About"): gr.Markdown(""" ### About this App This dashboard predicts maternal high-risk pregnancy using a RandomForest model. - **Dataset:** Cleaned maternal health records - **Features:** Age, Gravida, Gestation Weeks, Weight, Height, BP, Fetal HR, Anaemia, Jaundice, Fetal Position, Fetal Movement, Urine Albumin, Urine Sugar - **Output:** High Risk vs Not High Risk with probability """) demo.launch()