MohammedAH commited on
Commit
bab81ab
·
verified ·
1 Parent(s): c4f3a2e

Upload 2 files

Browse files
Files changed (2) hide show
  1. intrusion.py +433 -0
  2. intrusion_model.h5 +3 -0
intrusion.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ from tensorflow.keras.models import load_model
8
+
9
+ # Configure styling
10
+ sns.set_theme(style="whitegrid")
11
+ st.set_page_config(
12
+ page_title="Federated Learning for Anomaly Detection in IOT Environments",
13
+ page_icon="🛡️",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # Load the pre-trained model
19
+ @st.cache_resource
20
+ def load_intrusion_model():
21
+ return load_model('intrusion_model.h5')
22
+
23
+ # Define attack type labels
24
+ ATTACK_TYPES = {
25
+ 0: 'Normal', 1: 'Backdoor', 2: 'DDoS_HTTP',
26
+ 3: 'DDoS_ICMP', 4: 'DDoS_TCP', 5: 'DDoS_UDP',
27
+ 6: 'Fingerprinting', 7: 'MITM', 8: 'Password',
28
+ 9: 'Port_Scanning', 10: 'Ransomware', 11: 'SQL_injection',
29
+ 12: 'Uploading', 13: 'Vulnerability_scanner', 14: 'XSS'
30
+ }
31
+
32
+ # Critical attacks that trigger alerts
33
+ CRITICAL_ATTACKS = {
34
+ 'DDoS_HTTP', 'DDoS_ICMP', 'DDoS_TCP', 'DDoS_UDP',
35
+ 'Ransomware', 'SQL_injection', 'Port_Scanning'
36
+ }
37
+
38
+ # Create the Streamlit app
39
+ def main():
40
+ # Sidebar with model information
41
+ st.sidebar.header("About")
42
+ st.sidebar.markdown("""
43
+ **Federated Learning for Anomaly Detection in IOT Environments**
44
+ This system detects and classifies cyber attacks on IoT networks using deep learning.
45
+ The model achieves 93.6% accuracy on validation data.
46
+ """)
47
+
48
+ st.sidebar.subheader("Attack Types")
49
+ for code, name in ATTACK_TYPES.items():
50
+ st.sidebar.caption(f"{code}: {name}")
51
+
52
+ st.sidebar.subheader("Attack Severity")
53
+ st.sidebar.markdown("""
54
+ - 🔴 **Critical**: DDoS, Ransomware, SQL Injection
55
+ - 🟠 **High**: Port Scanning, Backdoor
56
+ - 🟢 **Medium**: Other attacks
57
+ - ⚪ **Normal**: Benign traffic
58
+ """)
59
+
60
+ st.sidebar.divider()
61
+ st.sidebar.info("﹫2025")
62
+ st.sidebar.download_button(
63
+ label="Download Sample CSV",
64
+ data=pd.DataFrame(columns=range(1, 250)).to_csv(index=False),
65
+ file_name="sample_features.csv",
66
+ mime="text/csv"
67
+ )
68
+
69
+ # Main content
70
+ st.title("🛡️ Federated Learning for Anomaly Detection in IOT Environments")
71
+ st.caption("Detect and classify security threats in IoT network traffic")
72
+
73
+ # Initialize session state
74
+ if 'predictions' not in st.session_state:
75
+ st.session_state.predictions = None
76
+ if 'critical_alerts' not in st.session_state:
77
+ st.session_state.critical_alerts = []
78
+
79
+ # Load model
80
+ try:
81
+ model = load_intrusion_model()
82
+ except Exception as e:
83
+ st.error(f"Error loading model: {str(e)}")
84
+ st.stop()
85
+
86
+ # Alert banner area at top
87
+ alert_placeholder = st.empty()
88
+
89
+ # Prediction section
90
+ tab1, tab2 = st.tabs(["📊 Batch Prediction", "🔍 Single Prediction"])
91
+
92
+ with tab1:
93
+ st.subheader("Batch Prediction from CSV")
94
+ uploaded_file = st.file_uploader("Upload IoT device data (CSV)", type="csv")
95
+
96
+ if uploaded_file:
97
+ try:
98
+ df = pd.read_csv(uploaded_file)
99
+ st.success(f"Successfully loaded {len(df)} records")
100
+
101
+ # Show sample data
102
+ if st.checkbox("Show data preview"):
103
+ st.dataframe(df.head())
104
+
105
+ # Validate features
106
+ if len(df.columns) != 249:
107
+ st.warning(f"Data should have 249 features. Found {len(df.columns)} columns.")
108
+ st.info("Ensure your CSV has exactly 249 columns representing the model features")
109
+ else:
110
+ # Make predictions
111
+ if st.button("Run Predictions", type="primary"):
112
+ with st.spinner("Analyzing network traffic..."):
113
+ # Preprocess and predict
114
+ X = df.values.astype('float32')
115
+ pred_probs = model.predict(X, verbose=0)
116
+ pred_classes = np.argmax(pred_probs, axis=1)
117
+ confidence_scores = np.max(pred_probs, axis=1)
118
+
119
+ # Add predictions to dataframe
120
+ df['Predicted_Attack'] = [ATTACK_TYPES[c] for c in pred_classes]
121
+ df['Prediction_Confidence'] = confidence_scores
122
+
123
+ # Store in session state
124
+ st.session_state.predictions = df
125
+ st.session_state.critical_alerts = df[
126
+ df['Predicted_Attack'].isin(CRITICAL_ATTACKS)
127
+ ]
128
+
129
+ except Exception as e:
130
+ st.error(f"Error processing file: {str(e)}")
131
+
132
+ # Display results if available
133
+ if st.session_state.predictions is not None:
134
+ df = st.session_state.predictions
135
+
136
+ # Critical attack alert
137
+ if not st.session_state.critical_alerts.empty:
138
+ critical_count = len(st.session_state.critical_alerts)
139
+ with alert_placeholder.container():
140
+ st.error(f"🚨 **CRITICAL THREAT DETECTED!** - {critical_count} critical attacks identified",
141
+ icon="⚠️")
142
+
143
+ st.subheader("Prediction Results")
144
+
145
+ # Summary stats
146
+ normal_count = len(df[df['Predicted_Attack'] == 'Normal'])
147
+ attack_count = len(df) - normal_count
148
+ critical_count = len(st.session_state.critical_alerts)
149
+
150
+ col1, col2, col3 = st.columns(3)
151
+ col1.metric("Total Records", len(df))
152
+ col2.metric("Attack Traffic", f"{attack_count} ({attack_count/len(df):.1%})")
153
+ col3.metric("Critical Threats", critical_count,
154
+ f"{critical_count/attack_count:.1%}" if attack_count else "0%")
155
+
156
+ # Visualization section
157
+ st.subheader("Attack Analysis")
158
+
159
+ # Tabs for different visualizations
160
+ viz_tab1, viz_tab2, viz_tab3, viz_tab4 = st.tabs([
161
+ "Attack Distribution",
162
+ "Confidence Analysis",
163
+ "Threat Severity",
164
+ "Detailed Results"
165
+ ])
166
+
167
+ with viz_tab1:
168
+ col1, col2 = st.columns([3, 2])
169
+
170
+ with col1:
171
+ # Attack type bar chart
172
+ st.markdown("**Attack Type Distribution**")
173
+ attack_counts = df['Predicted_Attack'].value_counts()
174
+ fig, ax = plt.subplots(figsize=(10, 6))
175
+ sns.barplot(
176
+ x=attack_counts.values,
177
+ y=attack_counts.index,
178
+ palette="viridis",
179
+ ax=ax
180
+ )
181
+ plt.xlabel("Count")
182
+ plt.ylabel("Attack Type")
183
+ plt.title("Attack Frequency Distribution")
184
+ st.pyplot(fig)
185
+
186
+ with col2:
187
+ # Attack type pie chart
188
+ st.markdown("**Attack Proportion**")
189
+ normal_attack = df['Predicted_Attack'] != 'Normal'
190
+ attack_ratio = normal_attack.value_counts(normalize=True)
191
+
192
+ fig, ax = plt.subplots(figsize=(8, 6))
193
+ attack_ratio.plot.pie(
194
+ autopct='%1.1f%%',
195
+ labels=['Normal', 'Attack'],
196
+ colors=['#2ca02c', '#d62728'],
197
+ startangle=90,
198
+ ax=ax
199
+ )
200
+ plt.title("Normal vs Attack Traffic")
201
+ plt.ylabel("")
202
+ st.pyplot(fig)
203
+
204
+ with viz_tab2:
205
+ col1, col2 = st.columns(2)
206
+
207
+ with col1:
208
+ # Confidence histogram
209
+ st.markdown("**Confidence Distribution**")
210
+ fig, ax = plt.subplots(figsize=(10, 6))
211
+ sns.histplot(
212
+ df['Prediction_Confidence'],
213
+ bins=20,
214
+ kde=True,
215
+ color='#1f77b4',
216
+ ax=ax
217
+ )
218
+ plt.axvline(x=0.9, color='r', linestyle='--', label='High Confidence')
219
+ plt.xlabel("Confidence Score")
220
+ plt.ylabel("Frequency")
221
+ plt.title("Prediction Confidence Distribution")
222
+ plt.legend()
223
+ st.pyplot(fig)
224
+
225
+ with col2:
226
+ # Confidence by attack type
227
+ st.markdown("**Confidence by Attack Type**")
228
+ fig, ax = plt.subplots(figsize=(10, 6))
229
+ sns.boxplot(
230
+ x=df['Prediction_Confidence'],
231
+ y=df['Predicted_Attack'],
232
+ palette="Set3",
233
+ ax=ax
234
+ )
235
+ plt.xlabel("Confidence Score")
236
+ plt.ylabel("Attack Type")
237
+ plt.title("Confidence Distribution per Attack Type")
238
+ st.pyplot(fig)
239
+
240
+ with viz_tab3:
241
+ # Define severity levels
242
+ severity_map = {
243
+ 'Normal': 'Normal',
244
+ 'DDoS_HTTP': 'Critical',
245
+ 'DDoS_ICMP': 'Critical',
246
+ 'DDoS_TCP': 'Critical',
247
+ 'DDoS_UDP': 'Critical',
248
+ 'Ransomware': 'Critical',
249
+ 'SQL_injection': 'Critical',
250
+ 'Port_Scanning': 'High',
251
+ 'Backdoor': 'High',
252
+ 'Fingerprinting': 'Medium',
253
+ 'MITM': 'Medium',
254
+ 'Password': 'Medium',
255
+ 'Uploading': 'Medium',
256
+ 'Vulnerability_scanner': 'Medium',
257
+ 'XSS': 'Medium'
258
+ }
259
+
260
+ df['Severity'] = df['Predicted_Attack'].map(severity_map)
261
+
262
+ col1, col2 = st.columns(2)
263
+
264
+ with col1:
265
+ # Severity pie chart
266
+ st.markdown("**Threat Severity Distribution**")
267
+ severity_counts = df['Severity'].value_counts()
268
+
269
+ fig, ax = plt.subplots(figsize=(8, 8))
270
+ colors = {'Critical': '#d62728', 'High': '#ff7f0e',
271
+ 'Medium': '#e377c2', 'Normal': '#2ca02c'}
272
+ severity_counts.plot.pie(
273
+ autopct='%1.1f%%',
274
+ colors=[colors[s] for s in severity_counts.index],
275
+ startangle=90,
276
+ ax=ax
277
+ )
278
+ plt.title("Threat Severity Levels")
279
+ plt.ylabel("")
280
+ st.pyplot(fig)
281
+
282
+ with col2:
283
+ # Severity count plot
284
+ st.markdown("**Threat Severity Counts**")
285
+ fig, ax = plt.subplots(figsize=(10, 6))
286
+ sns.countplot(
287
+ x=df['Severity'],
288
+ order=['Critical', 'High', 'Medium', 'Normal'],
289
+ palette=list(colors.values()),
290
+ ax=ax
291
+ )
292
+ plt.xlabel("Severity Level")
293
+ plt.ylabel("Count")
294
+ plt.title("Threat Severity Distribution")
295
+ st.pyplot(fig)
296
+
297
+ with viz_tab4:
298
+ # Detailed results table
299
+ st.dataframe(df[['Predicted_Attack', 'Prediction_Confidence', 'Severity']].head(50))
300
+
301
+ # Download results
302
+ st.divider()
303
+ csv = df.to_csv(index=False)
304
+ st.download_button(
305
+ label="Download Full Predictions",
306
+ data=csv,
307
+ file_name="intrusion_predictions.csv",
308
+ mime="text/csv",
309
+ type="primary"
310
+ )
311
+
312
+ with tab2:
313
+ st.subheader("Single Prediction")
314
+ st.markdown("Enter feature values manually for real-time threat detection")
315
+
316
+ # Create input form
317
+ with st.form("single_prediction"):
318
+ # Generate sample input features
319
+ sample_features = [0.0] * 249
320
+ inputs = []
321
+
322
+ st.info("For demonstration, only the first 10 features are shown. Others are set to default values.")
323
+
324
+ # Split into 3 columns for better layout
325
+ col1, col2, col3 = st.columns(3)
326
+ cols = [col1, col2, col3]
327
+
328
+ # Only show first 10 features to save space
329
+ features_to_show = 10
330
+
331
+ for i in range(features_to_show):
332
+ with cols[i % 3]:
333
+ inputs.append(
334
+ st.number_input(
335
+ f"Feature {i+1}",
336
+ value=sample_features[i],
337
+ key=f"feature_{i}",
338
+ step=0.001
339
+ )
340
+ )
341
+
342
+ # Fill remaining features with default values
343
+ inputs += sample_features[features_to_show:]
344
+
345
+ submit = st.form_submit_button("Analyze Traffic", type="primary")
346
+
347
+ if submit:
348
+ try:
349
+ # Prepare input data
350
+ input_array = np.array([inputs], dtype='float32')
351
+
352
+ # Make prediction
353
+ pred_prob = model.predict(input_array, verbose=0)
354
+ pred_class = np.argmax(pred_prob, axis=1)[0]
355
+ confidence = np.max(pred_prob)
356
+ attack_name = ATTACK_TYPES[pred_class]
357
+
358
+ # Check if critical
359
+ is_critical = attack_name in CRITICAL_ATTACKS
360
+
361
+ # Display alert
362
+ if is_critical:
363
+ with alert_placeholder.container():
364
+ st.error(f"🚨 **CRITICAL THREAT DETECTED!** - {attack_name} attack identified",
365
+ icon="⚠️")
366
+
367
+ # Display results
368
+ st.subheader("Analysis Result")
369
+
370
+ # Create columns for results
371
+ col1, col2 = st.columns([1, 2])
372
+
373
+ with col1:
374
+ # Attack type card
375
+ severity = "Critical" if is_critical else "Normal" if attack_name == "Normal" else "Warning"
376
+ color = "#d62728" if is_critical else "#2ca02c" if attack_name == "Normal" else "#ff7f0e"
377
+
378
+ st.markdown(f"""
379
+ <div style="
380
+ border: 1px solid {color};
381
+ border-radius: 10px;
382
+ padding: 20px;
383
+ text-align: center;
384
+ background-color: #f0f2f6;
385
+ margin-bottom: 20px;
386
+ ">
387
+ <h3 style="color: {color}; margin-top: 0;">{attack_name}</h3>
388
+ <p style="font-size: 18px; margin-bottom: 5px;">Threat Level: <strong>{severity}</strong></p>
389
+ <p style="font-size: 18px;">Confidence: <strong>{confidence:.2%}</strong></p>
390
+ </div>
391
+ """, unsafe_allow_html=True)
392
+
393
+ # Confidence indicator
394
+ st.metric("Prediction Confidence", f"{confidence:.2%}")
395
+ st.progress(float(confidence))
396
+
397
+ with col2:
398
+ # Probability distribution
399
+ prob_df = pd.DataFrame({
400
+ 'Attack Type': list(ATTACK_TYPES.values()),
401
+ 'Probability': pred_prob[0]
402
+ }).sort_values('Probability', ascending=False)
403
+
404
+ # Top 10 probabilities
405
+ top_probs = prob_df.head(10)
406
+
407
+ fig, ax = plt.subplots(figsize=(10, 6))
408
+ sns.barplot(
409
+ x='Probability',
410
+ y='Attack Type',
411
+ data=top_probs,
412
+ palette="rocket",
413
+ ax=ax
414
+ )
415
+ plt.title("Top 10 Predicted Attack Probabilities")
416
+ plt.xlabel("Probability")
417
+ plt.ylabel("")
418
+ st.pyplot(fig)
419
+
420
+ # Show full probability table
421
+ with st.expander("View Complete Probability Distribution"):
422
+ prob_df['Probability'] = prob_df['Probability'].apply(lambda x: f"{x:.4f}")
423
+ st.dataframe(prob_df)
424
+
425
+ except Exception as e:
426
+ st.error(f"Prediction error: {str(e)}")
427
+
428
+ # Add footer
429
+ st.divider()
430
+ st.caption("IoT Security Dashboard v1.0 | Real-time Threat Detection System")
431
+
432
+ if __name__ == "__main__":
433
+ main()
intrusion_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b1a75a20cb963180b8a87135da6ba592d32ddc6239a5b54165012a3b8232a82
3
+ size 558180