Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import hopsworks | |
| import joblib | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import shap | |
| from sklearn.pipeline import make_pipeline | |
| import seaborn as sns | |
| feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"] | |
| project = hopsworks.login(project="SonyaStern_Lab1") | |
| fs = project.get_feature_store() | |
| print("trying to dl model") | |
| mr = project.get_model_registry() | |
| model = mr.get_model("diabetes_gan_model", version=2) | |
| model_dir = model.download() | |
| model = joblib.load(model_dir + "/diabetes_gan_model.pkl") | |
| print("Model downloaded") | |
| diabetes_fg = fs.get_feature_group(name="diabetes_gan", version=1) | |
| query = diabetes_fg.select_all() | |
| # feature_view = fs.get_or_create_feature_view(name="diabetes", | |
| feature_view = fs.get_or_create_feature_view( | |
| name="diabetes_gan", | |
| version=1, | |
| description="Read from Diabetes dataset", | |
| labels=["diabetes"], | |
| query=query, | |
| ) | |
| diabetes_df = pd.DataFrame(diabetes_fg.read()) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.HTML(value="<h1 style='text-align: center;'>Diabetes prediction</h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| age_input = gr.Number(label="age") | |
| bmi_input = gr.Slider(10, 100, label="bmi", info="Body Mass Index") | |
| hba1c_input = gr.Slider( | |
| 3.5, 9, label="hba1c_level", info="Glycated Haemoglobin" | |
| ) | |
| blood_glucose_input = gr.Slider( | |
| 80, 300, label="blood_glucose_level", info="Blood Glucose Level" | |
| ) | |
| existent_info_input = gr.Radio( | |
| ["yes", "no", "Don't know"], | |
| label="Do you already know if you have diabetes? (This will not be used for the prediction)", | |
| ) | |
| consent_input = gr.Radio( | |
| ["accept", "decline"], | |
| label="I consent that my personal data will be saved and potentially be used for the model training", | |
| ) | |
| btn = gr.Button("Submit") | |
| with gr.Column(): | |
| with gr.Row(): | |
| output = gr.Text(label="Model prediction") | |
| with gr.Row(): | |
| mean_plot = gr.Plot() | |
| with gr.Row(): | |
| with gr.Accordion("See model explanability", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| waterfall_plot = gr.Plot() | |
| with gr.Column(): | |
| summary_plot = gr.Plot() | |
| with gr.Row(): | |
| with gr.Column(): | |
| importance_plot = gr.Plot() | |
| with gr.Column(): | |
| decision_plot = gr.Plot() | |
| def submit_inputs( | |
| age_input, | |
| bmi_input, | |
| hba1c_input, | |
| blood_glucose_input, | |
| existent_info_input, | |
| consent_input, | |
| ): | |
| df = pd.DataFrame( | |
| [[age_input, float(bmi_input), hba1c_input, blood_glucose_input]], | |
| columns=["age", "bmi", "hba1c_level", "blood_glucose_level"], | |
| ) | |
| res = model.predict(df) | |
| if res == [0]: | |
| res_str = "the model prediction is: You don't have diabetes" | |
| elif res == [1]: | |
| res_str = "the model prediction is: You have diabetes" | |
| mean_for_age = diabetes_df[ | |
| (diabetes_df["diabetes"] == 0) & (diabetes_df["age"] == age_input) | |
| ].mean() | |
| print( | |
| "your bmi is:", bmi_input, "the mean for ur age is :", mean_for_age["bmi"] | |
| ) | |
| categories = ["BMI", "HbA1c", "Blood Level"] | |
| fig, ax = plt.subplots() | |
| bar_width = 0.35 | |
| indices = np.arange(len(categories)) | |
| ax.bar( | |
| indices, | |
| [ | |
| mean_for_age.bmi, | |
| mean_for_age.hba1c_level, | |
| mean_for_age.blood_glucose_level, | |
| ], | |
| bar_width, | |
| label="Reference", | |
| color="b", | |
| alpha=0.7, | |
| ) | |
| ax.bar( | |
| indices + bar_width, | |
| [bmi_input, hba1c_input, blood_glucose_input], | |
| bar_width, | |
| label="User", | |
| color="r", | |
| alpha=0.7, | |
| ) | |
| ax.legend() | |
| ax.set_xlabel("Variables") | |
| ax.set_ylabel("Values") | |
| ax.set_title("Comparison with average non-diabetic values for your age") | |
| ax.set_xticks(indices + bar_width / 2) | |
| ax.set_xticklabels(categories) | |
| ## explainability plots | |
| rf_classifier = model.named_steps["randomforestclassifier"] | |
| transformer_pipeline = make_pipeline( | |
| *[ | |
| step | |
| for name, step in model.named_steps.items() | |
| if name != "randomforestclassifier" | |
| ] | |
| ) | |
| transformed_df = transformer_pipeline.transform(df) | |
| # Generate the SHAP waterfall plot for fig2 | |
| explainer = shap.TreeExplainer(rf_classifier) | |
| shap_values = explainer.shap_values( | |
| transformed_df | |
| ) # Compute SHAP values directly on the DataFrame | |
| predicted_class = rf_classifier.predict(transformed_df)[0] | |
| shap_values_for_predicted_class = shap_values[predicted_class] | |
| # Select the SHAP values for the first instance and the positive class | |
| shap_explanation = shap.Explanation( | |
| values=shap_values_for_predicted_class[0], | |
| base_values=explainer.expected_value[predicted_class], | |
| data=df.iloc[0], | |
| feature_names=["age", "bmi", "hba1c", "glucose"], | |
| ) | |
| fig2 = plt.figure(figsize=(3, 3)) # Create a new figure for SHAP plot | |
| fig2.tight_layout() | |
| plt.gca().set_position((0, 0, 1, 1)) | |
| plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot | |
| plt.tight_layout() | |
| plt.tick_params(axis="y", labelsize=3) | |
| shap.waterfall_plot(shap_explanation) | |
| fig3 = plt.figure(figsize=(3, 3)) | |
| plt.title("SHAP Summary Plot") | |
| shap.summary_plot( | |
| shap_values, | |
| features=transformed_df, | |
| feature_names=["age", "bmi", "hba1c", "glucose"], | |
| ) | |
| fig4 = plt.figure(figsize=(4, 3)) | |
| feature_importances = rf_classifier.feature_importances_ | |
| plt.title("Feature Importances") | |
| sns.barplot(x=feature_importances, y=["age", "bmi", "hba1c", "glucose"]) | |
| fig5 = plt.figure(figsize=(3, 3)) | |
| fig5.tight_layout() | |
| plt.gca().set_position((0, 0, 1, 1)) | |
| plt.title("SHAP Interaction Plot") | |
| plt.tight_layout() | |
| shap.decision_plot( | |
| explainer.expected_value[predicted_class], | |
| shap_values_for_predicted_class, | |
| df.iloc[0], | |
| ) | |
| ## save user's data in hopsworks | |
| if consent_input == "accept": | |
| print("user consented to save their data, now trying to save to hopsworks") | |
| user_data_fg = fs.get_or_create_feature_group( | |
| name="diabetes_user_data", | |
| version=1, | |
| primary_key=[ | |
| "age", | |
| "bmi", | |
| "hba1c_level", | |
| "blood_glucose_level", | |
| "diabetes", | |
| ], | |
| description="Submitted user data", | |
| ) | |
| user_data_df = df.copy() | |
| user_data_df["diabetes"] = existent_info_input | |
| user_data_df["model_prediction"] = res[0] | |
| user_data_fg.insert(user_data_df) | |
| print("inserted new user data to hopsworks", user_data_df) | |
| return res_str, fig, fig2, fig3, fig4, fig5 | |
| btn.click( | |
| submit_inputs, | |
| inputs=[ | |
| age_input, | |
| bmi_input, | |
| hba1c_input, | |
| blood_glucose_input, | |
| existent_info_input, | |
| consent_input, | |
| ], | |
| outputs=[ | |
| output, | |
| mean_plot, | |
| waterfall_plot, | |
| summary_plot, | |
| importance_plot, | |
| decision_plot, | |
| ], | |
| ) | |
| demo.launch() | |