rtik007 commited on
Commit
a9b33b5
·
verified ·
1 Parent(s): 3d9e921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np
2
  import pandas as pd
3
  from sklearn.datasets import make_classification
4
  from sklearn.ensemble import IsolationForest
 
5
  import shap
6
  import matplotlib.pyplot as plt
7
  import gradio as gr
@@ -40,12 +41,30 @@ anomaly_labels = iso_forest.predict(df) # -1 for anomaly, 1 for normal
40
  df["Anomaly_Score"] = anomaly_scores
41
  df["Anomaly_Label"] = np.where(anomaly_labels == -1, "Anomaly", "Normal")
42
 
 
 
 
43
  # SHAP Explainability
44
  explainer = shap.Explainer(iso_forest, df[columns])
45
  shap_values = explainer(df[columns])
46
 
47
  # Define functions for Gradio
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_anomaly_samples():
50
  """Returns formatted top, middle, and bottom 10 records based on anomaly score."""
51
  sorted_df = df.sort_values("Anomaly_Score", ascending=False)
@@ -63,13 +82,13 @@ with gr.Blocks() as demo:
63
  gr.Markdown("# Isolation Forest Anomaly Detection")
64
 
65
  with gr.Tab("Anomaly Samples"):
66
- gr.Markdown("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Top 10 Records (Anomalies)</h3>", unsafe_allow_html=True)
67
  top_table = gr.Dataframe(label="Top 10 Records")
68
 
69
- gr.Markdown("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Middle 10 Records (Mixed)</h3>", unsafe_allow_html=True)
70
  middle_table = gr.Dataframe(label="Middle 10 Records")
71
 
72
- gr.Markdown("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Bottom 10 Records (Normal)</h3>", unsafe_allow_html=True)
73
  bottom_table = gr.Dataframe(label="Bottom 10 Records")
74
 
75
  anomaly_samples_button = gr.Button("Show Anomaly Samples")
@@ -77,6 +96,12 @@ with gr.Blocks() as demo:
77
  get_anomaly_samples,
78
  outputs=[top_table, middle_table, bottom_table]
79
  )
 
 
 
 
 
 
80
 
81
  # Launch the Gradio app
82
  demo.launch()
 
2
  import pandas as pd
3
  from sklearn.datasets import make_classification
4
  from sklearn.ensemble import IsolationForest
5
+ from sklearn.metrics import roc_curve, auc
6
  import shap
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
 
41
  df["Anomaly_Score"] = anomaly_scores
42
  df["Anomaly_Label"] = np.where(anomaly_labels == -1, "Anomaly", "Normal")
43
 
44
+ # Generate true labels (1 for anomaly, 0 for normal) for ROC curve
45
+ true_labels = np.where(df["Anomaly_Label"] == "Anomaly", 1, 0)
46
+
47
  # SHAP Explainability
48
  explainer = shap.Explainer(iso_forest, df[columns])
49
  shap_values = explainer(df[columns])
50
 
51
  # Define functions for Gradio
52
 
53
+ def get_roc_curve():
54
+ """Generates the ROC curve plot."""
55
+ fpr, tpr, _ = roc_curve(true_labels, -df["Anomaly_Score"]) # Use -scores as higher scores mean normal
56
+ roc_auc = auc(fpr, tpr)
57
+ plt.figure(figsize=(8, 6))
58
+ plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.2f})")
59
+ plt.plot([0, 1], [0, 1], "k--", label="Random Guess")
60
+ plt.xlabel("False Positive Rate")
61
+ plt.ylabel("True Positive Rate")
62
+ plt.title("Receiver Operating Characteristic (ROC) Curve")
63
+ plt.legend(loc="lower right")
64
+ plt.grid()
65
+ plt.savefig("roc_curve.png")
66
+ return "roc_curve.png"
67
+
68
  def get_anomaly_samples():
69
  """Returns formatted top, middle, and bottom 10 records based on anomaly score."""
70
  sorted_df = df.sort_values("Anomaly_Score", ascending=False)
 
82
  gr.Markdown("# Isolation Forest Anomaly Detection")
83
 
84
  with gr.Tab("Anomaly Samples"):
85
+ gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Top 10 Records (Anomalies)</h3>")
86
  top_table = gr.Dataframe(label="Top 10 Records")
87
 
88
+ gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Middle 10 Records (Mixed)</h3>")
89
  middle_table = gr.Dataframe(label="Middle 10 Records")
90
 
91
+ gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Bottom 10 Records (Normal)</h3>")
92
  bottom_table = gr.Dataframe(label="Bottom 10 Records")
93
 
94
  anomaly_samples_button = gr.Button("Show Anomaly Samples")
 
96
  get_anomaly_samples,
97
  outputs=[top_table, middle_table, bottom_table]
98
  )
99
+
100
+ with gr.Tab("ROC Curve"):
101
+ gr.Markdown("### ROC Curve for Isolation Forest")
102
+ roc_button = gr.Button("Generate ROC Curve")
103
+ roc_image = gr.Image()
104
+ roc_button.click(get_roc_curve, outputs=roc_image)
105
 
106
  # Launch the Gradio app
107
  demo.launch()