ACA050 commited on
Commit
08c0349
·
verified ·
1 Parent(s): 7dfedf3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #app.py
2
+ import pandas as pd
3
+ import shap
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import joblib
7
+ import gradio as gr
8
+
9
+ # Load the model
10
+ try:
11
+ loaded_model = joblib.load('machine_failure_prediction_model.joblib')
12
+ print("Model loaded successfully for Gradio.")
13
+ except FileNotFoundError:
14
+ print("Error: 'machine_failure_prediction_model.joblib' not found. Ensure the model file is in the correct directory.")
15
+ loaded_model = None # Set to None if loading fails
16
+
17
+ def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, torque, tool_wear):
18
+ """
19
+ Predicts machine failure, generates SHAP analysis, and returns prediction,
20
+ probabilities, and SHAP waterfall plot for the given input.
21
+
22
+ Args:
23
+ Type (str): Machine type (L, M, or H).
24
+ air_temp (float): Air temperature in Kelvin.
25
+ process_temp (float): Process temperature in Kelvin.
26
+ rotational_speed (int): Rotational speed in rpm.
27
+ torque (float): Torque in Nm.
28
+ tool_wear (int): Tool wear in minutes.
29
+
30
+ Returns:
31
+ tuple: A tuple containing:
32
+ - str: Formatted prediction string.
33
+ - str: Formatted probabilities string.
34
+ - matplotlib.figure.Figure: SHAP waterfall plot figure object.
35
+ """
36
+ if loaded_model is None:
37
+ return "Error: Model not loaded.", "", None # Return error if model loading failed
38
+
39
+ # Create a DataFrame from the input
40
+ input_data = pd.DataFrame({
41
+ 'Type': [Type],
42
+ 'Air temperature [K]': [air_temp],
43
+ 'Process temperature [K]': [process_temp],
44
+ 'Rotational speed [rpm]': [rotational_speed],
45
+ 'Torque [Nm]': [torque],
46
+ 'Tool wear [min]': [tool_wear]
47
+ })
48
+
49
+ # Make prediction and get probabilities using the loaded pipeline
50
+ predicted_failure = loaded_model.predict(input_data)[0]
51
+ predicted_proba = loaded_model.predict_proba(input_data)[0]
52
+
53
+ # Format the prediction and probabilities for Gradio output
54
+ prediction_label = f"Predicted Failure: {'Failure' if predicted_failure == 1 else 'No Failure'}"
55
+ probabilities_label = f"Probabilities (No Failure, Failure): {predicted_proba[0]:.4f}, {predicted_proba[1]:.4f}"
56
+
57
+
58
+ # Get preprocessor and classifier from the pipeline
59
+ preprocessor = loaded_model.named_steps['preprocessor']
60
+ classifier = loaded_model.named_steps['classifier']
61
+
62
+ # Transform the input data
63
+ X_transformed = preprocessor.transform(input_data)
64
+
65
+ # Initialize SHAP explainer and calculate SHAP values
66
+ explainer = shap.TreeExplainer(classifier)
67
+ # Ensure SHAP values are calculated for the transformed input data
68
+ shap_values = explainer.shap_values(X_transformed)
69
+
70
+ # Get feature names after preprocessing
71
+ feature_names = preprocessor.get_feature_names_out()
72
+
73
+ # Handle multi-output SHAP values (for binary classification, usually list of arrays)
74
+ if isinstance(shap_values, list):
75
+ # For binary classification, shap_values[0] is for class 0, shap_values[1] for class 1
76
+ shap_val = shap_values[1][0] # Get values for the positive class (Failure)
77
+ base_val = explainer.expected_value[1] # Get expected value for the positive class
78
+ else:
79
+ # For single-output models, or if shap_values is a single array
80
+ # Assuming the positive class is at index 1 for probability output
81
+ shap_val = shap_values[0, :] if shap_values.ndim == 2 else shap_values[0, :, 1] # Get values for the positive class
82
+ base_val = explainer.expected_value if not isinstance(explainer.expected_value, (list, tuple, np.ndarray)) else explainer.expected_value[1] # Get expected value for the positive class
83
+
84
+ # Generate SHAP waterfall plot
85
+ # Use a different figure explicitly to avoid interference with other plots
86
+ fig = plt.figure()
87
+ shap.waterfall_plot(
88
+ shap.Explanation(
89
+ values=shap_val,
90
+ base_values=base_val,
91
+ data=X_transformed[0],
92
+ feature_names=feature_names
93
+ ), show=False
94
+ )
95
+ plt.title("SHAP Waterfall Plot for Failure Prediction")
96
+ plt.tight_layout() # Adjust layout to prevent labels overlapping
97
+
98
+ # Define the Gradio input components
99
+ inputs = [
100
+ gr.Dropdown(choices=['L', 'M', 'H'], label='Machine Type'),
101
+ gr.Number(label='Air temperature [K]', step=0.1),
102
+ gr.Number(label='Process temperature [K]', step=0.1),
103
+ gr.Number(label='Rotational speed [rpm]', step=1),
104
+ gr.Number(label='Torque [Nm]', step=0.1),
105
+ gr.Number(label='Tool wear [min]', step=1)
106
+ ]
107
+
108
+ # Define the Gradio output components
109
+ outputs = [
110
+ gr.Label(label='Predicted Failure (0=No Failure, 1=Failure)'),
111
+ gr.Label(label='Prediction Probabilities'),
112
+ gr.Plot(label='SHAP Waterfall Plot')
113
+ ]
114
+
115
+ # Create the Gradio interface using the corrected parameter
116
+ iface = gr.Interface(
117
+ fn=predict_failure_with_shap,
118
+ inputs=inputs,
119
+ outputs=outputs,
120
+ title="Machine Failure Prediction with SHAP Analysis",
121
+ description="Enter the machine parameters to predict failure and see the SHAP analysis.",
122
+ flagging_mode='never' # Using the recommended parameter
123
+ )
124
+
125
+ # Return the outputs for Gradio
126
+ return prediction_label, probabilities_label, fig
127
+
128
+ # To run this app locally for testing, uncomment the line below:
129
+ # iface.launch()