Commit
·
9a09d32
1
Parent(s):
e489d41
Update Sniffer_AI.py
Browse files- Sniffer_AI.py +46 -19
Sniffer_AI.py
CHANGED
|
@@ -13,6 +13,19 @@ dt_model = joblib.load('decision_tree_model.pkl')
|
|
| 13 |
bagging_model = joblib.load('model_bagging.pkl')
|
| 14 |
ada_model = joblib.load('model_adaboost.pkl')
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class_labels = {
|
| 17 |
0: "normal",
|
| 18 |
1: "backdoor",
|
|
@@ -26,39 +39,53 @@ class_labels = {
|
|
| 26 |
9: "mitm"
|
| 27 |
}
|
| 28 |
|
| 29 |
-
def detect_intrusion(
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Choose the model based on user selection
|
| 34 |
if model_choice == "Random Forest":
|
| 35 |
model = rf_model
|
| 36 |
elif model_choice == "Decision Tree":
|
| 37 |
-
model =
|
| 38 |
elif model_choice == "Bagging Classifier":
|
| 39 |
-
model =
|
| 40 |
elif model_choice == "AdaBoost Classifier":
|
| 41 |
-
model =
|
| 42 |
else:
|
| 43 |
return "Invalid model choice!"
|
| 44 |
-
|
| 45 |
# Predict the class (multi-class classification)
|
| 46 |
-
prediction = model.predict(
|
| 47 |
-
predicted_class = prediction[0] # Get the predicted class (an integer between 0-
|
| 48 |
-
|
| 49 |
-
#
|
| 50 |
if predicted_class == 0:
|
| 51 |
return "No Intrusion Detected"
|
| 52 |
else:
|
| 53 |
return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
|
| 54 |
|
| 55 |
-
# Create
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Launch the interface locally for testing
|
| 64 |
iface.launch()
|
|
|
|
| 13 |
bagging_model = joblib.load('model_bagging.pkl')
|
| 14 |
ada_model = joblib.load('model_adaboost.pkl')
|
| 15 |
|
| 16 |
+
# Define the feature names
|
| 17 |
+
feature_names = [
|
| 18 |
+
"src_ip", "src_port", "dst_ip", "dst_port", "proto", "service", "duration",
|
| 19 |
+
"src_bytes", "dst_bytes", "conn_state", "missed_bytes", "src_pkts",
|
| 20 |
+
"src_ip_bytes", "dst_pkts", "dst_ip_bytes", "dns_query", "dns_qclass",
|
| 21 |
+
"dns_qtype", "dns_rcode", "dns_AA", "dns_RD", "dns_RA", "dns_rejected",
|
| 22 |
+
"ssl_version", "ssl_cipher", "ssl_resumed", "ssl_established", "ssl_subject",
|
| 23 |
+
"ssl_issuer", "http_trans_depth", "http_method", "http_uri", "http_version",
|
| 24 |
+
"http_request_body_len", "http_response_body_len", "http_status_code",
|
| 25 |
+
"http_user_agent", "http_orig_mime_types", "http_resp_mime_types",
|
| 26 |
+
"weird_name", "weird_addl", "weird_notice", "label"
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
class_labels = {
|
| 30 |
0: "normal",
|
| 31 |
1: "backdoor",
|
|
|
|
| 39 |
9: "mitm"
|
| 40 |
}
|
| 41 |
|
| 42 |
+
def detect_intrusion(feature_values, model_choice="Random Forest"):
|
| 43 |
+
# Ensure the length of feature_values matches feature_names
|
| 44 |
+
if len(feature_values) != len(feature_names):
|
| 45 |
+
return "Please fill in all the required feature values."
|
| 46 |
+
|
| 47 |
+
# Convert the input values to floats and match them with feature names
|
| 48 |
+
try:
|
| 49 |
+
feature_values = [float(value) for value in feature_values]
|
| 50 |
+
except ValueError:
|
| 51 |
+
return "Please enter valid numerical values for all fields."
|
| 52 |
+
|
| 53 |
# Choose the model based on user selection
|
| 54 |
if model_choice == "Random Forest":
|
| 55 |
model = rf_model
|
| 56 |
elif model_choice == "Decision Tree":
|
| 57 |
+
model = dt_model
|
| 58 |
elif model_choice == "Bagging Classifier":
|
| 59 |
+
model = bagging_model
|
| 60 |
elif model_choice == "AdaBoost Classifier":
|
| 61 |
+
model = ada_model
|
| 62 |
else:
|
| 63 |
return "Invalid model choice!"
|
| 64 |
+
|
| 65 |
# Predict the class (multi-class classification)
|
| 66 |
+
prediction = model.predict([feature_values])
|
| 67 |
+
predicted_class = prediction[0] # Get the predicted class (an integer between 0-9)
|
| 68 |
+
|
| 69 |
+
# Notify the user of the detected attack or normal traffic
|
| 70 |
if predicted_class == 0:
|
| 71 |
return "No Intrusion Detected"
|
| 72 |
else:
|
| 73 |
return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
|
| 74 |
|
| 75 |
+
# Create Gradio input fields for each feature
|
| 76 |
+
inputs = [gr.Textbox(label=feature_name) for feature_name in feature_names[:-1]] # Exclude "label" field from inputs
|
| 77 |
+
|
| 78 |
+
# Add model choice dropdown
|
| 79 |
+
inputs.append(gr.Dropdown(choices=["Random Forest", "Decision Tree", "Bagging Classifier", "AdaBoost Classifier"], label="Select Model"))
|
| 80 |
+
|
| 81 |
+
# Create the Gradio interface
|
| 82 |
+
iface = gr.Interface(
|
| 83 |
+
fn=detect_intrusion,
|
| 84 |
+
inputs=inputs,
|
| 85 |
+
outputs="text",
|
| 86 |
+
title="Intrusion Detection System",
|
| 87 |
+
description="Fill in the blank fields for the network traffic features, and choose the model to detect intrusions."
|
| 88 |
+
)
|
| 89 |
|
| 90 |
# Launch the interface locally for testing
|
| 91 |
iface.launch()
|