sparsh007 commited on
Commit
e6bdb79
·
verified ·
1 Parent(s): c7f907f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import gradio as gr
4
+ import joblib # For loading the model
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ from PIL import Image
8
+
9
+ # Load the scaler and model
10
+ scaler_path = "scaler.pkl"
11
+ model_path = "ebm_model.pkl"
12
+
13
+ # Load the pre-trained scaler and model
14
+ scaler = joblib.load(scaler_path)
15
+ ebm = joblib.load(model_path)
16
+ print("Loaded saved model and scaler.")
17
+
18
+ # Global variable to store the most recent input data for explanation
19
+ last_input_data_scaled = None
20
+
21
+ # Define prediction function
22
+ def predict_strength(cement, blast_furnace_slag, fly_ash, water, superplasticizer,
23
+ coarse_aggregate, fine_aggregate, age):
24
+ global last_input_data_scaled
25
+ input_data = pd.DataFrame({
26
+ 'cement': [cement], 'blast_furnace_slag': [blast_furnace_slag],
27
+ 'fly_ash': [fly_ash], 'water': [water],
28
+ 'superplasticizer': [superplasticizer],
29
+ 'coarse_aggregate': [coarse_aggregate],
30
+ 'fine_aggregate': [fine_aggregate], 'age': [age]
31
+ })
32
+ last_input_data_scaled = scaler.transform(input_data)
33
+ prediction = ebm.predict(last_input_data_scaled)
34
+ return prediction[0]
35
+
36
+ # Explanation function for enhanced visual
37
+ def show_local_explanation():
38
+ if last_input_data_scaled is not None:
39
+ local_exp = ebm.explain_local(last_input_data_scaled)
40
+ contributions = local_exp.data(0)['scores']
41
+ names = local_exp.data(0)['names']
42
+
43
+ # Enhanced Plotting
44
+ fig, ax = plt.subplots(figsize=(10, 6))
45
+ colors = ['red' if x < 0 else 'green' for x in contributions]
46
+ ax.barh(names, contributions, color=colors)
47
+ ax.set_xlabel('Contribution to Prediction')
48
+ ax.set_title('Local Explanation for the Most Recent Prediction')
49
+
50
+ # Save plot to a buffer
51
+ buf = io.BytesIO()
52
+ plt.savefig(buf, format='png', bbox_inches='tight')
53
+ buf.seek(0)
54
+ plt.close()
55
+
56
+ # Load image for display
57
+ img = Image.open(buf)
58
+ img_array = np.array(img)
59
+ return img_array
60
+ else:
61
+ return "No prediction has been made yet."
62
+
63
+ # Gradio interface setup with introduction and instructions
64
+ with gr.Blocks() as app:
65
+ gr.Markdown("## Concrete Strength Prediction App")
66
+ gr.Markdown("""
67
+ This app predicts the compressive strength of concrete based on its composition using the Explainable Boosting Machine (EBM).
68
+ EBM is a transparent, interpretable machine learning model that combines the power of boosting techniques with interpretable models,
69
+ making it easier to explain prediction outcomes.
70
+ """)
71
+ gr.Markdown("### Instructions")
72
+ gr.Markdown("""
73
+ - Enter the composition of the concrete in the input fields.
74
+ - Click 'Predict Concrete Strength' to see the predicted strength.
75
+ - Click 'Show Local Explanation' to see the contribution of each feature to the prediction.
76
+ """)
77
+
78
+ with gr.Row():
79
+ cement = gr.Number(label="Cement")
80
+ slag = gr.Number(label="Blast Furnace Slag")
81
+ fly_ash = gr.Number(label="Fly Ash")
82
+ water = gr.Number(label="Water")
83
+ superplasticizer = gr.Number(label="Superplasticizer")
84
+ coarse_agg = gr.Number(label="Coarse Aggregate")
85
+ fine_agg = gr.Number(label="Fine Aggregate")
86
+ age = gr.Number(label="Age")
87
+
88
+ predict_btn = gr.Button("Predict Concrete Strength")
89
+ explanation_btn = gr.Button("Show Local Explanation")
90
+ result = gr.Textbox(label="Predicted Concrete Strength")
91
+ local_image = gr.Image(label="Local Explanation", type="numpy")
92
+
93
+ predict_btn.click(
94
+ fn=predict_strength,
95
+ inputs=[cement, slag, fly_ash, water, superplasticizer, coarse_agg, fine_agg, age],
96
+ outputs=result
97
+ )
98
+ explanation_btn.click(
99
+ fn=show_local_explanation,
100
+ inputs=[],
101
+ outputs=local_image
102
+ )
103
+
104
+ app.launch()