AmnaHassan commited on
Commit
ea4398d
·
verified ·
1 Parent(s): bfd5a34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -88
app.py CHANGED
@@ -2,11 +2,12 @@ import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
 
5
  import joblib
6
  import plotly.graph_objects as go
7
 
8
  # -----------------------------
9
- # Page config (premium look)
10
  # -----------------------------
11
  st.set_page_config(
12
  page_title="Cyber Threat Detection Dashboard",
@@ -15,11 +16,24 @@ st.set_page_config(
15
  )
16
 
17
  # -----------------------------
18
- # Title Section
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # -----------------------------
20
  st.markdown("""
21
- # 🛡️ Cyber Threat Detection Dashboard
22
- **Deep Learning–Based Suspicious Activity Detection**
23
  """)
24
 
25
  st.markdown("---")
@@ -28,11 +42,10 @@ st.markdown("---")
28
  # Load scaler
29
  # -----------------------------
30
  scaler = joblib.load("scaler.pkl")
31
- INPUT_DIM = scaler.mean_.shape[0]
32
 
33
  # -----------------------------
34
- # Model definition
35
- # (MUST match training exactly)
36
  # -----------------------------
37
  model = nn.Sequential(
38
  nn.Linear(INPUT_DIM, 128),
@@ -48,88 +61,120 @@ model = nn.Sequential(
48
  nn.Linear(64, 1)
49
  )
50
 
51
- # -----------------------------
52
- # Load trained weights
53
- # -----------------------------
54
- state_dict = torch.load("model.pth", map_location="cpu")
55
- model.load_state_dict(state_dict)
56
  model.eval()
57
 
58
- # -----------------------------
59
- # Sidebar – Log Input
60
- # -----------------------------
61
- st.sidebar.header("🔍 Log Feature Input")
62
-
63
- features = []
64
- for i in range(INPUT_DIM):
65
- val = st.sidebar.number_input(
66
- f"Feature {i+1}",
67
- value=0.0,
68
- step=0.01
69
- )
70
- features.append(val)
71
 
72
- # -----------------------------
73
- # Prediction
74
- # -----------------------------
75
- if st.sidebar.button("🚨 Analyze Event"):
76
- x = np.array(features).reshape(1, -1)
77
- x_scaled = scaler.transform(x)
78
- x_tensor = torch.tensor(x_scaled, dtype=torch.float32)
79
-
80
- with torch.no_grad():
81
- prob = torch.sigmoid(model(x_tensor)).item()
82
-
83
- # -------------------------
84
- # Risk level logic
85
- # -------------------------
86
- if prob > 0.7:
87
- risk = "HIGH"
88
- color = "red"
89
- action = "Immediate investigation required. Isolate affected system."
90
- elif prob > 0.4:
91
- risk = "MEDIUM"
92
- color = "orange"
93
- action = "Monitor closely. Correlate with other logs."
94
- else:
95
- risk = "LOW"
96
- color = "green"
97
- action = "No action required. Log for auditing."
98
-
99
- # -------------------------
100
- # Main dashboard layout
101
- # -------------------------
102
- col1, col2, col3 = st.columns(3)
103
-
104
- col1.metric("Risk Level", risk)
105
- col2.metric("Suspicion Probability", f"{prob:.2f}")
106
- col3.metric("Recommended Action", action)
107
-
108
- # -------------------------
109
- # Gauge chart
110
- # -------------------------
111
- fig = go.Figure(go.Indicator(
112
- mode="gauge+number",
113
- value=prob * 100,
114
- title={'text': "Threat Confidence (%)"},
115
- gauge={
116
- 'axis': {'range': [0, 100]},
117
- 'bar': {'color': color},
118
- 'steps': [
119
- {'range': [0, 40], 'color': "lightgreen"},
120
- {'range': [40, 70], 'color': "orange"},
121
- {'range': [70, 100], 'color': "red"}
122
- ],
123
- }
124
- ))
125
-
126
- st.plotly_chart(fig, use_container_width=True)
127
-
128
- # -------------------------
129
- # Explanation block
130
- # -------------------------
131
- st.markdown("### 🧠 Model Interpretation")
132
- st.info(
133
- "The model analyzes behavioral patterns in system logs and assigns a risk score "
134
- "based on learned representations of malicious activity from the BETH dataset."
 
 
 
 
 
 
135
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
5
+ import pandas as pd
6
  import joblib
7
  import plotly.graph_objects as go
8
 
9
  # -----------------------------
10
+ # Page config
11
  # -----------------------------
12
  st.set_page_config(
13
  page_title="Cyber Threat Detection Dashboard",
 
16
  )
17
 
18
  # -----------------------------
19
+ # Feature names (MUST match training order)
20
+ # -----------------------------
21
+ FEATURE_COLUMNS = [
22
+ "processId",
23
+ "threadId",
24
+ "parentProcessId",
25
+ "userId",
26
+ "mountNamespace",
27
+ "argsNum",
28
+ "returnValue"
29
+ ]
30
+
31
+ # -----------------------------
32
+ # Title
33
  # -----------------------------
34
  st.markdown("""
35
+ # 🛡️ Cyber Threat Detection Dashboard
36
+ **SOC-Style Deep Learning–Based Suspicious Activity Detection**
37
  """)
38
 
39
  st.markdown("---")
 
42
  # Load scaler
43
  # -----------------------------
44
  scaler = joblib.load("scaler.pkl")
45
+ INPUT_DIM = len(FEATURE_COLUMNS)
46
 
47
  # -----------------------------
48
+ # Model (EXACT training architecture)
 
49
  # -----------------------------
50
  model = nn.Sequential(
51
  nn.Linear(INPUT_DIM, 128),
 
61
  nn.Linear(64, 1)
62
  )
63
 
64
+ model.load_state_dict(torch.load("model.pth", map_location="cpu"))
 
 
 
 
65
  model.eval()
66
 
67
+ # =============================
68
+ # SIDEBAR
69
+ # =============================
70
+ st.sidebar.header("🔍 Analysis Mode")
 
 
 
 
 
 
 
 
 
71
 
72
+ mode = st.sidebar.radio(
73
+ "Choose input type",
74
+ ["Single Log Event", "Upload CSV (Batch Analysis)"]
75
+ )
76
+
77
+ # =============================
78
+ # SINGLE EVENT MODE
79
+ # =============================
80
+ if mode == "Single Log Event":
81
+ st.sidebar.subheader("Log Features")
82
+
83
+ features = []
84
+ for col in FEATURE_COLUMNS:
85
+ val = st.sidebar.number_input(col, value=0)
86
+ features.append(val)
87
+
88
+ if st.sidebar.button("🚨 Analyze Event"):
89
+ x = np.array(features).reshape(1, -1)
90
+ x_scaled = scaler.transform(x)
91
+ x_tensor = torch.tensor(x_scaled, dtype=torch.float32)
92
+
93
+ with torch.no_grad():
94
+ prob = torch.sigmoid(model(x_tensor)).item()
95
+
96
+ # Risk logic
97
+ if prob > 0.7:
98
+ risk = "HIGH"
99
+ color = "red"
100
+ action = "Immediate investigation required. Isolate affected system."
101
+ elif prob > 0.4:
102
+ risk = "MEDIUM"
103
+ color = "orange"
104
+ action = "Monitor closely. Correlate with other logs."
105
+ else:
106
+ risk = "LOW"
107
+ color = "green"
108
+ action = "No action required. Log for auditing."
109
+
110
+ col1, col2, col3 = st.columns(3)
111
+ col1.metric("Risk Level", risk)
112
+ col2.metric("Suspicion Probability", f"{prob:.2f}")
113
+ col3.metric("Recommended Action", action)
114
+
115
+ fig = go.Figure(go.Indicator(
116
+ mode="gauge+number",
117
+ value=prob * 100,
118
+ title={'text': "Threat Confidence (%)"},
119
+ gauge={
120
+ 'axis': {'range': [0, 100]},
121
+ 'bar': {'color': color},
122
+ 'steps': [
123
+ {'range': [0, 40], 'color': "lightgreen"},
124
+ {'range': [40, 70], 'color': "orange"},
125
+ {'range': [70, 100], 'color': "red"}
126
+ ],
127
+ }
128
+ ))
129
+
130
+ st.plotly_chart(fig, use_container_width=True)
131
+
132
+ # =============================
133
+ # CSV UPLOAD MODE
134
+ # =============================
135
+ else:
136
+ st.subheader("📄 Batch Log Analysis")
137
+
138
+ uploaded_file = st.file_uploader(
139
+ "Upload CSV file (validation/test logs)",
140
+ type=["csv"]
141
  )
142
+
143
+ if uploaded_file:
144
+ df = pd.read_csv(uploaded_file)
145
+
146
+ # Drop label column if present
147
+ if "sus_label" in df.columns:
148
+ df = df.drop(columns=["sus_label"])
149
+
150
+ # Ensure correct columns
151
+ missing = set(FEATURE_COLUMNS) - set(df.columns)
152
+ if missing:
153
+ st.error(f"Missing required columns: {missing}")
154
+ else:
155
+ df = df[FEATURE_COLUMNS] # enforce correct order
156
+ X_scaled = scaler.transform(df.values)
157
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
158
+
159
+ with torch.no_grad():
160
+ probs = torch.sigmoid(model(X_tensor)).numpy().flatten()
161
+
162
+ df["suspicion_probability"] = probs
163
+ df["risk_level"] = df["suspicion_probability"].apply(
164
+ lambda p: "HIGH" if p > 0.7 else "MEDIUM" if p > 0.4 else "LOW"
165
+ )
166
+
167
+ st.success("Batch analysis completed")
168
+ st.dataframe(df, use_container_width=True)
169
+
170
+ st.markdown("### 📊 Risk Distribution")
171
+ st.bar_chart(df["risk_level"].value_counts())
172
+
173
+ # =============================
174
+ # Footer
175
+ # =============================
176
+ st.markdown("---")
177
+ st.info(
178
+ "This dashboard simulates a Security Operations Center (SOC) workflow by "
179
+ "analyzing system logs using a deep learning model trained on the BETH dataset."
180
+ )