utsavNagar commited on
Commit
7baa7c6
·
verified ·
1 Parent(s): 5f69d39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -39
app.py CHANGED
@@ -1,53 +1,114 @@
1
- import gradio as gr
2
- import pandas as pd
3
  import pickle
4
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
5
 
6
- # Download model from Hugging Face Hub
 
 
7
  model_path = hf_hub_download(
8
  repo_id="utsavNagar/cyberids-ml",
9
- filename="ids_model.pkl"
 
10
  )
11
 
12
- # Load the model
13
  with open(model_path, "rb") as f:
14
- model = pickle.load(f)
15
-
16
- # Feature names (41 features of NSL-KDD)
17
- columns = [
18
- "duration","protocol_type","service","flag","src_bytes","dst_bytes","land",
19
- "wrong_fragment","urgent","hot","num_failed_logins","logged_in",
20
- "num_compromised","root_shell","su_attempted","num_root",
21
- "num_file_creations","num_shells","num_access_files","num_outbound_cmds",
22
- "is_host_login","is_guest_login","count","srv_count","serror_rate",
23
- "srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
24
- "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
25
- "dst_host_same_srv_rate","dst_host_diff_srv_rate",
26
- "dst_host_same_src_port_rate","dst_host_srv_diff_host_rate",
27
- "dst_host_serror_rate","dst_host_srv_serror_rate",
28
- "dst_host_rerror_rate","dst_host_srv_rerror_rate"
29
- ]
30
 
31
- def predict_intrusion(*inputs):
32
- df = pd.DataFrame([inputs], columns=columns)
33
- df = df.apply(pd.to_numeric, errors='coerce').fillna(0)
34
 
35
- pred = model.predict(df)[0] # LightGBM outputs probability
 
 
 
36
 
37
- if pred > 0.5:
38
- return "🔴 ATTACK DETECTED"
39
- else:
40
- return "🟢 NORMAL TRAFFIC"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Dynamic UI (41 numeric inputs)
43
- inputs_ui = [gr.Number(label=col) for col in columns]
44
 
45
- app = gr.Interface(
46
- fn=predict_intrusion,
47
- inputs=inputs_ui,
48
- outputs="text",
49
- title="CyberSecurity Intrusion Detection System",
50
- description="Upload network feature values (NSL-KDD dataset) to detect malicious intrusions using LightGBM."
 
 
 
51
  )
52
 
53
- app.launch()
 
 
1
+ import os
 
2
  import pickle
3
+ import pandas as pd
4
+ from dotenv import load_dotenv
5
+ from huggingface_hub import hf_hub_download, InferenceClient
6
+ import gradio as gr
7
+
8
+ # Load environment
9
+ load_dotenv()
10
+ hf_token = os.getenv("HF_TOKEN")
11
 
12
+ # ------------------------------------------
13
+ # Load IDS Model
14
+ # ------------------------------------------
15
  model_path = hf_hub_download(
16
  repo_id="utsavNagar/cyberids-ml",
17
+ filename="ids_model.pkl",
18
+ token=hf_token
19
  )
20
 
 
21
  with open(model_path, "rb") as f:
22
+ ids_model = pickle.load(f)
23
+
24
+ # ------------------------------------------
25
+ # IDS Prediction Logic
26
+ # ------------------------------------------
27
+ def predict_intrusion(data_dict):
28
+ df = pd.DataFrame([data_dict])
29
+ df = df.apply(pd.to_numeric, errors="coerce").fillna(0)
30
+ prob = ids_model.predict(df)[0]
31
+ return "Attack" if prob > 0.5 else "Normal"
32
+
33
+ # ------------------------------------------
34
+ # LLM Setup
35
+ # ------------------------------------------
36
+ LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
 
37
 
38
+ client = InferenceClient(model=LLM_MODEL, token=hf_token)
 
 
39
 
40
+ # ------------------------------------------
41
+ # Chat-based explanation from LLM
42
+ # ------------------------------------------
43
+ def generate_report(features, prediction):
44
 
45
+ prompt = f"""
46
+ You are a cybersecurity analyst.
47
+
48
+ Network Data:
49
+ {features}
50
+
51
+ IDS Prediction: {prediction}
52
+
53
+ Provide a clear and professional analysis including:
54
+ 1. Attack or Normal?
55
+ 2. Why the IDS believes this.
56
+ 3. Most likely attack type (DoS, Probe, R2L, U2R).
57
+ 4. Severity level.
58
+ 5. Recommended actions.
59
+ 6. Brief incident summary (2–3 sentences).
60
+ """
61
+
62
+ try:
63
+ resp = client.chat_completion(
64
+ messages=[{"role": "user", "content": prompt}],
65
+ max_tokens=400,
66
+ temperature=0.5
67
+ )
68
+ return resp.choices[0].message["content"]
69
+ except Exception as e:
70
+ return f"[LLM Error] {e}"
71
+
72
+ # ------------------------------------------
73
+ # Wrapper for Gradio
74
+ # ------------------------------------------
75
+ def analyze(**inputs):
76
+ prediction = predict_intrusion(inputs)
77
+ report = generate_report(inputs, prediction)
78
+ return prediction, report
79
+
80
+ # ------------------------------------------
81
+ # Gradio UI
82
+ # ------------------------------------------
83
+ feature_inputs = []
84
+
85
+ NSL_FEATURES = [
86
+ "duration","protocol_type","service","flag","src_bytes","dst_bytes",
87
+ "land","wrong_fragment","urgent","hot","num_failed_logins","logged_in",
88
+ "num_compromised","root_shell","su_attempted","num_root","num_file_creations",
89
+ "num_shells","num_access_files","num_outbound_cmds","is_host_login",
90
+ "is_guest_login","count","srv_count","serror_rate","srv_serror_rate",
91
+ "rerror_rate","srv_rerror_rate","same_srv_rate","diff_srv_rate",
92
+ "srv_diff_host_rate","dst_host_count","dst_host_srv_count",
93
+ "dst_host_same_srv_rate","dst_host_diff_srv_rate",
94
+ "dst_host_same_src_port_rate","dst_host_srv_diff_host_rate",
95
+ "dst_host_serror_rate","dst_host_srv_serror_rate","dst_host_rerror_rate",
96
+ "dst_host_srv_rerror_rate"
97
+ ]
98
 
99
+ for f in NSL_FEATURES:
100
+ feature_inputs.append(gr.Number(label=f))
101
 
102
+ interface = gr.Interface(
103
+ fn=analyze,
104
+ inputs=feature_inputs,
105
+ outputs=[
106
+ gr.Textbox(label="Prediction (Attack / Normal)", interactive=False),
107
+ gr.Textbox(label="AI-Generated Incident Report", lines=12, interactive=False)
108
+ ],
109
+ title="Cybersecurity Intrusion Detection System (IDS + AI Analyst)",
110
+ description="Detect network intrusions using a Machine Learning IDS model and get a full explanation via an LLM."
111
  )
112
 
113
+ if __name__ == "__main__":
114
+ interface.launch()