import gradio as gr import numpy as np import pandas as pd import joblib from tensorflow.keras.models import load_model # --- 1. Load Model and Preprocessing Components --- try: # Load the trained model model = load_model('improved_intrusion_detection_model.h5') # Load the scaler, label encoder, and feature names scaler = joblib.load('final_standard_scaler.pkl') le = joblib.load('final_label_encoder.pkl') feature_names = joblib.load('final_feature_names.pkl') print("Model and preprocessing components loaded successfully.") except Exception as e: print(f"Error loading files: {e}") # Exit or raise error if crucial files are missing raise # --- 2. Define the Prediction Function --- def predict_intrusion( duration, src_bytes, dst_bytes, land, wrong_fragment, urgent, hot, num_failed_logins, logged_in, num_compromised, root_shell, su_attempted, num_root, num_file_creations, num_shells, num_access_files, num_outbound_cmds, is_host_login, is_guest_login, count, srv_count, serror_rate, srv_serror_rate, rerror_rate, srv_rerror_rate, same_srv_rate, diff_srv_rate, srv_diff_host_rate, dst_host_count, dst_host_srv_count, dst_host_same_srv_rate, dst_host_diff_srv_rate, dst_host_same_src_port_rate, dst_host_srv_diff_host_rate, dst_host_serror_rate, dst_host_srv_serror_rate, dst_host_rerror_rate, dst_host_srv_rerror_rate, # Categorical features - Use integer/dummy values for simple Gradio input protocol_type, service, flag ): # --- A. Create raw DataFrame (Using the first 41 features, others are 0) --- # The full model requires 119 features, including one-hot encoded categories. # We create a template DataFrame with all 119 features initialized to zero data_dict = {name: [0.0] for name in feature_names} input_df = pd.DataFrame(data_dict) # Map the simple inputs to the DataFrame numerical_features = { 'duration': duration, 'src_bytes': src_bytes, 'dst_bytes': dst_bytes, 'land': land, 'wrong_fragment': wrong_fragment, 'urgent': urgent, 'hot': hot, 'num_failed_logins': num_failed_logins, 'logged_in': logged_in, 'num_compromised': num_compromised, 'root_shell': root_shell, 'su_attempted': su_attempted, 'num_root': num_root, 'num_file_creations': num_file_creations, 'num_shells': num_shells, 'num_access_files': num_access_files, 'num_outbound_cmds': num_outbound_cmds, 'is_host_login': is_host_login, 'is_guest_login': is_guest_login, 'count': count, 'srv_count': srv_count, 'serror_rate': serror_rate, 'srv_serror_rate': srv_serror_rate, 'rerror_rate': rerror_rate, 'srv_rerror_rate': srv_rerror_rate, 'same_srv_rate': same_srv_rate, 'diff_srv_rate': diff_srv_rate, 'srv_diff_host_rate': srv_diff_host_rate, 'dst_host_count': dst_host_count, 'dst_host_srv_count': dst_host_srv_count, 'dst_host_same_srv_rate': dst_host_same_srv_rate, 'dst_host_diff_srv_rate': dst_host_diff_srv_rate, 'dst_host_same_src_port_rate': dst_host_same_src_port_rate, 'dst_host_srv_diff_host_rate': dst_host_srv_diff_host_rate, 'dst_host_serror_rate': dst_host_serror_rate, 'dst_host_srv_serror_rate': dst_host_srv_serror_rate, 'dst_host_rerror_rate': dst_host_rerror_rate, 'dst_host_srv_rerror_rate': dst_host_srv_rerror_rate } # Update numerical features for col, val in numerical_features.items(): if col in input_df.columns: input_df[col] = val # Update one-hot encoded features if f'protocol_type_{protocol_type}' in input_df.columns: input_df[f'protocol_type_{protocol_type}'] = 1.0 if f'service_{service}' in input_df.columns: input_df[f'service_{service}'] = 1.0 if f'flag_{flag}' in input_df.columns: input_df[f'flag_{flag}'] = 1.0 # Ensure the order of columns matches the training data input_df = input_df[feature_names] # --- B. Scale the input data --- X_scaled = scaler.transform(input_df.values) # --- C. Reshape for CNN (3D input) --- X_cnn = X_scaled.reshape(1, X_scaled.shape[1], 1) # --- D. Predict --- prediction_proba = model.predict(X_cnn)[0][0] # --- E. Decode Result --- # Threshold at 0.5 prediction_class = (prediction_proba > 0.5).astype(int) # Decode the result (assuming 0 is 'normal' and 1 is 'attack' based on your previous output) decoded_prediction = le.classes_[prediction_class] # Format the output for Gradio probability_percent = f"{prediction_proba * 100:.2f}%" if decoded_prediction == 'attack': result_message = f"🚨 INTRUSION DETECTED! (Attack Probability: {probability_percent})" else: result_message = f"✅ Normal Traffic. (Attack Probability: {probability_percent})" return result_message # --- 3. Gradio Interface Setup --- # Create simplified input components for the most critical features inputs = [ gr.Number(label="Duration (seconds)", value=0), gr.Number(label="Source Bytes", value=491), gr.Number(label="Destination Bytes", value=0), gr.Dropdown(label="Protocol Type", choices=['tcp', 'udp', 'icmp'], value='tcp'), gr.Dropdown(label="Service", choices=['ftp_data', 'http', 'private', 'domain_u', 'other'], value='ftp_data'), gr.Dropdown(label="Flag", choices=['SF', 'S0', 'REJ', 'RSTO'], value='SF'), gr.Slider(0, 1, label="Logged In (1=Yes, 0=No)", step=1, value=1), gr.Number(label="Count (connections to same host)", value=2), gr.Number(label="Service Count (connections to same service)", value=2), gr.Slider(0.0, 1.0, label="Same Service Rate", step=0.01, value=1.0), gr.Slider(0.0, 1.0, label="SError Rate", step=0.01, value=0.0) ] # Add more numerical inputs (optional, for a complete interface) # NOTE: The function takes 39 raw inputs + 3 categoricals. We'll simplify the interface. # For simplicity, we are only showing the 11 most important inputs in the Gradio interface # and setting defaults for the remaining 31 numerical and 3 categorical features. # For a full interface, you would need to list all 41 raw features here! # For a practical demo, the current 11 inputs and the fixed defaults will work. # The function call below must pass ALL 39 numerical features (plus the 3 categoricals). # To make this demo work without 39 explicit inputs, we use a wrapper: def predict_wrapper(*args): # Map the simplified 11 inputs to the full 39+3 required by the main function simplified_inputs = args # Default values for all 39 numerical inputs, ordered by the original feature list defaults = [0.0] * 39 # Replace the defaults with the 11 provided inputs # The index mapping must be carefully checked against your feature_names list! # Simplified mapping for demonstration (adjust based on your full feature list order) # This is a highly simplified mapping and assumes the features are in this order: # Numerical features (39) defaults[0] = simplified_inputs[0] # duration defaults[1] = simplified_inputs[1] # src_bytes defaults[2] = simplified_inputs[2] # dst_bytes # ... many features skipped for simplicity ... defaults[8] = simplified_inputs[6] # logged_in (index 8) # ... many features skipped for simplicity ... defaults[19] = simplified_inputs[7] # count (index 19) defaults[20] = simplified_inputs[8] # srv_count (index 20) defaults[21] = simplified_inputs[9] # serror_rate (index 21) # Placeholder for the remaining features (need to be filled with defaults) # For a proper demo, it's better to use the full feature list: full_args = list(defaults) full_args[0] = simplified_inputs[0] # duration full_args[1] = simplified_inputs[1] # src_bytes full_args[2] = simplified_inputs[2] # dst_bytes full_args[8] = simplified_inputs[6] # logged_in full_args[19] = simplified_inputs[7] # count full_args[20] = simplified_inputs[8] # srv_count full_args[21] = simplified_inputs[9] # serror_rate full_args[22] = simplified_inputs[9] # srv_serror_rate (Assuming serror_rate = srv_serror_rate for simplicity) full_args[23] = simplified_inputs[10] # rerror_rate (Assuming rerror_rate = srv_rerror_rate for simplicity) full_args[24] = simplified_inputs[10] # srv_rerror_rate full_args[25] = simplified_inputs[9] # same_srv_rate (Assuming same_srv_rate is also serror_rate's complement) # The actual numerical inputs from the list of 39 in the original feature_names # The missing 31 features are currently stuck at 0.0 in the `defaults` list. # The three categorical inputs: protocol_type = simplified_inputs[3] service = simplified_inputs[4] flag = simplified_inputs[5] # Combine all numerical features and categorical strings final_args = full_args + [protocol_type, service, flag] # Call the main prediction function return predict_intrusion(*final_args) # Define the Gradio interface demo = gr.Interface( fn=predict_wrapper, inputs=inputs, outputs=gr.Label(label="Intrusion Detection Result"), title="Intrusion Detection System (KDD Cup '99)", description="Predicts whether a network connection is 'normal' or an 'attack' using a 1D CNN model. The model achieved 99.28% accuracy on the test set." ) # Launch the Gradio app if __name__ == "__main__": demo.launch()