ACA050's picture
Update app.py
96ef2ff verified
# app.py
import os
import joblib
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.pipeline import Pipeline # Keep for type hinting/structure
# =====================================================================================
# PART 1: MODEL LOADING
# This section now directly loads the model file you have uploaded to the repository.
# =====================================================================================
# <<< CHANGE 1: Update the filename to match your uploaded model.
MODEL_FILE = "machine_failure_prediction_model.joblib"
# <<< CHANGE 2: Remove the model creation logic. We will now directly load your file.
# The app assumes this file exists in your Hugging Face repository.
try:
loaded_model = joblib.load(MODEL_FILE)
print(f"Successfully loaded model from {MODEL_FILE}")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_FILE}. Make sure the file is uploaded to the repository.")
# You might want to display an error in the Gradio app itself if the model fails to load.
loaded_model = None # Set to None to handle the error case
# =====================================================================================
# PART 2: BACKEND LOGIC (Prediction and SHAP Calculation)
# =====================================================================================
def predict_failure(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
"""Predicts machine failure and calculates SHAP values using the loaded model."""
if loaded_model is None:
return "Error: Model not loaded.", None
input_data = pd.DataFrame({
'Type': [Type], 'Air temperature [K]': [air_temperature],
'Process temperature [K]': [process_temperature], 'Rotational speed [rpm]': [rotational_speed],
'Torque [Nm]': [torque], 'Tool wear [min]': [tool_wear]
})
preprocessor = loaded_model.named_steps['preprocessor']
classifier = loaded_model.named_steps['classifier']
input_processed = preprocessor.transform(input_data)
probability = classifier.predict_proba(input_processed)[:, 1]
explainer = shap.TreeExplainer(classifier)
shap_values = explainer.shap_values(input_processed)
feature_names = preprocessor.get_feature_names_out()
# SHAP values for the "Failure" class (index 1)
shap_val_failure = shap_values[1][0]
base_val_failure = explainer.expected_value[1]
return probability[0], shap_val_failure, feature_names, base_val_failure
# =====================================================================================
# PART 3: FRONTEND LOGIC (Plotting and Gradio Interface)
# =====================================================================================
def generate_shap_plot(shap_values, feature_names, base_value):
"""Generates a SHAP waterfall plot for the Gradio interface."""
plt.close('all')
explanation = shap.Explanation(
values=shap_values, base_values=base_value, feature_names=feature_names
)
fig, _ = plt.subplots()
shap.waterfall_plot(explanation, max_display=10, show=False)
plt.tight_layout()
return fig
def predict_and_generate_plot(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
"""Wrapper function that connects the backend prediction to the frontend plot."""
result = predict_failure(
Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear
)
if isinstance(result, tuple):
probability, shap_values, feature_names, base_value = result
shap_plot = generate_shap_plot(shap_values, feature_names, base_value)
return f"{probability:.2%}", shap_plot
else:
# Handle the case where the model failed to load
error_message = result
empty_plot = plt.figure()
plt.text(0.5, 0.5, 'Error: Model could not be loaded.', horizontalalignment='center', verticalalignment='center')
return error_message, empty_plot
# Define the Gradio interface layout and components
with gr.Blocks(theme=gr.themes.Soft()) as iface_with_shap:
gr.Markdown("# Machine Failure Prediction with Live SHAP Analysis")
gr.Markdown("Adjust the sliders to see the real-time probability of machine failure and how each feature's value contributes to the prediction.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Features")
type_input = gr.Dropdown(label="Type", choices=['L', 'M', 'H'], value='L')
air_temp_input = gr.Slider(minimum=295, maximum=305, value=300, label="Air temperature [K]")
proc_temp_input = gr.Slider(minimum=305, maximum=315, value=310, label="Process temperature [K]")
rpm_input = gr.Slider(minimum=1000, maximum=3000, value=1500, label="Rotational speed [rpm]")
torque_input = gr.Slider(minimum=5, maximum=80, value=40, label="Torque [Nm]")
wear_input = gr.Slider(minimum=0, maximum=250, value=100, label="Tool wear [min]")
inputs = [type_input, air_temp_input, proc_temp_input, rpm_input, torque_input, wear_input]
with gr.Column(scale=2):
gr.Markdown("### Prediction Outputs")
probability_output = gr.Textbox(label="Probability of Machine Failure")
plot_output = gr.Plot(label="Feature Contribution to Failure (SHAP Waterfall Plot)")
# This makes the app load the first prediction on startup
iface_with_shap.load(
fn=predict_and_generate_plot,
inputs=inputs,
outputs=[probability_output, plot_output]
)
# This connects the UI changes to the prediction function
for input_comp in inputs:
input_comp.change(
fn=predict_and_generate_plot,
inputs=inputs,
outputs=[probability_output, plot_output]
)
# Launch the application
if __name__ == "__main__":
iface_with_shap.launch()