|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import joblib |
|
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model = load_model('improved_intrusion_detection_model.h5') |
|
|
|
|
|
|
|
|
scaler = joblib.load('final_standard_scaler.pkl') |
|
|
le = joblib.load('final_label_encoder.pkl') |
|
|
feature_names = joblib.load('final_feature_names.pkl') |
|
|
|
|
|
print("Model and preprocessing components loaded successfully.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading files: {e}") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
def predict_intrusion( |
|
|
duration, src_bytes, dst_bytes, land, wrong_fragment, urgent, hot, |
|
|
num_failed_logins, logged_in, num_compromised, root_shell, su_attempted, |
|
|
num_root, num_file_creations, num_shells, num_access_files, |
|
|
num_outbound_cmds, is_host_login, is_guest_login, count, srv_count, |
|
|
serror_rate, srv_serror_rate, rerror_rate, srv_rerror_rate, same_srv_rate, |
|
|
diff_srv_rate, srv_diff_host_rate, dst_host_count, dst_host_srv_count, |
|
|
dst_host_same_srv_rate, dst_host_diff_srv_rate, dst_host_same_src_port_rate, |
|
|
dst_host_srv_diff_host_rate, dst_host_serror_rate, dst_host_srv_serror_rate, |
|
|
dst_host_rerror_rate, dst_host_srv_rerror_rate, |
|
|
|
|
|
protocol_type, service, flag |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_dict = {name: [0.0] for name in feature_names} |
|
|
input_df = pd.DataFrame(data_dict) |
|
|
|
|
|
|
|
|
numerical_features = { |
|
|
'duration': duration, 'src_bytes': src_bytes, 'dst_bytes': dst_bytes, |
|
|
'land': land, 'wrong_fragment': wrong_fragment, 'urgent': urgent, |
|
|
'hot': hot, 'num_failed_logins': num_failed_logins, |
|
|
'logged_in': logged_in, 'num_compromised': num_compromised, |
|
|
'root_shell': root_shell, 'su_attempted': su_attempted, |
|
|
'num_root': num_root, 'num_file_creations': num_file_creations, |
|
|
'num_shells': num_shells, 'num_access_files': num_access_files, |
|
|
'num_outbound_cmds': num_outbound_cmds, 'is_host_login': is_host_login, |
|
|
'is_guest_login': is_guest_login, 'count': count, 'srv_count': srv_count, |
|
|
'serror_rate': serror_rate, 'srv_serror_rate': srv_serror_rate, |
|
|
'rerror_rate': rerror_rate, 'srv_rerror_rate': srv_rerror_rate, |
|
|
'same_srv_rate': same_srv_rate, 'diff_srv_rate': diff_srv_rate, |
|
|
'srv_diff_host_rate': srv_diff_host_rate, 'dst_host_count': dst_host_count, |
|
|
'dst_host_srv_count': dst_host_srv_count, |
|
|
'dst_host_same_srv_rate': dst_host_same_srv_rate, |
|
|
'dst_host_diff_srv_rate': dst_host_diff_srv_rate, |
|
|
'dst_host_same_src_port_rate': dst_host_same_src_port_rate, |
|
|
'dst_host_srv_diff_host_rate': dst_host_srv_diff_host_rate, |
|
|
'dst_host_serror_rate': dst_host_serror_rate, |
|
|
'dst_host_srv_serror_rate': dst_host_srv_serror_rate, |
|
|
'dst_host_rerror_rate': dst_host_rerror_rate, |
|
|
'dst_host_srv_rerror_rate': dst_host_srv_rerror_rate |
|
|
} |
|
|
|
|
|
|
|
|
for col, val in numerical_features.items(): |
|
|
if col in input_df.columns: |
|
|
input_df[col] = val |
|
|
|
|
|
|
|
|
if f'protocol_type_{protocol_type}' in input_df.columns: |
|
|
input_df[f'protocol_type_{protocol_type}'] = 1.0 |
|
|
if f'service_{service}' in input_df.columns: |
|
|
input_df[f'service_{service}'] = 1.0 |
|
|
if f'flag_{flag}' in input_df.columns: |
|
|
input_df[f'flag_{flag}'] = 1.0 |
|
|
|
|
|
|
|
|
input_df = input_df[feature_names] |
|
|
|
|
|
|
|
|
X_scaled = scaler.transform(input_df.values) |
|
|
|
|
|
|
|
|
X_cnn = X_scaled.reshape(1, X_scaled.shape[1], 1) |
|
|
|
|
|
|
|
|
prediction_proba = model.predict(X_cnn)[0][0] |
|
|
|
|
|
|
|
|
|
|
|
prediction_class = (prediction_proba > 0.5).astype(int) |
|
|
|
|
|
|
|
|
decoded_prediction = le.classes_[prediction_class] |
|
|
|
|
|
|
|
|
probability_percent = f"{prediction_proba * 100:.2f}%" |
|
|
|
|
|
if decoded_prediction == 'attack': |
|
|
result_message = f"🚨 INTRUSION DETECTED! (Attack Probability: {probability_percent})" |
|
|
else: |
|
|
result_message = f"✅ Normal Traffic. (Attack Probability: {probability_percent})" |
|
|
|
|
|
return result_message |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = [ |
|
|
gr.Number(label="Duration (seconds)", value=0), |
|
|
gr.Number(label="Source Bytes", value=491), |
|
|
gr.Number(label="Destination Bytes", value=0), |
|
|
gr.Dropdown(label="Protocol Type", choices=['tcp', 'udp', 'icmp'], value='tcp'), |
|
|
gr.Dropdown(label="Service", choices=['ftp_data', 'http', 'private', 'domain_u', 'other'], value='ftp_data'), |
|
|
gr.Dropdown(label="Flag", choices=['SF', 'S0', 'REJ', 'RSTO'], value='SF'), |
|
|
gr.Slider(0, 1, label="Logged In (1=Yes, 0=No)", step=1, value=1), |
|
|
gr.Number(label="Count (connections to same host)", value=2), |
|
|
gr.Number(label="Service Count (connections to same service)", value=2), |
|
|
gr.Slider(0.0, 1.0, label="Same Service Rate", step=0.01, value=1.0), |
|
|
gr.Slider(0.0, 1.0, label="SError Rate", step=0.01, value=0.0) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_wrapper(*args): |
|
|
|
|
|
simplified_inputs = args |
|
|
|
|
|
|
|
|
defaults = [0.0] * 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
defaults[0] = simplified_inputs[0] |
|
|
defaults[1] = simplified_inputs[1] |
|
|
defaults[2] = simplified_inputs[2] |
|
|
|
|
|
defaults[8] = simplified_inputs[6] |
|
|
|
|
|
defaults[19] = simplified_inputs[7] |
|
|
defaults[20] = simplified_inputs[8] |
|
|
defaults[21] = simplified_inputs[9] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_args = list(defaults) |
|
|
full_args[0] = simplified_inputs[0] |
|
|
full_args[1] = simplified_inputs[1] |
|
|
full_args[2] = simplified_inputs[2] |
|
|
full_args[8] = simplified_inputs[6] |
|
|
full_args[19] = simplified_inputs[7] |
|
|
full_args[20] = simplified_inputs[8] |
|
|
full_args[21] = simplified_inputs[9] |
|
|
full_args[22] = simplified_inputs[9] |
|
|
full_args[23] = simplified_inputs[10] |
|
|
full_args[24] = simplified_inputs[10] |
|
|
full_args[25] = simplified_inputs[9] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protocol_type = simplified_inputs[3] |
|
|
service = simplified_inputs[4] |
|
|
flag = simplified_inputs[5] |
|
|
|
|
|
|
|
|
final_args = full_args + [protocol_type, service, flag] |
|
|
|
|
|
|
|
|
return predict_intrusion(*final_args) |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_wrapper, |
|
|
inputs=inputs, |
|
|
outputs=gr.Label(label="Intrusion Detection Result"), |
|
|
title="Intrusion Detection System (KDD Cup '99)", |
|
|
description="Predicts whether a network connection is 'normal' or an 'attack' using a 1D CNN model. The model achieved 99.28% accuracy on the test set." |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |