Spaces:
Sleeping
Sleeping
Commit
·
5ef57e6
1
Parent(s):
47b79b6
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,96 +2,67 @@ import gradio as gr
|
|
| 2 |
import joblib
|
| 3 |
import requests
|
| 4 |
import os
|
| 5 |
-
|
| 6 |
-
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
|
| 7 |
-
from sklearn.tree import DecisionTreeClassifier
|
| 8 |
|
| 9 |
# Load the saved models
|
| 10 |
rf_model = joblib.load('rf_model.pkl')
|
| 11 |
-
dt_model = joblib.load('decision_tree_model.pkl')
|
| 12 |
-
bagging_model = joblib.load('model_bagging.pkl')
|
| 13 |
-
ada_model = joblib.load('model_adaboost.pkl')
|
| 14 |
|
| 15 |
-
# Define the feature names
|
| 16 |
feature_names = [
|
| 17 |
-
"
|
| 18 |
-
"src_bytes", "dst_bytes", "conn_state", "missed_bytes", "src_pkts",
|
| 19 |
-
"src_ip_bytes", "dst_pkts", "dst_ip_bytes", "dns_query", "dns_qclass",
|
| 20 |
-
"dns_qtype", "dns_rcode", "dns_AA", "dns_RD", "dns_RA", "dns_rejected",
|
| 21 |
-
"ssl_version", "ssl_cipher", "ssl_resumed", "ssl_established", "ssl_subject",
|
| 22 |
-
"ssl_issuer", "http_trans_depth", "http_method", "http_uri", "http_version",
|
| 23 |
-
"http_request_body_len", "http_response_body_len", "http_status_code",
|
| 24 |
-
"http_user_agent", "http_orig_mime_types", "http_resp_mime_types",
|
| 25 |
-
"weird_name", "weird_addl", "weird_notice", "label"
|
| 26 |
]
|
| 27 |
|
| 28 |
class_labels = {
|
| 29 |
0: "normal",
|
| 30 |
1: "backdoor",
|
| 31 |
2: "ddos",
|
| 32 |
-
3: "
|
| 33 |
-
4: "
|
| 34 |
-
5: "
|
| 35 |
-
6: "
|
| 36 |
-
7: "
|
| 37 |
-
8: "xss",
|
| 38 |
-
9: "mitm"
|
| 39 |
}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
if len(feature_values) != len(feature_names) - 1:
|
| 44 |
-
return "Please fill in all the required feature values."
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
try:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
return "Please enter valid numerical values for all fields, including the label."
|
| 52 |
-
|
| 53 |
-
# Add the label to feature values
|
| 54 |
-
feature_values.append(label_value)
|
| 55 |
-
|
| 56 |
-
# Choose the model based on user selection
|
| 57 |
-
if model_choice == "Random Forest":
|
| 58 |
-
model = rf_model
|
| 59 |
-
elif model_choice == "Decision Tree":
|
| 60 |
-
model = dt_model
|
| 61 |
-
elif model_choice == "Bagging Classifier":
|
| 62 |
-
model = bagging_model
|
| 63 |
-
elif model_choice == "AdaBoost Classifier":
|
| 64 |
-
model = ada_model
|
| 65 |
-
else:
|
| 66 |
-
return "Invalid model choice!"
|
| 67 |
-
|
| 68 |
-
# Predict the class (multi-class classification)
|
| 69 |
-
prediction = model.predict([feature_values])
|
| 70 |
-
predicted_class = prediction[0] # Get the predicted class (an integer between 0-9)
|
| 71 |
-
|
| 72 |
-
# Notify the user of the detected attack or normal traffic
|
| 73 |
-
if predicted_class == 0:
|
| 74 |
-
return "No Intrusion Detected"
|
| 75 |
-
else:
|
| 76 |
-
return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
|
| 77 |
-
|
| 78 |
-
# Create Gradio input fields for each feature (excluding label initially)
|
| 79 |
-
inputs = [gr.Textbox(label=feature_name) for feature_name in feature_names[:-1]] # Exclude the last one (label) from inputs
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
# Create
|
| 88 |
iface = gr.Interface(
|
| 89 |
-
fn=detect_intrusion,
|
| 90 |
-
inputs=
|
|
|
|
|
|
|
| 91 |
outputs="text",
|
| 92 |
title="Intrusion Detection System",
|
| 93 |
-
description=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
-
# Launch the interface
|
| 97 |
-
iface.launch(
|
|
|
|
| 2 |
import joblib
|
| 3 |
import requests
|
| 4 |
import os
|
| 5 |
+
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Load the saved models
|
| 8 |
rf_model = joblib.load('rf_model.pkl')
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# Define the feature names (excluding the target column 'type')
|
| 11 |
feature_names = [
|
| 12 |
+
"date", "time", "door_state", "sphone_signal", "label"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
]
|
| 14 |
|
| 15 |
class_labels = {
|
| 16 |
0: "normal",
|
| 17 |
1: "backdoor",
|
| 18 |
2: "ddos",
|
| 19 |
+
3: "injection",
|
| 20 |
+
4: "password",
|
| 21 |
+
5: "ransomware",
|
| 22 |
+
6: "scanning",
|
| 23 |
+
7: "xss",
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
|
| 26 |
+
# Placeholder model (replace with actual Random Forest model object)
|
| 27 |
+
rf_model = None # Load the actual trained Random Forest model here
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def detect_intrusion(file):
|
| 30 |
+
# Read the uploaded log file as a CSV or structured data
|
| 31 |
try:
|
| 32 |
+
log_data = pd.read_csv(file.name) # Use file.name to get the path for reading
|
| 33 |
+
except Exception as e:
|
| 34 |
+
return f"Error reading file: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# Check if all required feature columns are in the log file
|
| 37 |
+
missing_features = [feature for feature in feature_names if feature not in log_data.columns]
|
| 38 |
+
if missing_features:
|
| 39 |
+
return f"Missing features in file: {', '.join(missing_features)}"
|
| 40 |
+
|
| 41 |
+
# Extract the feature values (excluding the 'type' column which is the target)
|
| 42 |
+
feature_values = log_data[feature_names].astype(float).values
|
| 43 |
|
| 44 |
+
# Predict the class (multi-class classification) for each row in the log file
|
| 45 |
+
predictions = rf_model.predict(feature_values)
|
| 46 |
+
|
| 47 |
+
# Return only the 'Prediction' and 'label' columns
|
| 48 |
+
return log_data[['Prediction']].head().to_string()
|
| 49 |
|
| 50 |
+
# Create a Gradio interface
|
| 51 |
iface = gr.Interface(
|
| 52 |
+
fn=detect_intrusion,
|
| 53 |
+
inputs=[
|
| 54 |
+
gr.File(label="Upload Log File (CSV format)") # File input
|
| 55 |
+
],
|
| 56 |
outputs="text",
|
| 57 |
title="Intrusion Detection System",
|
| 58 |
+
description=("""
|
| 59 |
+
Upload a CSV log file containing the following features:
|
| 60 |
+
date, time, door_state, sphone_signal, label (without the 'type' column).
|
| 61 |
+
Example file structure:
|
| 62 |
+
date,time,door_state,sphone_signal,label
|
| 63 |
+
2025-03-12,10:45:00,1,-85,normal
|
| 64 |
+
""")
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# Launch the interface locally for testing
|
| 68 |
+
iface.launch()
|