|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import random |
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import joblib |
|
|
from sklearn.ensemble import RandomForestClassifier |
|
|
from sklearn.model_selection import StratifiedKFold |
|
|
from sklearn.metrics import (precision_recall_fscore_support, roc_auc_score, |
|
|
average_precision_score, confusion_matrix) |
|
|
from groq import Groq |
|
|
from training_utils import train_model_cv |
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
random.seed(SEED) |
|
|
np.random.seed(SEED) |
|
|
os.environ['PYTHONHASHSEED'] = str(SEED) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, filename='training.log', filemode='a', |
|
|
format='%(asctime)s %(levelname)s %(message)s') |
|
|
logger = logging.getLogger('nids') |
|
|
logger.info('App start, seed=%s', SEED) |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="AI-NIDS Student Project", layout="wide") |
|
|
|
|
|
st.title("AI-Based Network Intrusion Detection System") |
|
|
st.markdown(""" |
|
|
**Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets. |
|
|
""") |
|
|
|
|
|
|
|
|
DATA_FILE = "sample_data/sample_small.csv" |
|
|
|
|
|
|
|
|
st.sidebar.header("1. Settings") |
|
|
groq_api_key = st.sidebar.text_input( |
|
|
"Groq API Key (optional)", type="password") |
|
|
st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)") |
|
|
|
|
|
st.sidebar.header("2. Model Training") |
|
|
NROWS = st.sidebar.number_input( |
|
|
"Max rows to load", min_value=1000, value=15000, step=1000) |
|
|
N_EST = st.sidebar.slider("RF trees (n_estimators)", 10, 500, 100, step=10) |
|
|
MAX_DEPTH = st.sidebar.slider("Max tree depth (0=none)", 0, 50, 12) |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_data(filepath, nrows=None): |
|
|
try: |
|
|
df = pd.read_csv(filepath, nrows=nrows) |
|
|
df.columns = df.columns.str.strip() |
|
|
df.replace([np.inf, -np.inf], np.nan, inplace=True) |
|
|
|
|
|
if 'Label' not in df.columns: |
|
|
logger.error('Label column missing in dataset') |
|
|
return None |
|
|
|
|
|
df = df.dropna(subset=['Label']) |
|
|
|
|
|
num_cols = df.select_dtypes(include=[np.number]).columns |
|
|
for c in num_cols: |
|
|
med = df[c].median() |
|
|
df[c].fillna(med, inplace=True) |
|
|
return df |
|
|
except FileNotFoundError: |
|
|
logger.exception('Data file not found: %s', filepath) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = load_data(DATA_FILE, nrows=NROWS) |
|
|
|
|
|
if df is None: |
|
|
st.error( |
|
|
f"Error: File '{DATA_FILE}' not found or missing required columns. Please upload it to the Files tab.") |
|
|
st.stop() |
|
|
|
|
|
st.sidebar.success(f"Dataset Loaded: {len(df)} rows") |
|
|
|
|
|
DEFAULT_FEATURES = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', |
|
|
'Total Length of Fwd Packets', 'Fwd Packet Length Max', |
|
|
'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s'] |
|
|
|
|
|
if st.sidebar.button("Train Model Now"): |
|
|
with st.spinner("Training model (Stratified CV)..."): |
|
|
try: |
|
|
clf, metrics, X_all, y_all, val_probas, val_labels, encoder = train_model_cv( |
|
|
df, DEFAULT_FEATURES, n_splits=5, n_estimators=N_EST, max_depth=MAX_DEPTH, seed=SEED |
|
|
) |
|
|
st.session_state['model'] = clf |
|
|
st.session_state['features'] = list(X_all.columns) |
|
|
|
|
|
st.session_state['X_test'] = pd.DataFrame( |
|
|
X_all).sample(frac=0.2, random_state=SEED) |
|
|
st.session_state['y_test'] = pd.Series( |
|
|
y_all)[st.session_state['X_test'].index] |
|
|
st.session_state['encoder'] = encoder |
|
|
st.session_state['val_probas'] = val_probas |
|
|
st.session_state['val_labels'] = val_labels |
|
|
|
|
|
|
|
|
agg = metrics['aggregate'] |
|
|
st.sidebar.success( |
|
|
f"Training complete — F1: {agg['f1_mean']:.3f} ± {agg['f1_std']:.3f}") |
|
|
st.subheader('Training Metrics (aggregate)') |
|
|
st.json(agg) |
|
|
st.subheader('Per-fold metrics') |
|
|
st.json(metrics['folds']) |
|
|
|
|
|
|
|
|
st.sidebar.header('3. Threshold Tuning') |
|
|
threshold = st.sidebar.slider( |
|
|
'Decision threshold (ATTACK probability)', |
|
|
0.10, 0.90, 0.50, 0.01 |
|
|
) |
|
|
st.session_state['threshold'] = threshold |
|
|
|
|
|
|
|
|
try: |
|
|
probs = np.array(val_probas) |
|
|
labels = np.array(val_labels) |
|
|
preds_thr = (probs >= threshold).astype(int) |
|
|
prec, rec, f1, _ = precision_recall_fscore_support( |
|
|
labels, preds_thr, average='binary', zero_division=0) |
|
|
cm = confusion_matrix(labels, preds_thr) |
|
|
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) |
|
|
st.sidebar.markdown('**At selected threshold:**') |
|
|
st.sidebar.write( |
|
|
f'Precision: {prec:.3f} — Recall: {rec:.3f} — F1: {f1:.3f}') |
|
|
st.sidebar.write( |
|
|
f'False Positives: {int(fp)} — False Negatives: {int(fn)}') |
|
|
st.sidebar.caption( |
|
|
'Threshold selection trades off between catching more attacks (recall) and raising false alarms (FP).') |
|
|
except Exception: |
|
|
st.sidebar.warning('Unable to compute threshold metrics.') |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception('Training failed') |
|
|
st.error(f"Training failed: {e}") |
|
|
|
|
|
|
|
|
st.header("3. Threat Analysis Dashboard") |
|
|
|
|
|
if 'model' in st.session_state: |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Simulation") |
|
|
st.info("Pick a random packet from the test data to simulate live traffic.") |
|
|
|
|
|
if st.button("🎲 Capture Random Packet"): |
|
|
random_idx = np.random.randint(0, len(st.session_state['X_test'])) |
|
|
packet_data = st.session_state['X_test'].iloc[random_idx] |
|
|
actual_label = st.session_state['y_test'].iloc[random_idx] |
|
|
|
|
|
st.session_state['current_packet'] = packet_data |
|
|
st.session_state['actual_label'] = actual_label |
|
|
|
|
|
if 'current_packet' in st.session_state: |
|
|
packet = st.session_state['current_packet'] |
|
|
|
|
|
with col1: |
|
|
st.write("**Packet Header Info:**") |
|
|
st.dataframe(packet, use_container_width=True) |
|
|
|
|
|
with col2: |
|
|
st.subheader("AI Detection Result") |
|
|
|
|
|
pred_proba = st.session_state['model'].predict_proba([packet.values])[ |
|
|
0, 1] |
|
|
threshold = st.session_state.get('threshold', 0.5) |
|
|
prediction = 'ATTACK' if pred_proba >= threshold else 'BENIGN' |
|
|
|
|
|
st.caption(f"Active decision threshold: {threshold:.2f}") |
|
|
|
|
|
if prediction == "BENIGN": |
|
|
st.success( |
|
|
f" STATUS: **SAFE (BENIGN)** — score {pred_proba:.3f}") |
|
|
else: |
|
|
st.error( |
|
|
f"🚨 STATUS: **ATTACK DETECTED ({prediction})** — score {pred_proba:.3f}") |
|
|
|
|
|
gt = st.session_state['encoder'].inverse_transform( |
|
|
[st.session_state['actual_label']])[0] |
|
|
st.caption(f"Ground Truth Label: {gt}") |
|
|
|
|
|
st.markdown("---") |
|
|
st.subheader(" Ask AI Analyst (Groq)") |
|
|
|
|
|
if st.button("Generate Explanation"): |
|
|
if not groq_api_key: |
|
|
st.warning( |
|
|
" Please enter your Groq API Key in the sidebar first.") |
|
|
else: |
|
|
try: |
|
|
client = Groq(api_key=groq_api_key) |
|
|
|
|
|
prompt = f""" |
|
|
You are a cybersecurity analyst. |
|
|
A network packet window was scored as: {prediction} (score={pred_proba:.3f}). |
|
|
|
|
|
Packet Technical Details: |
|
|
{packet.to_string()} |
|
|
|
|
|
Please explain succinctly why the model might consider this an {prediction}. |
|
|
""" |
|
|
|
|
|
with st.spinner("Groq is analyzing the packet..."): |
|
|
completion = client.chat.completions.create( |
|
|
model="llama-3.3-70b-versatile", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0.6, |
|
|
) |
|
|
st.info(completion.choices[0].message.content) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"API Error: {e}") |
|
|
else: |
|
|
st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.") |
|
|
|