|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.models import load_model |
|
|
import joblib |
|
|
import pickle |
|
|
import json |
|
|
|
|
|
|
|
|
model = load_model('multiclass_cnn_model_fast.h5') |
|
|
scaler = joblib.load('multiclass_standard_scaler.pkl') |
|
|
label_encoder = joblib.load('multiclass_label_encoder.pkl') |
|
|
with open('multiclass_feature_names.pkl', 'rb') as f: |
|
|
feature_names = pickle.load(f) |
|
|
|
|
|
class NetworkIntrusionDetector: |
|
|
def __init__(self, model, scaler, label_encoder, feature_names): |
|
|
self.model = model |
|
|
self.scaler = scaler |
|
|
self.label_encoder = label_encoder |
|
|
self.feature_names = feature_names |
|
|
|
|
|
def preprocess(self, features_dict): |
|
|
"""Preprocess input features""" |
|
|
features_df = pd.DataFrame([features_dict]) |
|
|
features_aligned = features_df.reindex(columns=self.feature_names, fill_value=0) |
|
|
features_scaled = self.scaler.transform(features_aligned) |
|
|
features_cnn = features_scaled.reshape(1, -1, 1) |
|
|
return features_cnn |
|
|
|
|
|
def predict(self, **input_features): |
|
|
"""Predict attack type""" |
|
|
try: |
|
|
|
|
|
features_cnn = self.preprocess(input_features) |
|
|
|
|
|
|
|
|
prediction = self.model.predict(features_cnn, verbose=0) |
|
|
class_idx = np.argmax(prediction, axis=1)[0] |
|
|
confidence = np.max(prediction) |
|
|
attack_type = self.label_encoder.inverse_transform([class_idx])[0] |
|
|
|
|
|
|
|
|
all_probs = { |
|
|
self.label_encoder.classes_[i]: float(prediction[0][i]) |
|
|
for i in range(len(self.label_encoder.classes_)) |
|
|
} |
|
|
top_3 = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
|
|
|
|
|
|
is_malicious = attack_type != 'normal' |
|
|
status = "π¨ MALICIOUS TRAFFIC" if is_malicious else "β
NORMAL TRAFFIC" |
|
|
color = "red" if is_malicious else "green" |
|
|
|
|
|
result = { |
|
|
"status": status, |
|
|
"attack_type": attack_type, |
|
|
"confidence": f"{confidence:.4f}", |
|
|
"is_malicious": is_malicious, |
|
|
"color": color |
|
|
} |
|
|
|
|
|
|
|
|
output = f""" |
|
|
## π **Network Traffic Analysis Result** |
|
|
|
|
|
**Status:** <span style='color:{color}; font-weight:bold'>{status}</span> |
|
|
|
|
|
**Classification:** {attack_type} |
|
|
**Confidence:** {confidence:.4f} |
|
|
|
|
|
### Top 3 Predictions: |
|
|
{chr(10).join([f'β’ {pred[0]}: {pred[1]:.4f}' for pred in top_3])} |
|
|
|
|
|
**Action Recommended:** {'π¨ Immediate investigation required!' if is_malicious else 'β
No action needed'} |
|
|
""" |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
detector = NetworkIntrusionDetector(model, scaler, label_encoder, feature_names) |
|
|
|
|
|
|
|
|
sample_normal = { |
|
|
'duration': 0, 'src_bytes': 0, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0, |
|
|
'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 1, 'num_compromised': 0, |
|
|
'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0, |
|
|
'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0, |
|
|
'is_guest_login': 0, 'count': 2, 'srv_count': 2, 'serror_rate': 0.0, |
|
|
'srv_serror_rate': 0.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0, |
|
|
'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0, |
|
|
'dst_host_count': 150, 'dst_host_srv_count': 150, 'dst_host_same_srv_rate': 1.0, |
|
|
'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0, |
|
|
'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 0.0, |
|
|
'dst_host_srv_serror_rate': 0.0, 'dst_host_rerror_rate': 0.0, |
|
|
'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1, |
|
|
'protocol_type_udp': 0, 'service_http': 1, 'service_other': 0, 'flag_SF': 1 |
|
|
} |
|
|
|
|
|
sample_attack = { |
|
|
'duration': 0, 'src_bytes': 1032, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0, |
|
|
'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 0, 'num_compromised': 0, |
|
|
'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0, |
|
|
'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0, |
|
|
'is_guest_login': 0, 'count': 1, 'srv_count': 1, 'serror_rate': 1.0, |
|
|
'srv_serror_rate': 1.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0, |
|
|
'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0, |
|
|
'dst_host_count': 255, 'dst_host_srv_count': 255, 'dst_host_same_srv_rate': 1.0, |
|
|
'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0, |
|
|
'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 1.0, |
|
|
'dst_host_srv_serror_rate': 1.0, 'dst_host_rerror_rate': 0.0, |
|
|
'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1, |
|
|
'protocol_type_udp': 0, 'service_http': 0, 'service_other': 1, 'flag_S0': 1 |
|
|
} |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
|
|
|
inputs = [] |
|
|
|
|
|
|
|
|
inputs.append(gr.Number(label="Duration", value=0)) |
|
|
inputs.append(gr.Number(label="Source Bytes", value=0)) |
|
|
inputs.append(gr.Number(label="Destination Bytes", value=0)) |
|
|
inputs.append(gr.Number(label="Land (0/1)", value=0)) |
|
|
inputs.append(gr.Number(label="Wrong Fragment", value=0)) |
|
|
|
|
|
|
|
|
inputs.append(gr.Number(label="Hot", value=0)) |
|
|
inputs.append(gr.Number(label="Num Failed Logins", value=0)) |
|
|
inputs.append(gr.Number(label="Logged In (0/1)", value=0)) |
|
|
inputs.append(gr.Number(label="Num Compromised", value=0)) |
|
|
inputs.append(gr.Number(label="Root Shell (0/1)", value=0)) |
|
|
|
|
|
|
|
|
inputs.append(gr.Number(label="Count", value=1)) |
|
|
inputs.append(gr.Number(label="Service Count", value=1)) |
|
|
inputs.append(gr.Slider(0, 1, label="Error Rate", value=0)) |
|
|
inputs.append(gr.Slider(0, 1, label="Service Error Rate", value=0)) |
|
|
inputs.append(gr.Slider(0, 1, label="Same Service Rate", value=1)) |
|
|
|
|
|
|
|
|
inputs.append(gr.Radio([0, 1], label="Protocol TCP", value=1)) |
|
|
inputs.append(gr.Radio([0, 1], label="Protocol UDP", value=0)) |
|
|
inputs.append(gr.Radio([0, 1], label="Protocol ICMP", value=0)) |
|
|
|
|
|
|
|
|
inputs.append(gr.Radio([0, 1], label="Service HTTP", value=1)) |
|
|
inputs.append(gr.Radio([0, 1], label="Service Other", value=0)) |
|
|
|
|
|
|
|
|
inputs.append(gr.Radio([0, 1], label="Flag SF", value=1)) |
|
|
inputs.append(gr.Radio([0, 1], label="Flag S0", value=0)) |
|
|
|
|
|
def predict_attack(*args): |
|
|
|
|
|
feature_names_simple = [ |
|
|
'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', |
|
|
'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', |
|
|
'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'same_srv_rate', |
|
|
'protocol_type_tcp', 'protocol_type_udp', 'protocol_type_icmp', |
|
|
'service_http', 'service_other', 'flag_SF', 'flag_S0' |
|
|
] |
|
|
|
|
|
features_dict = dict(zip(feature_names_simple, args)) |
|
|
|
|
|
|
|
|
for feature in feature_names: |
|
|
if feature not in features_dict: |
|
|
features_dict[feature] = 0 |
|
|
|
|
|
return detector.predict(**features_dict) |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_attack, |
|
|
inputs=inputs, |
|
|
outputs=gr.Markdown(), |
|
|
title="π¨ Network Intrusion Detection System", |
|
|
description="""**Detect malicious network traffic in real-time using AI**\n |
|
|
This system can identify 40+ different types of network attacks including: |
|
|
- DoS attacks (neptune, smurf, teardrop) |
|
|
- Probing attacks (portsweep, nmap, satan) |
|
|
- R2L attacks (guess_passwd, warezclient) |
|
|
- U2R attacks (buffer_overflow, rootkit) |
|
|
|
|
|
*Enter network traffic features below to analyze:*""", |
|
|
examples=[ |
|
|
[sample_normal[k] for k in list(sample_normal.keys())[:len(inputs)]], |
|
|
[sample_attack[k] for k in list(sample_attack.keys())[:len(inputs)]] |
|
|
], |
|
|
theme="soft" |
|
|
) |
|
|
|
|
|
return iface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch(share=True) |