File size: 5,841 Bytes
d6f95f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import streamlit as st
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from groq import Groq
import os
# --- PAGE SETUP ---
st.set_page_config(page_title="AI-NIDS Student Project", layout="wide")
st.title("AI-Based Network Intrusion Detection System")
st.markdown("""
**Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets.
""")
# --- CONFIGURATION ---
DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"
# --- SIDEBAR: SETTINGS ---
st.sidebar.header("1. Settings")
groq_api_key = st.sidebar.text_input("Groq API Key (starts with gsk_)", type="password")
st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)")
st.sidebar.header("2. Model Training")
@st.cache_data
def load_data(filepath):
try:
df = pd.read_csv(filepath, nrows=15000)
df.columns = df.columns.str.strip()
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
return df
except FileNotFoundError:
return None
def train_model(df):
features = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets',
'Total Length of Fwd Packets', 'Fwd Packet Length Max',
'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s']
target = 'Label'
missing_cols = [c for c in features if c not in df.columns]
if missing_cols:
st.error(f"Missing columns in CSV: {missing_cols}")
return None, 0, [], None, None
X = df[features]
y = df[target]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
clf = RandomForestClassifier(n_estimators=10, max_depth=10, random_state=42)
clf.fit(X_train, y_train)
score = accuracy_score(y_test, clf.predict(X_test))
return clf, score, features, X_test, y_test
# --- APP LOGIC ---
df = load_data(DATA_FILE)
if df is None:
st.error(f"Error: File '{DATA_FILE}' not found. Please upload it to the Files tab.")
st.stop()
st.sidebar.success(f"Dataset Loaded: {len(df)} rows")
if st.sidebar.button("Train Model Now"):
with st.spinner("Training model..."):
clf, accuracy, feature_names, X_test, y_test = train_model(df)
if clf:
st.session_state['model'] = clf
st.session_state['features'] = feature_names
st.session_state['X_test'] = X_test
st.session_state['y_test'] = y_test
st.sidebar.success(f"Training Complete! Accuracy: {accuracy:.2%}")
st.header("3. Threat Analysis Dashboard")
if 'model' in st.session_state:
col1, col2 = st.columns(2)
with col1:
st.subheader("Simulation")
st.info("Pick a random packet from the test data to simulate live traffic.")
if st.button("🎲 Capture Random Packet"):
random_idx = np.random.randint(0, len(st.session_state['X_test']))
packet_data = st.session_state['X_test'].iloc[random_idx]
actual_label = st.session_state['y_test'].iloc[random_idx]
st.session_state['current_packet'] = packet_data
st.session_state['actual_label'] = actual_label
if 'current_packet' in st.session_state:
packet = st.session_state['current_packet']
with col1:
st.write("**Packet Header Info:**")
st.dataframe(packet, use_container_width=True)
with col2:
st.subheader("AI Detection Result")
prediction = st.session_state['model'].predict([packet])[0]
if prediction == "BENIGN":
st.success(f" STATUS: **SAFE (BENIGN)**")
else:
st.error(f"🚨 STATUS: **ATTACK DETECTED ({prediction})**")
st.caption(f"Ground Truth Label: {st.session_state['actual_label']}")
st.markdown("---")
st.subheader(" Ask AI Analyst (Groq)")
if st.button("Generate Explanation"):
if not groq_api_key:
st.warning(" Please enter your Groq API Key in the sidebar first.")
else:
try:
client = Groq(api_key=groq_api_key)
prompt = f"""
You are a cybersecurity analyst.
A network packet was detected as: {prediction}.
Packet Technical Details:
{packet.to_string()}
Please explain:
1. Why these specific values (like Flow Duration or Packet Length) might indicate {prediction}.
2. If it is BENIGN, explain why it looks normal.
3. Keep the answer short and simple for a student.
"""
with st.spinner("Groq is analyzing the packet..."):
completion = client.chat.completions.create(
model="llama-3.3-70b-versatile", # <--- UPDATED MODEL NAME
messages=[
{"role": "user", "content": prompt}
],
temperature=0.6,
)
st.info(completion.choices[0].message.content)
except Exception as e:
st.error(f"API Error: {e}")
else:
st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.") |