File size: 3,573 Bytes
7baa7c6
0f94f14
7baa7c6
 
 
 
 
 
 
 
0f94f14
7baa7c6
 
 
0f94f14
 
7baa7c6
 
0f94f14
e647820
0f94f14
7baa7c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e647820
7baa7c6
3139f2e
7baa7c6
 
 
 
e647820
7baa7c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f94f14
7baa7c6
 
e647820
7baa7c6
 
 
 
 
 
 
 
 
0f94f14
 
7baa7c6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import pickle
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, InferenceClient
import gradio as gr

# Load environment
load_dotenv()
hf_token = os.getenv("HF_TOKEN")

# ------------------------------------------
# Load IDS Model
# ------------------------------------------
model_path = hf_hub_download(
    repo_id="utsavNagar/cyberids-ml",
    filename="ids_model.pkl",
    token=hf_token
)

with open(model_path, "rb") as f:
    ids_model = pickle.load(f)

# ------------------------------------------
# IDS Prediction Logic
# ------------------------------------------
def predict_intrusion(data_dict):
    df = pd.DataFrame([data_dict])
    df = df.apply(pd.to_numeric, errors="coerce").fillna(0)
    prob = ids_model.predict(df)[0]
    return "Attack" if prob > 0.5 else "Normal"

# ------------------------------------------
# LLM Setup
# ------------------------------------------
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"

client = InferenceClient(model=LLM_MODEL, token=hf_token)

# ------------------------------------------
# Chat-based explanation from LLM
# ------------------------------------------
def generate_report(features, prediction):

    prompt = f"""
You are a cybersecurity analyst.

Network Data:
{features}

IDS Prediction: {prediction}

Provide a clear and professional analysis including:
1. Attack or Normal?
2. Why the IDS believes this.
3. Most likely attack type (DoS, Probe, R2L, U2R).
4. Severity level.
5. Recommended actions.
6. Brief incident summary (2–3 sentences).
"""

    try:
        resp = client.chat_completion(
            messages=[{"role": "user", "content": prompt}],
            max_tokens=400,
            temperature=0.5
        )
        return resp.choices[0].message["content"]
    except Exception as e:
        return f"[LLM Error] {e}"

# ------------------------------------------
# Wrapper for Gradio
# ------------------------------------------
def analyze(**inputs):
    prediction = predict_intrusion(inputs)
    report = generate_report(inputs, prediction)
    return prediction, report

# ------------------------------------------
# Gradio UI
# ------------------------------------------
feature_inputs = []

NSL_FEATURES = [
    "duration","protocol_type","service","flag","src_bytes","dst_bytes",
    "land","wrong_fragment","urgent","hot","num_failed_logins","logged_in",
    "num_compromised","root_shell","su_attempted","num_root","num_file_creations",
    "num_shells","num_access_files","num_outbound_cmds","is_host_login",
    "is_guest_login","count","srv_count","serror_rate","srv_serror_rate",
    "rerror_rate","srv_rerror_rate","same_srv_rate","diff_srv_rate",
    "srv_diff_host_rate","dst_host_count","dst_host_srv_count",
    "dst_host_same_srv_rate","dst_host_diff_srv_rate",
    "dst_host_same_src_port_rate","dst_host_srv_diff_host_rate",
    "dst_host_serror_rate","dst_host_srv_serror_rate","dst_host_rerror_rate",
    "dst_host_srv_rerror_rate"
]

for f in NSL_FEATURES:
    feature_inputs.append(gr.Number(label=f))

interface = gr.Interface(
    fn=analyze,
    inputs=feature_inputs,
    outputs=[
        gr.Textbox(label="Prediction (Attack / Normal)", interactive=False),
        gr.Textbox(label="AI-Generated Incident Report", lines=12, interactive=False)
    ],
    title="Cybersecurity Intrusion Detection System (IDS + AI Analyst)",
    description="Detect network intrusions using a Machine Learning IDS model and get a full explanation via an LLM."
)

if __name__ == "__main__":
    interface.launch()