ACA050 commited on
Commit
32b3395
·
verified ·
1 Parent(s): e1621a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -63
app.py CHANGED
@@ -1,4 +1,3 @@
1
- #app.py
2
  import pandas as pd
3
  import shap
4
  import matplotlib.pyplot as plt
@@ -12,31 +11,12 @@ try:
12
  print("Model loaded successfully for Gradio.")
13
  except FileNotFoundError:
14
  print("Error: 'machine_failure_prediction_model.joblib' not found. Ensure the model file is in the correct directory.")
15
- loaded_model = None # Set to None if loading fails
16
 
17
  def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, torque, tool_wear):
18
- """
19
- Predicts machine failure, generates SHAP analysis, and returns prediction,
20
- probabilities, and SHAP waterfall plot for the given input.
21
-
22
- Args:
23
- Type (str): Machine type (L, M, or H).
24
- air_temp (float): Air temperature in Kelvin.
25
- process_temp (float): Process temperature in Kelvin.
26
- rotational_speed (int): Rotational speed in rpm.
27
- torque (float): Torque in Nm.
28
- tool_wear (int): Tool wear in minutes.
29
-
30
- Returns:
31
- tuple: A tuple containing:
32
- - str: Formatted prediction string.
33
- - str: Formatted probabilities string.
34
- - matplotlib.figure.Figure: SHAP waterfall plot figure object.
35
- """
36
  if loaded_model is None:
37
- return "Error: Model not loaded.", "", None # Return error if model loading failed
38
 
39
- # Create a DataFrame from the input
40
  input_data = pd.DataFrame({
41
  'Type': [Type],
42
  'Air temperature [K]': [air_temp],
@@ -46,43 +26,27 @@ def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, to
46
  'Tool wear [min]': [tool_wear]
47
  })
48
 
49
- # Make prediction and get probabilities using the loaded pipeline
50
  predicted_failure = loaded_model.predict(input_data)[0]
51
  predicted_proba = loaded_model.predict_proba(input_data)[0]
52
 
53
- # Format the prediction and probabilities for Gradio output
54
  prediction_label = f"Predicted Failure: {'Failure' if predicted_failure == 1 else 'No Failure'}"
55
  probabilities_label = f"Probabilities (No Failure, Failure): {predicted_proba[0]:.4f}, {predicted_proba[1]:.4f}"
56
 
57
-
58
- # Get preprocessor and classifier from the pipeline
59
  preprocessor = loaded_model.named_steps['preprocessor']
60
  classifier = loaded_model.named_steps['classifier']
61
-
62
- # Transform the input data
63
  X_transformed = preprocessor.transform(input_data)
64
 
65
- # Initialize SHAP explainer and calculate SHAP values
66
  explainer = shap.TreeExplainer(classifier)
67
- # Ensure SHAP values are calculated for the transformed input data
68
  shap_values = explainer.shap_values(X_transformed)
69
-
70
- # Get feature names after preprocessing
71
  feature_names = preprocessor.get_feature_names_out()
72
 
73
- # Handle multi-output SHAP values (for binary classification, usually list of arrays)
74
  if isinstance(shap_values, list):
75
- # For binary classification, shap_values[0] is for class 0, shap_values[1] for class 1
76
- shap_val = shap_values[1][0] # Get values for the positive class (Failure)
77
- base_val = explainer.expected_value[1] # Get expected value for the positive class
78
  else:
79
- # For single-output models, or if shap_values is a single array
80
- # Assuming the positive class is at index 1 for probability output
81
- shap_val = shap_values[0, :] if shap_values.ndim == 2 else shap_values[0, :, 1] # Get values for the positive class
82
- base_val = explainer.expected_value if not isinstance(explainer.expected_value, (list, tuple, np.ndarray)) else explainer.expected_value[1] # Get expected value for the positive class
83
 
84
- # Generate SHAP waterfall plot
85
- # Use a different figure explicitly to avoid interference with other plots
86
  fig = plt.figure()
87
  shap.waterfall_plot(
88
  shap.Explanation(
@@ -93,37 +57,31 @@ def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, to
93
  ), show=False
94
  )
95
  plt.title("SHAP Waterfall Plot for Failure Prediction")
96
- plt.tight_layout() # Adjust layout to prevent labels overlapping
 
 
97
 
98
- # Define the Gradio input components
99
- inputs = [
 
 
 
100
  gr.Dropdown(choices=['L', 'M', 'H'], label='Machine Type'),
101
  gr.Number(label='Air temperature [K]', step=0.1),
102
  gr.Number(label='Process temperature [K]', step=0.1),
103
  gr.Number(label='Rotational speed [rpm]', step=1),
104
  gr.Number(label='Torque [Nm]', step=0.1),
105
  gr.Number(label='Tool wear [min]', step=1)
106
- ]
107
-
108
- # Define the Gradio output components
109
- outputs = [
110
  gr.Label(label='Predicted Failure (0=No Failure, 1=Failure)'),
111
  gr.Label(label='Prediction Probabilities'),
112
  gr.Plot(label='SHAP Waterfall Plot')
113
- ]
114
-
115
- # Create the Gradio interface using the corrected parameter
116
- iface = gr.Interface(
117
- fn=predict_failure_with_shap,
118
- inputs=inputs,
119
- outputs=outputs,
120
- title="Machine Failure Prediction with SHAP Analysis",
121
- description="Enter the machine parameters to predict failure and see the SHAP analysis.",
122
- flagging_mode='never' # Using the recommended parameter
123
- )
124
-
125
- # Return the outputs for Gradio
126
- return prediction_label, probabilities_label, fig
127
 
128
  if __name__ == "__main__":
129
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import pandas as pd
2
  import shap
3
  import matplotlib.pyplot as plt
 
11
  print("Model loaded successfully for Gradio.")
12
  except FileNotFoundError:
13
  print("Error: 'machine_failure_prediction_model.joblib' not found. Ensure the model file is in the correct directory.")
14
+ loaded_model = None
15
 
16
  def predict_failure_with_shap(Type, air_temp, process_temp, rotational_speed, torque, tool_wear):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if loaded_model is None:
18
+ return "Error: Model not loaded.", "", None
19
 
 
20
  input_data = pd.DataFrame({
21
  'Type': [Type],
22
  'Air temperature [K]': [air_temp],
 
26
  'Tool wear [min]': [tool_wear]
27
  })
28
 
 
29
  predicted_failure = loaded_model.predict(input_data)[0]
30
  predicted_proba = loaded_model.predict_proba(input_data)[0]
31
 
 
32
  prediction_label = f"Predicted Failure: {'Failure' if predicted_failure == 1 else 'No Failure'}"
33
  probabilities_label = f"Probabilities (No Failure, Failure): {predicted_proba[0]:.4f}, {predicted_proba[1]:.4f}"
34
 
 
 
35
  preprocessor = loaded_model.named_steps['preprocessor']
36
  classifier = loaded_model.named_steps['classifier']
 
 
37
  X_transformed = preprocessor.transform(input_data)
38
 
 
39
  explainer = shap.TreeExplainer(classifier)
 
40
  shap_values = explainer.shap_values(X_transformed)
 
 
41
  feature_names = preprocessor.get_feature_names_out()
42
 
 
43
  if isinstance(shap_values, list):
44
+ shap_val = shap_values[1][0]
45
+ base_val = explainer.expected_value[1]
 
46
  else:
47
+ shap_val = shap_values[0, :] if shap_values.ndim == 2 else shap_values[0, :, 1]
48
+ base_val = explainer.expected_value if not isinstance(explainer.expected_value, (list, tuple, np.ndarray)) else explainer.expected_value[1]
 
 
49
 
 
 
50
  fig = plt.figure()
51
  shap.waterfall_plot(
52
  shap.Explanation(
 
57
  ), show=False
58
  )
59
  plt.title("SHAP Waterfall Plot for Failure Prediction")
60
+ plt.tight_layout()
61
+
62
+ return prediction_label, probabilities_label, fig
63
 
64
+
65
+ # Define Gradio interface OUTSIDE the function
66
+ iface = gr.Interface(
67
+ fn=predict_failure_with_shap,
68
+ inputs=[
69
  gr.Dropdown(choices=['L', 'M', 'H'], label='Machine Type'),
70
  gr.Number(label='Air temperature [K]', step=0.1),
71
  gr.Number(label='Process temperature [K]', step=0.1),
72
  gr.Number(label='Rotational speed [rpm]', step=1),
73
  gr.Number(label='Torque [Nm]', step=0.1),
74
  gr.Number(label='Tool wear [min]', step=1)
75
+ ],
76
+ outputs=[
 
 
77
  gr.Label(label='Predicted Failure (0=No Failure, 1=Failure)'),
78
  gr.Label(label='Prediction Probabilities'),
79
  gr.Plot(label='SHAP Waterfall Plot')
80
+ ],
81
+ title="Machine Failure Prediction with SHAP Analysis",
82
+ description="Enter the machine parameters to predict failure and see the SHAP analysis.",
83
+ flagging_mode='never'
84
+ )
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  iface.launch(server_name="0.0.0.0", server_port=7860)