Taranpreet Singh
HF fix: default to sample dataset for reproducible Space demo
fd6f9d1
import streamlit as st
import pandas as pd
import numpy as np
import random
import os
import json
import logging
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (precision_recall_fscore_support, roc_auc_score,
average_precision_score, confusion_matrix)
from groq import Groq
from training_utils import train_model_cv
# --- REPRODUCIBILITY & LOGGING ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
logging.basicConfig(level=logging.INFO, filename='training.log', filemode='a',
format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger('nids')
logger.info('App start, seed=%s', SEED)
# --- PAGE SETUP ---
st.set_page_config(page_title="AI-NIDS Student Project", layout="wide")
st.title("AI-Based Network Intrusion Detection System")
st.markdown("""
**Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets.
""")
# --- CONFIGURATION ---
DATA_FILE = "sample_data/sample_small.csv"
# --- SIDEBAR: SETTINGS ---
st.sidebar.header("1. Settings")
groq_api_key = st.sidebar.text_input(
"Groq API Key (optional)", type="password")
st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)")
st.sidebar.header("2. Model Training")
NROWS = st.sidebar.number_input(
"Max rows to load", min_value=1000, value=15000, step=1000)
N_EST = st.sidebar.slider("RF trees (n_estimators)", 10, 500, 100, step=10)
MAX_DEPTH = st.sidebar.slider("Max tree depth (0=none)", 0, 50, 12)
@st.cache_data
def load_data(filepath, nrows=None):
try:
df = pd.read_csv(filepath, nrows=nrows)
df.columns = df.columns.str.strip()
df.replace([np.inf, -np.inf], np.nan, inplace=True)
# require Label column
if 'Label' not in df.columns:
logger.error('Label column missing in dataset')
return None
# drop rows missing label
df = df.dropna(subset=['Label'])
# fill numeric nulls with median
num_cols = df.select_dtypes(include=[np.number]).columns
for c in num_cols:
med = df[c].median()
df[c].fillna(med, inplace=True)
return df
except FileNotFoundError:
logger.exception('Data file not found: %s', filepath)
return None
# Training and metric utilities moved to training_utils.py
# --- APP LOGIC ---
df = load_data(DATA_FILE, nrows=NROWS)
if df is None:
st.error(
f"Error: File '{DATA_FILE}' not found or missing required columns. Please upload it to the Files tab.")
st.stop()
st.sidebar.success(f"Dataset Loaded: {len(df)} rows")
DEFAULT_FEATURES = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets',
'Total Length of Fwd Packets', 'Fwd Packet Length Max',
'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s']
if st.sidebar.button("Train Model Now"):
with st.spinner("Training model (Stratified CV)..."):
try:
clf, metrics, X_all, y_all, val_probas, val_labels, encoder = train_model_cv(
df, DEFAULT_FEATURES, n_splits=5, n_estimators=N_EST, max_depth=MAX_DEPTH, seed=SEED
)
st.session_state['model'] = clf
st.session_state['features'] = list(X_all.columns)
# keep small test slice for simulation UI
st.session_state['X_test'] = pd.DataFrame(
X_all).sample(frac=0.2, random_state=SEED)
st.session_state['y_test'] = pd.Series(
y_all)[st.session_state['X_test'].index]
st.session_state['encoder'] = encoder
st.session_state['val_probas'] = val_probas
st.session_state['val_labels'] = val_labels
# display aggregate metrics
agg = metrics['aggregate']
st.sidebar.success(
f"Training complete — F1: {agg['f1_mean']:.3f} ± {agg['f1_std']:.3f}")
st.subheader('Training Metrics (aggregate)')
st.json(agg)
st.subheader('Per-fold metrics')
st.json(metrics['folds'])
# Threshold tuning UI (operational)
st.sidebar.header('3. Threshold Tuning')
threshold = st.sidebar.slider(
'Decision threshold (ATTACK probability)',
0.10, 0.90, 0.50, 0.01
)
st.session_state['threshold'] = threshold
# compute metrics at selected threshold using CV validation outputs
try:
probs = np.array(val_probas)
labels = np.array(val_labels)
preds_thr = (probs >= threshold).astype(int)
prec, rec, f1, _ = precision_recall_fscore_support(
labels, preds_thr, average='binary', zero_division=0)
cm = confusion_matrix(labels, preds_thr)
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
st.sidebar.markdown('**At selected threshold:**')
st.sidebar.write(
f'Precision: {prec:.3f} — Recall: {rec:.3f} — F1: {f1:.3f}')
st.sidebar.write(
f'False Positives: {int(fp)} — False Negatives: {int(fn)}')
st.sidebar.caption(
'Threshold selection trades off between catching more attacks (recall) and raising false alarms (FP).')
except Exception:
st.sidebar.warning('Unable to compute threshold metrics.')
except Exception as e:
logger.exception('Training failed')
st.error(f"Training failed: {e}")
st.header("3. Threat Analysis Dashboard")
if 'model' in st.session_state:
col1, col2 = st.columns(2)
with col1:
st.subheader("Simulation")
st.info("Pick a random packet from the test data to simulate live traffic.")
if st.button("🎲 Capture Random Packet"):
random_idx = np.random.randint(0, len(st.session_state['X_test']))
packet_data = st.session_state['X_test'].iloc[random_idx]
actual_label = st.session_state['y_test'].iloc[random_idx]
st.session_state['current_packet'] = packet_data
st.session_state['actual_label'] = actual_label
if 'current_packet' in st.session_state:
packet = st.session_state['current_packet']
with col1:
st.write("**Packet Header Info:**")
st.dataframe(packet, use_container_width=True)
with col2:
st.subheader("AI Detection Result")
# ensure input shape
pred_proba = st.session_state['model'].predict_proba([packet.values])[
0, 1]
threshold = st.session_state.get('threshold', 0.5)
prediction = 'ATTACK' if pred_proba >= threshold else 'BENIGN'
st.caption(f"Active decision threshold: {threshold:.2f}")
if prediction == "BENIGN":
st.success(
f" STATUS: **SAFE (BENIGN)** — score {pred_proba:.3f}")
else:
st.error(
f"🚨 STATUS: **ATTACK DETECTED ({prediction})** — score {pred_proba:.3f}")
gt = st.session_state['encoder'].inverse_transform(
[st.session_state['actual_label']])[0]
st.caption(f"Ground Truth Label: {gt}")
st.markdown("---")
st.subheader(" Ask AI Analyst (Groq)")
if st.button("Generate Explanation"):
if not groq_api_key:
st.warning(
" Please enter your Groq API Key in the sidebar first.")
else:
try:
client = Groq(api_key=groq_api_key)
prompt = f"""
You are a cybersecurity analyst.
A network packet window was scored as: {prediction} (score={pred_proba:.3f}).
Packet Technical Details:
{packet.to_string()}
Please explain succinctly why the model might consider this an {prediction}.
"""
with st.spinner("Groq is analyzing the packet..."):
completion = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": prompt}],
temperature=0.6,
)
st.info(completion.choices[0].message.content)
except Exception as e:
st.error(f"API Error: {e}")
else:
st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.")