| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | from sklearn.ensemble import RandomForestClassifier |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import accuracy_score |
| | from groq import Groq |
| | import os |
| |
|
| | |
| | st.set_page_config(page_title="AI-NIDS Student Project", layout="wide") |
| |
|
| | st.title("AI-Based Network Intrusion Detection System") |
| | st.markdown(""" |
| | **Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets. |
| | """) |
| |
|
| | |
| | DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv" |
| |
|
| | |
| | st.sidebar.header("1. Settings") |
| | groq_api_key = st.sidebar.text_input("Groq API Key (starts with gsk_)", type="password") |
| | st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)") |
| |
|
| | st.sidebar.header("2. Model Training") |
| |
|
| | @st.cache_data |
| | def load_data(filepath): |
| | try: |
| | df = pd.read_csv(filepath, nrows=15000) |
| | df.columns = df.columns.str.strip() |
| | df.replace([np.inf, -np.inf], np.nan, inplace=True) |
| | df.dropna(inplace=True) |
| | return df |
| | except FileNotFoundError: |
| | return None |
| |
|
| | def train_model(df): |
| | features = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', |
| | 'Total Length of Fwd Packets', 'Fwd Packet Length Max', |
| | 'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s'] |
| | target = 'Label' |
| | |
| | missing_cols = [c for c in features if c not in df.columns] |
| | if missing_cols: |
| | st.error(f"Missing columns in CSV: {missing_cols}") |
| | return None, 0, [], None, None |
| |
|
| | X = df[features] |
| | y = df[target] |
| | |
| | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) |
| | |
| | clf = RandomForestClassifier(n_estimators=10, max_depth=10, random_state=42) |
| | clf.fit(X_train, y_train) |
| | |
| | score = accuracy_score(y_test, clf.predict(X_test)) |
| | return clf, score, features, X_test, y_test |
| |
|
| | |
| | df = load_data(DATA_FILE) |
| |
|
| | if df is None: |
| | st.error(f"Error: File '{DATA_FILE}' not found. Please upload it to the Files tab.") |
| | st.stop() |
| |
|
| | st.sidebar.success(f"Dataset Loaded: {len(df)} rows") |
| |
|
| | if st.sidebar.button("Train Model Now"): |
| | with st.spinner("Training model..."): |
| | clf, accuracy, feature_names, X_test, y_test = train_model(df) |
| | if clf: |
| | st.session_state['model'] = clf |
| | st.session_state['features'] = feature_names |
| | st.session_state['X_test'] = X_test |
| | st.session_state['y_test'] = y_test |
| | st.sidebar.success(f"Training Complete! Accuracy: {accuracy:.2%}") |
| |
|
| | st.header("3. Threat Analysis Dashboard") |
| |
|
| | if 'model' in st.session_state: |
| | col1, col2 = st.columns(2) |
| | |
| | with col1: |
| | st.subheader("Simulation") |
| | st.info("Pick a random packet from the test data to simulate live traffic.") |
| | |
| | if st.button("🎲 Capture Random Packet"): |
| | random_idx = np.random.randint(0, len(st.session_state['X_test'])) |
| | packet_data = st.session_state['X_test'].iloc[random_idx] |
| | actual_label = st.session_state['y_test'].iloc[random_idx] |
| | |
| | st.session_state['current_packet'] = packet_data |
| | st.session_state['actual_label'] = actual_label |
| | |
| | if 'current_packet' in st.session_state: |
| | packet = st.session_state['current_packet'] |
| | |
| | with col1: |
| | st.write("**Packet Header Info:**") |
| | st.dataframe(packet, use_container_width=True) |
| |
|
| | with col2: |
| | st.subheader("AI Detection Result") |
| | prediction = st.session_state['model'].predict([packet])[0] |
| | |
| | if prediction == "BENIGN": |
| | st.success(f" STATUS: **SAFE (BENIGN)**") |
| | else: |
| | st.error(f"🚨 STATUS: **ATTACK DETECTED ({prediction})**") |
| | |
| | st.caption(f"Ground Truth Label: {st.session_state['actual_label']}") |
| |
|
| | st.markdown("---") |
| | st.subheader(" Ask AI Analyst (Groq)") |
| | |
| | if st.button("Generate Explanation"): |
| | if not groq_api_key: |
| | st.warning(" Please enter your Groq API Key in the sidebar first.") |
| | else: |
| | try: |
| | client = Groq(api_key=groq_api_key) |
| | |
| | prompt = f""" |
| | You are a cybersecurity analyst. |
| | A network packet was detected as: {prediction}. |
| | |
| | Packet Technical Details: |
| | {packet.to_string()} |
| | |
| | Please explain: |
| | 1. Why these specific values (like Flow Duration or Packet Length) might indicate {prediction}. |
| | 2. If it is BENIGN, explain why it looks normal. |
| | 3. Keep the answer short and simple for a student. |
| | """ |
| |
|
| | with st.spinner("Groq is analyzing the packet..."): |
| | completion = client.chat.completions.create( |
| | model="llama-3.3-70b-versatile", |
| | messages=[ |
| | {"role": "user", "content": prompt} |
| | ], |
| | temperature=0.6, |
| | ) |
| | st.info(completion.choices[0].message.content) |
| | |
| | except Exception as e: |
| | st.error(f"API Error: {e}") |
| | else: |
| | st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.") |