import streamlit as st import pandas as pd import numpy as np import random import os import json import logging import joblib from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import StratifiedKFold from sklearn.metrics import (precision_recall_fscore_support, roc_auc_score, average_precision_score, confusion_matrix) from groq import Groq from training_utils import train_model_cv # --- REPRODUCIBILITY & LOGGING --- SEED = 42 random.seed(SEED) np.random.seed(SEED) os.environ['PYTHONHASHSEED'] = str(SEED) logging.basicConfig(level=logging.INFO, filename='training.log', filemode='a', format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger('nids') logger.info('App start, seed=%s', SEED) # --- PAGE SETUP --- st.set_page_config(page_title="AI-NIDS Student Project", layout="wide") st.title("AI-Based Network Intrusion Detection System") st.markdown(""" **Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets. """) # --- CONFIGURATION --- DATA_FILE = "sample_data/sample_small.csv" # --- SIDEBAR: SETTINGS --- st.sidebar.header("1. Settings") groq_api_key = st.sidebar.text_input( "Groq API Key (optional)", type="password") st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)") st.sidebar.header("2. Model Training") NROWS = st.sidebar.number_input( "Max rows to load", min_value=1000, value=15000, step=1000) N_EST = st.sidebar.slider("RF trees (n_estimators)", 10, 500, 100, step=10) MAX_DEPTH = st.sidebar.slider("Max tree depth (0=none)", 0, 50, 12) @st.cache_data def load_data(filepath, nrows=None): try: df = pd.read_csv(filepath, nrows=nrows) df.columns = df.columns.str.strip() df.replace([np.inf, -np.inf], np.nan, inplace=True) # require Label column if 'Label' not in df.columns: logger.error('Label column missing in dataset') return None # drop rows missing label df = df.dropna(subset=['Label']) # fill numeric nulls with median num_cols = df.select_dtypes(include=[np.number]).columns for c in num_cols: med = df[c].median() df[c].fillna(med, inplace=True) return df except FileNotFoundError: logger.exception('Data file not found: %s', filepath) return None # Training and metric utilities moved to training_utils.py # --- APP LOGIC --- df = load_data(DATA_FILE, nrows=NROWS) if df is None: st.error( f"Error: File '{DATA_FILE}' not found or missing required columns. Please upload it to the Files tab.") st.stop() st.sidebar.success(f"Dataset Loaded: {len(df)} rows") DEFAULT_FEATURES = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 'Total Length of Fwd Packets', 'Fwd Packet Length Max', 'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s'] if st.sidebar.button("Train Model Now"): with st.spinner("Training model (Stratified CV)..."): try: clf, metrics, X_all, y_all, val_probas, val_labels, encoder = train_model_cv( df, DEFAULT_FEATURES, n_splits=5, n_estimators=N_EST, max_depth=MAX_DEPTH, seed=SEED ) st.session_state['model'] = clf st.session_state['features'] = list(X_all.columns) # keep small test slice for simulation UI st.session_state['X_test'] = pd.DataFrame( X_all).sample(frac=0.2, random_state=SEED) st.session_state['y_test'] = pd.Series( y_all)[st.session_state['X_test'].index] st.session_state['encoder'] = encoder st.session_state['val_probas'] = val_probas st.session_state['val_labels'] = val_labels # display aggregate metrics agg = metrics['aggregate'] st.sidebar.success( f"Training complete — F1: {agg['f1_mean']:.3f} ± {agg['f1_std']:.3f}") st.subheader('Training Metrics (aggregate)') st.json(agg) st.subheader('Per-fold metrics') st.json(metrics['folds']) # Threshold tuning UI (operational) st.sidebar.header('3. Threshold Tuning') threshold = st.sidebar.slider( 'Decision threshold (ATTACK probability)', 0.10, 0.90, 0.50, 0.01 ) st.session_state['threshold'] = threshold # compute metrics at selected threshold using CV validation outputs try: probs = np.array(val_probas) labels = np.array(val_labels) preds_thr = (probs >= threshold).astype(int) prec, rec, f1, _ = precision_recall_fscore_support( labels, preds_thr, average='binary', zero_division=0) cm = confusion_matrix(labels, preds_thr) tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) st.sidebar.markdown('**At selected threshold:**') st.sidebar.write( f'Precision: {prec:.3f} — Recall: {rec:.3f} — F1: {f1:.3f}') st.sidebar.write( f'False Positives: {int(fp)} — False Negatives: {int(fn)}') st.sidebar.caption( 'Threshold selection trades off between catching more attacks (recall) and raising false alarms (FP).') except Exception: st.sidebar.warning('Unable to compute threshold metrics.') except Exception as e: logger.exception('Training failed') st.error(f"Training failed: {e}") st.header("3. Threat Analysis Dashboard") if 'model' in st.session_state: col1, col2 = st.columns(2) with col1: st.subheader("Simulation") st.info("Pick a random packet from the test data to simulate live traffic.") if st.button("🎲 Capture Random Packet"): random_idx = np.random.randint(0, len(st.session_state['X_test'])) packet_data = st.session_state['X_test'].iloc[random_idx] actual_label = st.session_state['y_test'].iloc[random_idx] st.session_state['current_packet'] = packet_data st.session_state['actual_label'] = actual_label if 'current_packet' in st.session_state: packet = st.session_state['current_packet'] with col1: st.write("**Packet Header Info:**") st.dataframe(packet, use_container_width=True) with col2: st.subheader("AI Detection Result") # ensure input shape pred_proba = st.session_state['model'].predict_proba([packet.values])[ 0, 1] threshold = st.session_state.get('threshold', 0.5) prediction = 'ATTACK' if pred_proba >= threshold else 'BENIGN' st.caption(f"Active decision threshold: {threshold:.2f}") if prediction == "BENIGN": st.success( f" STATUS: **SAFE (BENIGN)** — score {pred_proba:.3f}") else: st.error( f"🚨 STATUS: **ATTACK DETECTED ({prediction})** — score {pred_proba:.3f}") gt = st.session_state['encoder'].inverse_transform( [st.session_state['actual_label']])[0] st.caption(f"Ground Truth Label: {gt}") st.markdown("---") st.subheader(" Ask AI Analyst (Groq)") if st.button("Generate Explanation"): if not groq_api_key: st.warning( " Please enter your Groq API Key in the sidebar first.") else: try: client = Groq(api_key=groq_api_key) prompt = f""" You are a cybersecurity analyst. A network packet window was scored as: {prediction} (score={pred_proba:.3f}). Packet Technical Details: {packet.to_string()} Please explain succinctly why the model might consider this an {prediction}. """ with st.spinner("Groq is analyzing the packet..."): completion = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0.6, ) st.info(completion.choices[0].message.content) except Exception as e: st.error(f"API Error: {e}") else: st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.")