Taranpreet Singh commited on
Commit
98d799c
·
1 Parent(s): 8b03389

Phase 1: Offline NIDS prototype with CV, threshold tuning

Browse files
Files changed (4) hide show
  1. README.md +77 -0
  2. app.py +226 -0
  3. requirements.txt +5 -0
  4. sample_data/sample_small.csv +11 -0
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI NIDS Student Project
3
+ emoji: 🛡️
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+
11
+
12
+ ---
13
+ # 🛡️ AI-Based Network Intrusion Detection System (Student Project)
14
+
15
+ Project Status: Phase 1 – Pre-Production / Offline Prototype
16
+
17
+ This project demonstrates how to use **Machine Learning (Random Forest)** and **Generative AI (Groq)** to detect and explain network attacks (specifically DDoS).
18
+
19
+ ## 🚀 How to Use
20
+
21
+ 1. **Enter API Key:** Paste your Groq API key in the sidebar (optional, for AI explanations).
22
+ 2. **Train Model:** Click the "Train Model Now" button. The system loads the `Friday-WorkingHours...` dataset automatically.
23
+ 3. **Simulate:** Click "🎲 Capture Random Packet" to pick a real network packet from the test set.
24
+ 4. **Analyze:** See if the model flags it as **BENIGN** or **DDoS**, and ask Groq to explain why.
25
+
26
+ ## 📂 Files
27
+
28
+ - `app.py`: The main Python application code.
29
+ - `requirements.txt`: List of libraries used.
30
+ - `Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv`: The dataset (CIC-IDS2017 subset).
31
+
32
+ ## 🔧 PHASE 0 — Foundation Hardening (completed)
33
+
34
+ This repository includes an incremental, production-aligned hardening of the original student project.
35
+
36
+ - Deterministic reproducibility (global seed, logging).
37
+ - Explicit data validation and feature checks.
38
+ - Class-imbalance handling via `class_weight='balanced'`.
39
+ - Stratified 5-fold cross-validation with per-fold metrics.
40
+ - Evaluation metrics replaced accuracy with: precision, recall, F1, PR-AUC, ROC-AUC, and confusion matrices.
41
+ - Artifacts saved to `models/` and `metrics/` (see below).
42
+
43
+ These changes are intentionally small and reversible — see `training_utils.py` for the training implementation.
44
+
45
+ ## 📦 Artifacts (generated after training)
46
+
47
+ - `models/rf_model.joblib` — serialized RandomForest model (best fold).
48
+ - `metrics/training_metrics.json` — timestamped CV metrics including PR-curve, seed, feature list.
49
+
50
+ ## ⚠️ Dataset & Publishing
51
+
52
+ - ⚠️ Dataset Note: The full CIC-IDS2017 CSV (~96 MB) is intentionally excluded from GitHub.
53
+ This repository focuses on model architecture and training logic. A small sample or synthetic dataset (`sample_data/sample_small.csv`) is included for demos; the full dataset is not committed.
54
+
55
+ ## ▶️ Run locally
56
+
57
+ 1. Create a virtual environment and install dependencies:
58
+
59
+ ```powershell
60
+ python -m venv .venv
61
+ .\.venv\Scripts\Activate.ps1
62
+ pip install -r requirements.txt
63
+ ```
64
+
65
+ 2. Run the Streamlit app:
66
+
67
+ ```powershell
68
+ streamlit run app.py
69
+ ```
70
+
71
+ ## Contact / Next steps
72
+
73
+ If you want, I can generate a small sample CSV (e.g., 1k rows) that allows publishing the repo to GitHub safely.
74
+
75
+ ## 🎓 About
76
+
77
+ Created for a university cybersecurity project to demonstrate the integration of traditional ML and LLMs in security operations.
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ import os
6
+ import json
7
+ import logging
8
+ import joblib
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from sklearn.model_selection import StratifiedKFold
11
+ from sklearn.metrics import (precision_recall_fscore_support, roc_auc_score,
12
+ average_precision_score, confusion_matrix)
13
+ from groq import Groq
14
+ from training_utils import train_model_cv
15
+
16
+ # --- REPRODUCIBILITY & LOGGING ---
17
+ SEED = 42
18
+ random.seed(SEED)
19
+ np.random.seed(SEED)
20
+ os.environ['PYTHONHASHSEED'] = str(SEED)
21
+
22
+ logging.basicConfig(level=logging.INFO, filename='training.log', filemode='a',
23
+ format='%(asctime)s %(levelname)s %(message)s')
24
+ logger = logging.getLogger('nids')
25
+ logger.info('App start, seed=%s', SEED)
26
+
27
+ # --- PAGE SETUP ---
28
+ st.set_page_config(page_title="AI-NIDS Student Project", layout="wide")
29
+
30
+ st.title("AI-Based Network Intrusion Detection System")
31
+ st.markdown("""
32
+ **Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets.
33
+ """)
34
+
35
+ # --- CONFIGURATION ---
36
+ DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"
37
+
38
+ # --- SIDEBAR: SETTINGS ---
39
+ st.sidebar.header("1. Settings")
40
+ groq_api_key = st.sidebar.text_input(
41
+ "Groq API Key (optional)", type="password")
42
+ st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)")
43
+
44
+ st.sidebar.header("2. Model Training")
45
+ NROWS = st.sidebar.number_input(
46
+ "Max rows to load", min_value=1000, value=15000, step=1000)
47
+ N_EST = st.sidebar.slider("RF trees (n_estimators)", 10, 500, 100, step=10)
48
+ MAX_DEPTH = st.sidebar.slider("Max tree depth (0=none)", 0, 50, 12)
49
+
50
+
51
+ @st.cache_data
52
+ def load_data(filepath, nrows=None):
53
+ try:
54
+ df = pd.read_csv(filepath, nrows=nrows)
55
+ df.columns = df.columns.str.strip()
56
+ df.replace([np.inf, -np.inf], np.nan, inplace=True)
57
+ # require Label column
58
+ if 'Label' not in df.columns:
59
+ logger.error('Label column missing in dataset')
60
+ return None
61
+ # drop rows missing label
62
+ df = df.dropna(subset=['Label'])
63
+ # fill numeric nulls with median
64
+ num_cols = df.select_dtypes(include=[np.number]).columns
65
+ for c in num_cols:
66
+ med = df[c].median()
67
+ df[c].fillna(med, inplace=True)
68
+ return df
69
+ except FileNotFoundError:
70
+ logger.exception('Data file not found: %s', filepath)
71
+ return None
72
+
73
+
74
+ # Training and metric utilities moved to training_utils.py
75
+
76
+
77
+ # --- APP LOGIC ---
78
+ df = load_data(DATA_FILE, nrows=NROWS)
79
+
80
+ if df is None:
81
+ st.error(
82
+ f"Error: File '{DATA_FILE}' not found or missing required columns. Please upload it to the Files tab.")
83
+ st.stop()
84
+
85
+ st.sidebar.success(f"Dataset Loaded: {len(df)} rows")
86
+
87
+ DEFAULT_FEATURES = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets',
88
+ 'Total Length of Fwd Packets', 'Fwd Packet Length Max',
89
+ 'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s']
90
+
91
+ if st.sidebar.button("Train Model Now"):
92
+ with st.spinner("Training model (Stratified CV)..."):
93
+ try:
94
+ clf, metrics, X_all, y_all, val_probas, val_labels, encoder = train_model_cv(
95
+ df, DEFAULT_FEATURES, n_splits=5, n_estimators=N_EST, max_depth=MAX_DEPTH, seed=SEED
96
+ )
97
+ st.session_state['model'] = clf
98
+ st.session_state['features'] = list(X_all.columns)
99
+ # keep small test slice for simulation UI
100
+ st.session_state['X_test'] = pd.DataFrame(
101
+ X_all).sample(frac=0.2, random_state=SEED)
102
+ st.session_state['y_test'] = pd.Series(
103
+ y_all)[st.session_state['X_test'].index]
104
+ st.session_state['encoder'] = encoder
105
+ st.session_state['val_probas'] = val_probas
106
+ st.session_state['val_labels'] = val_labels
107
+
108
+ # display aggregate metrics
109
+ agg = metrics['aggregate']
110
+ st.sidebar.success(
111
+ f"Training complete — F1: {agg['f1_mean']:.3f} ± {agg['f1_std']:.3f}")
112
+ st.subheader('Training Metrics (aggregate)')
113
+ st.json(agg)
114
+ st.subheader('Per-fold metrics')
115
+ st.json(metrics['folds'])
116
+
117
+ # Threshold tuning UI (operational)
118
+ st.sidebar.header('3. Threshold Tuning')
119
+ threshold = st.sidebar.slider(
120
+ 'Decision threshold (ATTACK probability)',
121
+ 0.10, 0.90, 0.50, 0.01
122
+ )
123
+ st.session_state['threshold'] = threshold
124
+
125
+ # compute metrics at selected threshold using CV validation outputs
126
+ try:
127
+ probs = np.array(val_probas)
128
+ labels = np.array(val_labels)
129
+ preds_thr = (probs >= threshold).astype(int)
130
+ prec, rec, f1, _ = precision_recall_fscore_support(
131
+ labels, preds_thr, average='binary', zero_division=0)
132
+ cm = confusion_matrix(labels, preds_thr)
133
+ tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
134
+ st.sidebar.markdown('**At selected threshold:**')
135
+ st.sidebar.write(
136
+ f'Precision: {prec:.3f} — Recall: {rec:.3f} — F1: {f1:.3f}')
137
+ st.sidebar.write(
138
+ f'False Positives: {int(fp)} — False Negatives: {int(fn)}')
139
+ st.sidebar.caption(
140
+ 'Threshold selection trades off between catching more attacks (recall) and raising false alarms (FP).')
141
+ except Exception:
142
+ st.sidebar.warning('Unable to compute threshold metrics.')
143
+
144
+ except Exception as e:
145
+ logger.exception('Training failed')
146
+ st.error(f"Training failed: {e}")
147
+
148
+
149
+ st.header("3. Threat Analysis Dashboard")
150
+
151
+ if 'model' in st.session_state:
152
+ col1, col2 = st.columns(2)
153
+
154
+ with col1:
155
+ st.subheader("Simulation")
156
+ st.info("Pick a random packet from the test data to simulate live traffic.")
157
+
158
+ if st.button("🎲 Capture Random Packet"):
159
+ random_idx = np.random.randint(0, len(st.session_state['X_test']))
160
+ packet_data = st.session_state['X_test'].iloc[random_idx]
161
+ actual_label = st.session_state['y_test'].iloc[random_idx]
162
+
163
+ st.session_state['current_packet'] = packet_data
164
+ st.session_state['actual_label'] = actual_label
165
+
166
+ if 'current_packet' in st.session_state:
167
+ packet = st.session_state['current_packet']
168
+
169
+ with col1:
170
+ st.write("**Packet Header Info:**")
171
+ st.dataframe(packet, use_container_width=True)
172
+
173
+ with col2:
174
+ st.subheader("AI Detection Result")
175
+ # ensure input shape
176
+ pred_proba = st.session_state['model'].predict_proba([packet.values])[
177
+ 0, 1]
178
+ threshold = st.session_state.get('threshold', 0.5)
179
+ prediction = 'ATTACK' if pred_proba >= threshold else 'BENIGN'
180
+
181
+ st.caption(f"Active decision threshold: {threshold:.2f}")
182
+
183
+ if prediction == "BENIGN":
184
+ st.success(
185
+ f" STATUS: **SAFE (BENIGN)** — score {pred_proba:.3f}")
186
+ else:
187
+ st.error(
188
+ f"🚨 STATUS: **ATTACK DETECTED ({prediction})** — score {pred_proba:.3f}")
189
+
190
+ gt = st.session_state['encoder'].inverse_transform(
191
+ [st.session_state['actual_label']])[0]
192
+ st.caption(f"Ground Truth Label: {gt}")
193
+
194
+ st.markdown("---")
195
+ st.subheader(" Ask AI Analyst (Groq)")
196
+
197
+ if st.button("Generate Explanation"):
198
+ if not groq_api_key:
199
+ st.warning(
200
+ " Please enter your Groq API Key in the sidebar first.")
201
+ else:
202
+ try:
203
+ client = Groq(api_key=groq_api_key)
204
+
205
+ prompt = f"""
206
+ You are a cybersecurity analyst.
207
+ A network packet window was scored as: {prediction} (score={pred_proba:.3f}).
208
+
209
+ Packet Technical Details:
210
+ {packet.to_string()}
211
+
212
+ Please explain succinctly why the model might consider this an {prediction}.
213
+ """
214
+
215
+ with st.spinner("Groq is analyzing the packet..."):
216
+ completion = client.chat.completions.create(
217
+ model="llama-3.3-70b-versatile",
218
+ messages=[{"role": "user", "content": prompt}],
219
+ temperature=0.6,
220
+ )
221
+ st.info(completion.choices[0].message.content)
222
+
223
+ except Exception as e:
224
+ st.error(f"API Error: {e}")
225
+ else:
226
+ st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ scikit-learn
5
+ groq
sample_data/sample_small.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Fwd Packet Length Max,Flow IAT Mean,Flow IAT Std,Flow Packets/s,Label
2
+ 0.1,5,3,400,150,0.02,0.005,12.5,BENIGN
3
+ 0.2,8,2,900,200,0.03,0.01,20.0,BENIGN
4
+ 0.05,2,1,120,120,0.01,0.002,8.0,BENIGN
5
+ 0.3,10,4,1500,300,0.04,0.02,25.0,ATTACK
6
+ 0.15,6,3,600,160,0.02,0.007,15.0,BENIGN
7
+ 0.12,7,3,700,140,0.025,0.008,16.0,BENIGN
8
+ 0.5,50,45,40000,1500,0.5,0.2,120.0,ATTACK
9
+ 0.08,4,2,300,110,0.015,0.004,10.0,BENIGN
10
+ 0.25,12,6,1800,350,0.035,0.015,30.0,ATTACK
11
+ 0.07,3,1,200,100,0.012,0.003,9.0,BENIGN