ACA050's picture
Create app.py
08c0349 verified
raw
history blame
5.37 kB
#app.py
import pandas as pd
import shap
import matplotlib.pyplot as plt
import numpy as np
import joblib
import gradio as gr
# Load the model
try:
loaded_model = joblib.load('machine_failure_prediction_model.joblib')
print("Model loaded successfully for Gradio.")
except FileNotFoundError:
print("Error: 'machine_failure_prediction_model.joblib' not found. Ensure the model file is in the correct directory.")
loaded_model = None # Set to None if loading fails
def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, torque, tool_wear):
"""
Predicts machine failure, generates SHAP analysis, and returns prediction,
probabilities, and SHAP waterfall plot for the given input.
Args:
Type (str): Machine type (L, M, or H).
air_temp (float): Air temperature in Kelvin.
process_temp (float): Process temperature in Kelvin.
rotational_speed (int): Rotational speed in rpm.
torque (float): Torque in Nm.
tool_wear (int): Tool wear in minutes.
Returns:
tuple: A tuple containing:
- str: Formatted prediction string.
- str: Formatted probabilities string.
- matplotlib.figure.Figure: SHAP waterfall plot figure object.
"""
if loaded_model is None:
return "Error: Model not loaded.", "", None # Return error if model loading failed
# Create a DataFrame from the input
input_data = pd.DataFrame({
'Type': [Type],
'Air temperature [K]': [air_temp],
'Process temperature [K]': [process_temp],
'Rotational speed [rpm]': [rotational_speed],
'Torque [Nm]': [torque],
'Tool wear [min]': [tool_wear]
})
# Make prediction and get probabilities using the loaded pipeline
predicted_failure = loaded_model.predict(input_data)[0]
predicted_proba = loaded_model.predict_proba(input_data)[0]
# Format the prediction and probabilities for Gradio output
prediction_label = f"Predicted Failure: {'Failure' if predicted_failure == 1 else 'No Failure'}"
probabilities_label = f"Probabilities (No Failure, Failure): {predicted_proba[0]:.4f}, {predicted_proba[1]:.4f}"
# Get preprocessor and classifier from the pipeline
preprocessor = loaded_model.named_steps['preprocessor']
classifier = loaded_model.named_steps['classifier']
# Transform the input data
X_transformed = preprocessor.transform(input_data)
# Initialize SHAP explainer and calculate SHAP values
explainer = shap.TreeExplainer(classifier)
# Ensure SHAP values are calculated for the transformed input data
shap_values = explainer.shap_values(X_transformed)
# Get feature names after preprocessing
feature_names = preprocessor.get_feature_names_out()
# Handle multi-output SHAP values (for binary classification, usually list of arrays)
if isinstance(shap_values, list):
# For binary classification, shap_values[0] is for class 0, shap_values[1] for class 1
shap_val = shap_values[1][0] # Get values for the positive class (Failure)
base_val = explainer.expected_value[1] # Get expected value for the positive class
else:
# For single-output models, or if shap_values is a single array
# Assuming the positive class is at index 1 for probability output
shap_val = shap_values[0, :] if shap_values.ndim == 2 else shap_values[0, :, 1] # Get values for the positive class
base_val = explainer.expected_value if not isinstance(explainer.expected_value, (list, tuple, np.ndarray)) else explainer.expected_value[1] # Get expected value for the positive class
# Generate SHAP waterfall plot
# Use a different figure explicitly to avoid interference with other plots
fig = plt.figure()
shap.waterfall_plot(
shap.Explanation(
values=shap_val,
base_values=base_val,
data=X_transformed[0],
feature_names=feature_names
), show=False
)
plt.title("SHAP Waterfall Plot for Failure Prediction")
plt.tight_layout() # Adjust layout to prevent labels overlapping
# Define the Gradio input components
inputs = [
gr.Dropdown(choices=['L', 'M', 'H'], label='Machine Type'),
gr.Number(label='Air temperature [K]', step=0.1),
gr.Number(label='Process temperature [K]', step=0.1),
gr.Number(label='Rotational speed [rpm]', step=1),
gr.Number(label='Torque [Nm]', step=0.1),
gr.Number(label='Tool wear [min]', step=1)
]
# Define the Gradio output components
outputs = [
gr.Label(label='Predicted Failure (0=No Failure, 1=Failure)'),
gr.Label(label='Prediction Probabilities'),
gr.Plot(label='SHAP Waterfall Plot')
]
# Create the Gradio interface using the corrected parameter
iface = gr.Interface(
fn=predict_failure_with_shap,
inputs=inputs,
outputs=outputs,
title="Machine Failure Prediction with SHAP Analysis",
description="Enter the machine parameters to predict failure and see the SHAP analysis.",
flagging_mode='never' # Using the recommended parameter
)
# Return the outputs for Gradio
return prediction_label, probabilities_label, fig
# To run this app locally for testing, uncomment the line below:
# iface.launch()