rtik007 commited on
Commit
cadabca
·
verified ·
1 Parent(s): 6a25285

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -16
app.py CHANGED
@@ -5,6 +5,7 @@ from sklearn.ensemble import IsolationForest
5
  import shap
6
  import matplotlib.pyplot as plt
7
  from itertools import combinations
 
8
 
9
  # Generate synthetic data with 20 features
10
  np.random.seed(42)
@@ -44,24 +45,32 @@ df["Anomaly_Label"] = np.where(anomaly_labels == -1, "Anomaly", "Normal")
44
  explainer = shap.Explainer(iso_forest, df[columns])
45
  shap_values = explainer(df[columns])
46
 
47
- # SHAP Summary Plot (Global Explainability)
48
- shap.summary_plot(shap_values, df[columns], feature_names=columns)
49
 
50
- # SHAP Waterfall Plot for a Specific Data Point (Local Explainability)
51
- specific_index = df[df["Anomaly_Label"] == "Anomaly"].index[0] # Select first anomaly
52
- shap.waterfall_plot(
53
- shap.Explanation(
54
- values=shap_values.values[specific_index],
55
- base_values=shap_values.base_values[specific_index],
56
- data=df.iloc[specific_index],
57
- feature_names=columns
58
- )
59
- )
60
 
61
- # Scatter plots for pairwise combinations of features
62
- feature_combinations = list(combinations(columns[:5], 2)) # Use first 5 features for simplicity
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- for feature1, feature2 in feature_combinations:
 
65
  plt.figure(figsize=(8, 6))
66
  plt.scatter(
67
  df[feature1],
@@ -74,4 +83,33 @@ for feature1, feature2 in feature_combinations:
74
  plt.title(f"Isolation Forest - {feature1} vs {feature2}")
75
  plt.xlabel(feature1)
76
  plt.ylabel(feature2)
77
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import shap
6
  import matplotlib.pyplot as plt
7
  from itertools import combinations
8
+ import gradio as gr
9
 
10
  # Generate synthetic data with 20 features
11
  np.random.seed(42)
 
45
  explainer = shap.Explainer(iso_forest, df[columns])
46
  shap_values = explainer(df[columns])
47
 
48
+ # Define functions for Gradio
 
49
 
50
+ def get_shap_summary():
51
+ """Generates SHAP summary plot."""
52
+ plt.figure()
53
+ shap.summary_plot(shap_values, df[columns], feature_names=columns, show=False)
54
+ plt.savefig("shap_summary.png")
55
+ return "shap_summary.png"
 
 
 
 
56
 
57
+ def get_shap_waterfall(index):
58
+ """Generates SHAP waterfall plot for a specific data point."""
59
+ specific_index = int(index)
60
+ plt.figure()
61
+ shap.waterfall_plot(
62
+ shap.Explanation(
63
+ values=shap_values.values[specific_index],
64
+ base_values=shap_values.base_values[specific_index],
65
+ data=df.iloc[specific_index],
66
+ feature_names=columns
67
+ )
68
+ )
69
+ plt.savefig("shap_waterfall.png")
70
+ return "shap_waterfall.png"
71
 
72
+ def get_scatter_plot(feature1, feature2):
73
+ """Generates scatter plot for two features."""
74
  plt.figure(figsize=(8, 6))
75
  plt.scatter(
76
  df[feature1],
 
83
  plt.title(f"Isolation Forest - {feature1} vs {feature2}")
84
  plt.xlabel(feature1)
85
  plt.ylabel(feature2)
86
+ plt.savefig("scatter_plot.png")
87
+ return "scatter_plot.png"
88
+
89
+ # Create Gradio interface
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# Isolation Forest Anomaly Detection")
92
+
93
+ with gr.Tab("SHAP Summary"):
94
+ gr.Markdown("### Global Explainability: SHAP Summary Plot")
95
+ shap_button = gr.Button("Generate SHAP Summary Plot")
96
+ shap_image = gr.Image()
97
+ shap_button.click(get_shap_summary, outputs=shap_image)
98
+
99
+ with gr.Tab("SHAP Waterfall"):
100
+ gr.Markdown("### Local Explainability: SHAP Waterfall Plot")
101
+ index_input = gr.Number(label="Data Point Index", value=0)
102
+ shap_waterfall_button = gr.Button("Generate SHAP Waterfall Plot")
103
+ shap_waterfall_image = gr.Image()
104
+ shap_waterfall_button.click(get_shap_waterfall, inputs=index_input, outputs=shap_waterfall_image)
105
+
106
+ with gr.Tab("Feature Scatter Plot"):
107
+ gr.Markdown("### Feature Interaction: Scatter Plot")
108
+ feature1_dropdown = gr.Dropdown(choices=columns, label="Feature 1")
109
+ feature2_dropdown = gr.Dropdown(choices=columns, label="Feature 2")
110
+ scatter_button = gr.Button("Generate Scatter Plot")
111
+ scatter_image = gr.Image()
112
+ scatter_button.click(get_scatter_plot, inputs=[feature1_dropdown, feature2_dropdown], outputs=scatter_image)
113
+
114
+ # Launch the Gradio app
115
+ demo.launch()