File size: 5,841 Bytes
d6f95f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from groq import Groq
import os

# --- PAGE SETUP ---
st.set_page_config(page_title="AI-NIDS Student Project", layout="wide")

st.title("AI-Based Network Intrusion Detection System")
st.markdown("""
**Student Project**: This system uses **Random Forest** to detect Network attacks and **Groq AI** to explain the packets.
""")

# --- CONFIGURATION ---
DATA_FILE = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"

# --- SIDEBAR: SETTINGS ---
st.sidebar.header("1. Settings")
groq_api_key = st.sidebar.text_input("Groq API Key (starts with gsk_)", type="password")
st.sidebar.caption("[Get a free key here](https://console.groq.com/keys)")

st.sidebar.header("2. Model Training")

@st.cache_data
def load_data(filepath):
    try:
        df = pd.read_csv(filepath, nrows=15000)
        df.columns = df.columns.str.strip()
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        return df
    except FileNotFoundError:
        return None

def train_model(df):
    features = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 
                'Total Length of Fwd Packets', 'Fwd Packet Length Max', 
                'Flow IAT Mean', 'Flow IAT Std', 'Flow Packets/s']
    target = 'Label'
    
    missing_cols = [c for c in features if c not in df.columns]
    if missing_cols:
        st.error(f"Missing columns in CSV: {missing_cols}")
        return None, 0, [], None, None

    X = df[features]
    y = df[target]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    clf = RandomForestClassifier(n_estimators=10, max_depth=10, random_state=42)
    clf.fit(X_train, y_train)
    
    score = accuracy_score(y_test, clf.predict(X_test))
    return clf, score, features, X_test, y_test

# --- APP LOGIC ---
df = load_data(DATA_FILE)

if df is None:
    st.error(f"Error: File '{DATA_FILE}' not found. Please upload it to the Files tab.")
    st.stop()

st.sidebar.success(f"Dataset Loaded: {len(df)} rows")

if st.sidebar.button("Train Model Now"):
    with st.spinner("Training model..."):
        clf, accuracy, feature_names, X_test, y_test = train_model(df)
        if clf:
            st.session_state['model'] = clf
            st.session_state['features'] = feature_names
            st.session_state['X_test'] = X_test 
            st.session_state['y_test'] = y_test
            st.sidebar.success(f"Training Complete! Accuracy: {accuracy:.2%}")

st.header("3. Threat Analysis Dashboard")

if 'model' in st.session_state:
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Simulation")
        st.info("Pick a random packet from the test data to simulate live traffic.")
        
        if st.button("🎲 Capture Random Packet"):
            random_idx = np.random.randint(0, len(st.session_state['X_test']))
            packet_data = st.session_state['X_test'].iloc[random_idx]
            actual_label = st.session_state['y_test'].iloc[random_idx]
            
            st.session_state['current_packet'] = packet_data
            st.session_state['actual_label'] = actual_label
            
    if 'current_packet' in st.session_state:
        packet = st.session_state['current_packet']
        
        with col1:
            st.write("**Packet Header Info:**")
            st.dataframe(packet, use_container_width=True)

        with col2:
            st.subheader("AI Detection Result")
            prediction = st.session_state['model'].predict([packet])[0]
            
            if prediction == "BENIGN":
                st.success(f" STATUS: **SAFE (BENIGN)**")
            else:
                st.error(f"🚨 STATUS: **ATTACK DETECTED ({prediction})**")
            
            st.caption(f"Ground Truth Label: {st.session_state['actual_label']}")

            st.markdown("---")
            st.subheader(" Ask AI Analyst (Groq)")
            
            if st.button("Generate Explanation"):
                if not groq_api_key:
                    st.warning(" Please enter your Groq API Key in the sidebar first.")
                else:
                    try:
                        client = Groq(api_key=groq_api_key)
                        
                        prompt = f"""
                        You are a cybersecurity analyst. 
                        A network packet was detected as: {prediction}.
                        
                        Packet Technical Details:
                        {packet.to_string()}
                        
                        Please explain:
                        1. Why these specific values (like Flow Duration or Packet Length) might indicate {prediction}.
                        2. If it is BENIGN, explain why it looks normal.
                        3. Keep the answer short and simple for a student.
                        """

                        with st.spinner("Groq is analyzing the packet..."):
                            completion = client.chat.completions.create(
                                model="llama-3.3-70b-versatile",  # <--- UPDATED MODEL NAME
                                messages=[
                                    {"role": "user", "content": prompt}
                                ],
                                temperature=0.6,
                            )
                            st.info(completion.choices[0].message.content)
                            
                    except Exception as e:
                        st.error(f"API Error: {e}")
else:
    st.info(" Waiting for model training. Click **'Train Model Now'** in the sidebar.")