MrUtakata commited on
Commit
f6e4e35
·
verified ·
1 Parent(s): 397163e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -86
app.py CHANGED
@@ -1,90 +1,51 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import pickle
4
  import joblib
 
5
 
6
- # Load model artifacts
7
- @st.cache_resource
8
- def load_model_artifacts():
9
- with open('features_to_drop.pkl', 'rb') as f:
10
- features_to_drop = pickle.load(f)
11
- with open('category_encodings.pkl', 'rb') as f:
12
- category_encodings = pickle.load(f)
13
- model = joblib.load('xgb_model.pkl')
14
- return features_to_drop, category_encodings, model
15
-
16
- # Preprocessing function
17
- def preprocess_input(df, features_to_drop, category_encodings):
18
- df = df.copy()
19
- drop_cols = list(features_to_drop.intersection(df.columns))
20
- df.drop(columns=drop_cols, inplace=True, errors='ignore')
21
- for col, cats in category_encodings.items():
22
- if col in df.columns:
23
- df[col] = df[col].astype(str)
24
- df[col] = pd.Categorical(df[col], categories=cats)
25
- df[col] = df[col].cat.codes
26
- df = df.fillna(0)
27
- return df
28
-
29
- # UI layout
30
- st.set_page_config("IDS - Selected Features Input", layout="wide")
31
- st.title("Intrusion Detection System (IDS) - Manual Test Input")
32
-
33
- st.markdown("Enter selected features for a single row of network flow data:")
34
-
35
- with st.form("selected_input_form"):
36
- col1, col2 = st.columns(2)
37
- with col1:
38
- dur = st.number_input("Duration of Flow (`dur`)", min_value=0.0, step=0.001)
39
- proto = st.selectbox("Protocol (`proto`)", ["tcp", "udp", "icmp", "-"])
40
- service = st.text_input("Service (`service`)", value="-")
41
- state = st.selectbox("State (`state`)", ["CON", "INT", "FIN", "REQ", "-"])
42
- sbytes = st.number_input("Source Bytes (`sbytes`)", min_value=0)
43
- spkts = st.number_input("Source Packets (`spkts`)", min_value=0)
44
- sttl = st.number_input("Source TTL (`sttl`)", min_value=0)
45
- sload = st.number_input("Source Load (`sload`)", min_value=0.0)
46
-
47
- with col2:
48
- dbytes = st.number_input("Destination Bytes (`dbytes`)", min_value=0)
49
- dpkts = st.number_input("Destination Packets (`dpkts`)", min_value=0)
50
- sloss = st.number_input("Source Packet Loss (`sloss`)", min_value=0)
51
- swin = st.number_input("Source TCP Window Size (`swin`)", min_value=0)
52
- trans_depth = st.number_input("Transaction Depth (`trans_depth`)", min_value=0)
53
- ct_flw_http_mthd = st.text_input("HTTP Method (`ct_flw_http_mthd`)", value="-")
54
- is_ftp_login = st.selectbox("Is FTP Login (`is_ftp_login`)", [0, 1])
55
- attack_cat = st.text_input("(Optional) Ground Truth Label (`attack_cat`)", value="")
56
-
57
- submitted = st.form_submit_button("Run IDS Prediction")
58
-
59
- if submitted:
60
- user_input = pd.DataFrame([{
61
- "dur": dur,
62
- "proto": proto,
63
- "service": service,
64
- "state": state,
65
- "sbytes": sbytes,
66
- "dbytes": dbytes,
67
- "spkts": spkts,
68
- "dpkts": dpkts,
69
- "sttl": sttl,
70
- "sload": sload,
71
- "sloss": sloss,
72
- "swin": swin,
73
- "trans_depth": trans_depth,
74
- "ct_flw_http_mthd": ct_flw_http_mthd,
75
- "is_ftp_login": is_ftp_login,
76
- "attack_cat": attack_cat
77
- }])
78
-
79
- features_to_drop, category_encodings, model = load_model_artifacts()
80
- processed_input = preprocess_input(user_input, features_to_drop, category_encodings)
81
-
82
- if processed_input is not None:
83
- prediction = model.predict(processed_input)[0]
84
- st.success(f"Prediction: {prediction}")
85
- st.markdown("""
86
- - **13** → Normal Traffic
87
- - **Other values** → Intrusion Category
88
- """)
89
- else:
90
- st.error("Preprocessing failed. Please check your input.")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import numpy as np
4
  import joblib
5
+ import pickle
6
 
7
+ # ---------------------------
8
+ # Load trained model and metadata
9
+ # ---------------------------
10
+ model = joblib.load('xgb_model.pkl')
11
+
12
+ with open('features_to_drop.pkl', 'rb') as f:
13
+ features_to_drop = pickle.load(f)
14
+
15
+ with open('category_encodings.pkl', 'rb') as f:
16
+ category_encodings = pickle.load(f)
17
+
18
+ # ---------------------------
19
+ # App UI
20
+ # ---------------------------
21
+ st.title("🚨 Intrusion Detection In Networks")
22
+ st.markdown("Paste a single row of preprocessed NB15 features (comma-separated) below to predict the intrusion type.")
23
+
24
+ user_input = st.text_area("Input (1 row from NB15_combined_preprocessed.csv):", height=150)
25
+
26
+ if st.button("Predict"):
27
+ try:
28
+ # Convert user input string into a DataFrame
29
+ row = [x.strip() for x in user_input.strip().split(",")]
30
+ input_df = pd.DataFrame([row])
31
+
32
+ # Load a sample row to get correct column names and data types
33
+ sample_df = pd.read_csv('NB15_combined_preprocessed.csv', nrows=1)
34
+ columns = sample_df.drop('attack_cat', axis=1).columns.tolist()
35
+
36
+ # Ensure correct number of columns
37
+ if len(row) != len(columns):
38
+ st.error(f"Expected {len(columns)} values, but got {len(row)}.")
39
+ else:
40
+ input_df.columns = columns
41
+ # Convert to appropriate types
42
+ input_df = input_df.apply(pd.to_numeric, errors='coerce')
43
+
44
+ # Drop high-correlation features
45
+ input_df = input_df.drop(columns=features_to_drop, errors='ignore')
46
+
47
+ # Predict
48
+ prediction = model.predict(input_df)[0]
49
+ st.success(f"🛡️ Predicted Intrusion Category: **{prediction}**")
50
+ except Exception as e:
51
+ st.error(f"⚠️ Error processing input: {e}")