import gradio as gr import numpy as np import pandas as pd import tensorflow as tf from tensorflow.keras.models import load_model import joblib import pickle import json # Load model and artifacts model = load_model('multiclass_cnn_model_fast.h5') scaler = joblib.load('multiclass_standard_scaler.pkl') label_encoder = joblib.load('multiclass_label_encoder.pkl') with open('multiclass_feature_names.pkl', 'rb') as f: feature_names = pickle.load(f) class NetworkIntrusionDetector: def __init__(self, model, scaler, label_encoder, feature_names): self.model = model self.scaler = scaler self.label_encoder = label_encoder self.feature_names = feature_names def preprocess(self, features_dict): """Preprocess input features""" features_df = pd.DataFrame([features_dict]) features_aligned = features_df.reindex(columns=self.feature_names, fill_value=0) features_scaled = self.scaler.transform(features_aligned) features_cnn = features_scaled.reshape(1, -1, 1) return features_cnn def predict(self, **input_features): """Predict attack type""" try: # Preprocess features_cnn = self.preprocess(input_features) # Predict prediction = self.model.predict(features_cnn, verbose=0) class_idx = np.argmax(prediction, axis=1)[0] confidence = np.max(prediction) attack_type = self.label_encoder.inverse_transform([class_idx])[0] # Get top 3 predictions all_probs = { self.label_encoder.classes_[i]: float(prediction[0][i]) for i in range(len(self.label_encoder.classes_)) } top_3 = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)[:3] # Prepare results is_malicious = attack_type != 'normal' status = "🚨 MALICIOUS TRAFFIC" if is_malicious else "✅ NORMAL TRAFFIC" color = "red" if is_malicious else "green" result = { "status": status, "attack_type": attack_type, "confidence": f"{confidence:.4f}", "is_malicious": is_malicious, "color": color } # Format output output = f""" ## 🔍 **Network Traffic Analysis Result** **Status:** {status} **Classification:** {attack_type} **Confidence:** {confidence:.4f} ### Top 3 Predictions: {chr(10).join([f'• {pred[0]}: {pred[1]:.4f}' for pred in top_3])} **Action Recommended:** {'🚨 Immediate investigation required!' if is_malicious else '✅ No action needed'} """ return output except Exception as e: return f"❌ Error: {str(e)}" # Initialize detector detector = NetworkIntrusionDetector(model, scaler, label_encoder, feature_names) # Common network traffic examples sample_normal = { 'duration': 0, 'src_bytes': 0, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0, 'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 1, 'num_compromised': 0, 'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0, 'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0, 'is_guest_login': 0, 'count': 2, 'srv_count': 2, 'serror_rate': 0.0, 'srv_serror_rate': 0.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0, 'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0, 'dst_host_count': 150, 'dst_host_srv_count': 150, 'dst_host_same_srv_rate': 1.0, 'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0, 'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 0.0, 'dst_host_srv_serror_rate': 0.0, 'dst_host_rerror_rate': 0.0, 'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1, 'protocol_type_udp': 0, 'service_http': 1, 'service_other': 0, 'flag_SF': 1 } sample_attack = { 'duration': 0, 'src_bytes': 1032, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0, 'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 0, 'num_compromised': 0, 'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0, 'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0, 'is_guest_login': 0, 'count': 1, 'srv_count': 1, 'serror_rate': 1.0, 'srv_serror_rate': 1.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0, 'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0, 'dst_host_count': 255, 'dst_host_srv_count': 255, 'dst_host_same_srv_rate': 1.0, 'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0, 'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 1.0, 'dst_host_srv_serror_rate': 1.0, 'dst_host_rerror_rate': 0.0, 'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1, 'protocol_type_udp': 0, 'service_http': 0, 'service_other': 1, 'flag_S0': 1 } # Create Gradio interface def create_interface(): # Define input components inputs = [] # Basic features inputs.append(gr.Number(label="Duration", value=0)) inputs.append(gr.Number(label="Source Bytes", value=0)) inputs.append(gr.Number(label="Destination Bytes", value=0)) inputs.append(gr.Number(label="Land (0/1)", value=0)) inputs.append(gr.Number(label="Wrong Fragment", value=0)) # Connection features inputs.append(gr.Number(label="Hot", value=0)) inputs.append(gr.Number(label="Num Failed Logins", value=0)) inputs.append(gr.Number(label="Logged In (0/1)", value=0)) inputs.append(gr.Number(label="Num Compromised", value=0)) inputs.append(gr.Number(label="Root Shell (0/1)", value=0)) # Rate features inputs.append(gr.Number(label="Count", value=1)) inputs.append(gr.Number(label="Service Count", value=1)) inputs.append(gr.Slider(0, 1, label="Error Rate", value=0)) inputs.append(gr.Slider(0, 1, label="Service Error Rate", value=0)) inputs.append(gr.Slider(0, 1, label="Same Service Rate", value=1)) # Protocol type inputs.append(gr.Radio([0, 1], label="Protocol TCP", value=1)) inputs.append(gr.Radio([0, 1], label="Protocol UDP", value=0)) inputs.append(gr.Radio([0, 1], label="Protocol ICMP", value=0)) # Service type inputs.append(gr.Radio([0, 1], label="Service HTTP", value=1)) inputs.append(gr.Radio([0, 1], label="Service Other", value=0)) # Flag inputs.append(gr.Radio([0, 1], label="Flag SF", value=1)) inputs.append(gr.Radio([0, 1], label="Flag S0", value=0)) def predict_attack(*args): # Convert inputs to dictionary feature_names_simple = [ 'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'same_srv_rate', 'protocol_type_tcp', 'protocol_type_udp', 'protocol_type_icmp', 'service_http', 'service_other', 'flag_SF', 'flag_S0' ] features_dict = dict(zip(feature_names_simple, args)) # Fill missing features with 0 for feature in feature_names: if feature not in features_dict: features_dict[feature] = 0 return detector.predict(**features_dict) # Create interface iface = gr.Interface( fn=predict_attack, inputs=inputs, outputs=gr.Markdown(), title="🚨 Network Intrusion Detection System", description="""**Detect malicious network traffic in real-time using AI**\n This system can identify 40+ different types of network attacks including: - DoS attacks (neptune, smurf, teardrop) - Probing attacks (portsweep, nmap, satan) - R2L attacks (guess_passwd, warezclient) - U2R attacks (buffer_overflow, rootkit) *Enter network traffic features below to analyze:*""", examples=[ [sample_normal[k] for k in list(sample_normal.keys())[:len(inputs)]], [sample_attack[k] for k in list(sample_attack.keys())[:len(inputs)]] ], theme="soft" ) return iface # Launch app if __name__ == "__main__": demo = create_interface() demo.launch(share=True)