ACA050 commited on
Commit
96ef2ff
·
verified ·
1 Parent(s): 19abfdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -59
app.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
 
3
  import os
4
  import joblib
@@ -7,63 +7,25 @@ import matplotlib.pyplot as plt
7
  import numpy as np
8
  import pandas as pd
9
  import shap
10
- from sklearn.compose import ColumnTransformer
11
- from sklearn.ensemble import RandomForestClassifier
12
- from sklearn.pipeline import Pipeline
13
- from sklearn.preprocessing import OneHotEncoder, StandardScaler
14
 
15
  # =====================================================================================
16
- # PART 1: MODEL CREATION AND LOADING (Self-contained for Hugging Face)
17
- # This part creates, trains, and saves a mock model if one doesn't exist.
18
- # This ensures the app is fully reproducible in any environment.
19
  # =====================================================================================
20
 
21
- MODEL_FILE = "machine_failure_model.joblib"
22
-
23
- def create_and_train_model():
24
- """Creates, trains, and saves a mock model pipeline."""
25
- # Mock data that resembles the predictive maintenance dataset
26
- mock_features = pd.DataFrame({
27
- 'Type': ['L', 'M', 'H', 'L', 'M', 'H', 'L', 'M', 'H', 'L'],
28
- 'Air temperature [K]': [298.1, 298.2, 298.3, 298.4, 299.0, 299.5, 300.1, 301.0, 302.5, 303.0],
29
- 'Process temperature [K]': [308.6, 308.7, 308.8, 308.9, 309.1, 309.8, 310.5, 311.0, 312.0, 313.5],
30
- 'Rotational speed [rpm]': [1551, 1428, 1455, 1600, 1750, 2000, 2200, 2500, 2850, 1300],
31
- 'Torque [Nm]': [42.8, 46.3, 40.0, 50.1, 55.2, 60.0, 65.5, 70.0, 75.0, 35.0],
32
- 'Tool wear [min]': [0, 5, 10, 15, 25, 50, 80, 120, 180, 210]
33
- })
34
- # Mock target: 0 = No Failure, 1 = Failure
35
- mock_target = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 0])
36
-
37
- # Define preprocessing steps for different column types
38
- numeric_features = ['Air temperature [K]', 'Process temperature [K]', 'Rotational speed [rpm]', 'Torque [Nm]', 'Tool wear [min]']
39
- categorical_features = ['Type']
40
-
41
- preprocessor = ColumnTransformer(
42
- transformers=[
43
- ('num', StandardScaler(), numeric_features),
44
- ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)
45
- ])
46
-
47
- # Create the full pipeline with preprocessing and a classifier
48
- model_pipeline = Pipeline(steps=[
49
- ('preprocessor', preprocessor),
50
- ('classifier', RandomForestClassifier(n_estimators=50, random_state=42))
51
- ])
52
-
53
- # Train the model
54
- model_pipeline.fit(mock_features, mock_target)
55
-
56
- # Save the trained model to a file
57
- joblib.dump(model_pipeline, MODEL_FILE)
58
- print(f"Model trained and saved to {MODEL_FILE}")
59
- return model_pipeline
60
-
61
- # Check if the model file exists; if not, create it.
62
- if not os.path.exists(MODEL_FILE):
63
- loaded_model = create_and_train_model()
64
- else:
65
  loaded_model = joblib.load(MODEL_FILE)
66
- print(f"Model loaded from {MODEL_FILE}")
 
 
 
 
67
 
68
 
69
  # =====================================================================================
@@ -72,6 +34,9 @@ else:
72
 
73
  def predict_failure(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
74
  """Predicts machine failure and calculates SHAP values using the loaded model."""
 
 
 
75
  input_data = pd.DataFrame({
76
  'Type': [Type], 'Air temperature [K]': [air_temperature],
77
  'Process temperature [K]': [process_temperature], 'Rotational speed [rpm]': [rotational_speed],
@@ -100,7 +65,7 @@ def predict_failure(Type, air_temperature, process_temperature, rotational_speed
100
 
101
  def generate_shap_plot(shap_values, feature_names, base_value):
102
  """Generates a SHAP waterfall plot for the Gradio interface."""
103
- plt.close('all') # Ensure plots don't stack in memory
104
  explanation = shap.Explanation(
105
  values=shap_values, base_values=base_value, feature_names=feature_names
106
  )
@@ -111,11 +76,21 @@ def generate_shap_plot(shap_values, feature_names, base_value):
111
 
112
  def predict_and_generate_plot(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
113
  """Wrapper function that connects the backend prediction to the frontend plot."""
114
- probability, shap_values, feature_names, base_value = predict_failure(
115
  Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear
116
  )
117
- shap_plot = generate_shap_plot(shap_values, feature_names, base_value)
118
- return f"{probability:.2%}", shap_plot # Format probability as percentage
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Define the Gradio interface layout and components
121
  with gr.Blocks(theme=gr.themes.Soft()) as iface_with_shap:
@@ -137,7 +112,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as iface_with_shap:
137
  probability_output = gr.Textbox(label="Probability of Machine Failure")
138
  plot_output = gr.Plot(label="Feature Contribution to Failure (SHAP Waterfall Plot)")
139
 
140
- # Connect the inputs to the function and outputs
 
 
 
 
 
 
 
141
  for input_comp in inputs:
142
  input_comp.change(
143
  fn=predict_and_generate_plot,
@@ -147,4 +129,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as iface_with_shap:
147
 
148
  # Launch the application
149
  if __name__ == "__main__":
150
- iface_with_shap.launch(debug=True)
 
1
+ # app.py
2
 
3
  import os
4
  import joblib
 
7
  import numpy as np
8
  import pandas as pd
9
  import shap
10
+ from sklearn.pipeline import Pipeline # Keep for type hinting/structure
 
 
 
11
 
12
  # =====================================================================================
13
+ # PART 1: MODEL LOADING
14
+ # This section now directly loads the model file you have uploaded to the repository.
 
15
  # =====================================================================================
16
 
17
+ # <<< CHANGE 1: Update the filename to match your uploaded model.
18
+ MODEL_FILE = "machine_failure_prediction_model.joblib"
19
+
20
+ # <<< CHANGE 2: Remove the model creation logic. We will now directly load your file.
21
+ # The app assumes this file exists in your Hugging Face repository.
22
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  loaded_model = joblib.load(MODEL_FILE)
24
+ print(f"Successfully loaded model from {MODEL_FILE}")
25
+ except FileNotFoundError:
26
+ print(f"Error: Model file not found at {MODEL_FILE}. Make sure the file is uploaded to the repository.")
27
+ # You might want to display an error in the Gradio app itself if the model fails to load.
28
+ loaded_model = None # Set to None to handle the error case
29
 
30
 
31
  # =====================================================================================
 
34
 
35
  def predict_failure(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
36
  """Predicts machine failure and calculates SHAP values using the loaded model."""
37
+ if loaded_model is None:
38
+ return "Error: Model not loaded.", None
39
+
40
  input_data = pd.DataFrame({
41
  'Type': [Type], 'Air temperature [K]': [air_temperature],
42
  'Process temperature [K]': [process_temperature], 'Rotational speed [rpm]': [rotational_speed],
 
65
 
66
  def generate_shap_plot(shap_values, feature_names, base_value):
67
  """Generates a SHAP waterfall plot for the Gradio interface."""
68
+ plt.close('all')
69
  explanation = shap.Explanation(
70
  values=shap_values, base_values=base_value, feature_names=feature_names
71
  )
 
76
 
77
  def predict_and_generate_plot(Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear):
78
  """Wrapper function that connects the backend prediction to the frontend plot."""
79
+ result = predict_failure(
80
  Type, air_temperature, process_temperature, rotational_speed, torque, tool_wear
81
  )
82
+
83
+ if isinstance(result, tuple):
84
+ probability, shap_values, feature_names, base_value = result
85
+ shap_plot = generate_shap_plot(shap_values, feature_names, base_value)
86
+ return f"{probability:.2%}", shap_plot
87
+ else:
88
+ # Handle the case where the model failed to load
89
+ error_message = result
90
+ empty_plot = plt.figure()
91
+ plt.text(0.5, 0.5, 'Error: Model could not be loaded.', horizontalalignment='center', verticalalignment='center')
92
+ return error_message, empty_plot
93
+
94
 
95
  # Define the Gradio interface layout and components
96
  with gr.Blocks(theme=gr.themes.Soft()) as iface_with_shap:
 
112
  probability_output = gr.Textbox(label="Probability of Machine Failure")
113
  plot_output = gr.Plot(label="Feature Contribution to Failure (SHAP Waterfall Plot)")
114
 
115
+ # This makes the app load the first prediction on startup
116
+ iface_with_shap.load(
117
+ fn=predict_and_generate_plot,
118
+ inputs=inputs,
119
+ outputs=[probability_output, plot_output]
120
+ )
121
+
122
+ # This connects the UI changes to the prediction function
123
  for input_comp in inputs:
124
  input_comp.change(
125
  fn=predict_and_generate_plot,
 
129
 
130
  # Launch the application
131
  if __name__ == "__main__":
132
+ iface_with_shap.launch()