utsavNagar commited on
Commit
0f94f14
·
verified ·
1 Parent(s): 973990d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -18
app.py CHANGED
@@ -1,28 +1,52 @@
1
  import gradio as gr
2
- import pickle
3
  import pandas as pd
 
 
 
 
 
 
 
 
4
 
5
  # Load model
6
- with open("ids_model.pkl", "rb") as f:
7
  model = pickle.load(f)
8
 
9
- def predict(sample):
10
- df = pd.DataFrame([sample])
11
-
12
- # Force numeric conversion
13
- df = df.apply(pd.to_numeric, errors='coerce').fillna(0)
 
 
 
 
 
 
 
 
 
14
 
15
- pred = model.predict(df)[0]
16
- return "🔴 Attack Detected" if pred > 0.5 else "🟢 Normal Traffic"
 
17
 
18
- inputs = [
19
- gr.Number(label=col) for col in range(41) # but better replace with real feature names
20
- ]
 
 
 
 
 
21
 
22
- gr.Interface(
23
- fn=predict,
24
- inputs=inputs,
25
  outputs="text",
26
- title="CyberSecurity Intrusion Detection System (IDS)",
27
- description="Upload network flow values and detect intrusions using LightGBM."
28
- ).launch()
 
 
 
1
  import gradio as gr
 
2
  import pandas as pd
3
+ import pickle
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # Download model from Hugging Face Hub
7
+ model_path = hf_hub_download(
8
+ repo_id="utsavNagar/cyberids-ml",
9
+ filename="ids_model.pkl"
10
+ )
11
 
12
  # Load model
13
+ with open(model_path, "rb") as f:
14
  model = pickle.load(f)
15
 
16
+ # All NSL-KDD features (41 features)
17
+ columns = [
18
+ "duration","protocol_type","service","flag","src_bytes","dst_bytes","land",
19
+ "wrong_fragment","urgent","hot","num_failed_logins","logged_in",
20
+ "num_compromised","root_shell","su_attempted","num_root",
21
+ "num_file_creations","num_shells","num_access_files","num_outbound_cmds",
22
+ "is_host_login","is_guest_login","count","srv_count","serror_rate",
23
+ "srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
24
+ "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
25
+ "dst_host_same_srv_rate","dst_host_diff_srv_rate",
26
+ "dst_host_same_src_port_rate","dst_host_srv_diff_host_rate",
27
+ "dst_host_serror_rate","dst_host_srv_serror_rate","dst_host_rerror_rate",
28
+ "dst_host_srv_rerror_rate"
29
+ ]
30
 
31
+ def predict_intrusion(*inputs):
32
+ data = pd.DataFrame([inputs], columns=columns)
33
+ data = data.apply(pd.to_numeric, errors="coerce").fillna(0)
34
 
35
+ pred = model.predict(data)[0]
36
+ if pred > 0.5:
37
+ return "🔴 ATTACK DETECTED"
38
+ else:
39
+ return "🟢 NORMAL TRAFFIC"
40
+
41
+ # Build Gradio UI with 41 numeric inputs
42
+ inputs_ui = [gr.Number(label=col) for col in columns]
43
 
44
+ app = gr.Interface(
45
+ fn=predict_intrusion,
46
+ inputs=inputs_ui,
47
  outputs="text",
48
+ title="CyberSecurity IDS - NSL KDD",
49
+ description="Enter feature values to detect intrusion using LightGBM model."
50
+ )
51
+
52
+ app.launch()