File size: 3,818 Bytes
96adf77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import joblib
import numpy as np
import tensorflow as tf
import pandas as pd
import gradio as gr
import os

# --- Configuration ---
MODEL_PATH = "improved_intrusion_detection_model_SIMPLIFIED.h5"
SCALER_PATH = "standard_scaler.pkl"
FEATURES_PATH = "feature_names.pkl"
LABEL_ENCODER_PATH = "label_encoder.pkl"
FINAL_THRESHOLD = 0.7 
CATEGORICAL_COLS = ['protocol_type', 'service', 'flag']

# --- Load Artifacts ---
# The model and preprocessors are loaded once when the app starts
try:
    model = tf.keras.models.load_model(MODEL_PATH)
    scaler = joblib.load(SCALER_PATH)
    final_features = joblib.load(FEATURES_PATH)
    label_encoder = joblib.load(LABEL_ENCODER_PATH)
    print("Model and preprocessors loaded successfully.")
except Exception as e:
    print(f"Error loading model artifacts: {e}")
    # Exit if essential files are missing
    exit()

def preprocess_and_predict(*raw_input_features):
    """
    Takes raw inputs, preprocesses them exactly like the training data, 
    and returns the prediction.
    """
    
    # 1. Convert tuple of inputs to a single list/Series
    input_data = pd.Series(raw_input_features, index=raw_input_features_names)
    
    # Reshape for single sample processing
    df_raw = pd.DataFrame([input_data])
    
    # 2. One-Hot Encode Categorical Features
    df_encoded = pd.get_dummies(df_raw, columns=CATEGORICAL_COLS)
    
    # 3. Align columns with training data and fill missing features with 0
    # This is CRUCIAL for deployment correctness.
    df_encoded = df_encoded.reindex(columns=final_features, fill_value=0)
    
    # 4. Scale Numerical Features
    X_scaled = scaler.transform(df_encoded)
    
    # 5. Reshape for CNN Input: (1 sample, 122 features, 1 channel)
    X_cnn = X_scaled.reshape(X_scaled.shape[0], X_scaled.shape[1], 1)
    
    # 6. Predict Probability
    y_pred_proba = model.predict(X_cnn, verbose=0)[0][0]
    
    # 7. Apply Fixed Threshold and Decode Label
    if y_pred_proba >= FINAL_THRESHOLD:
        prediction_int = 1
    else:
        prediction_int = 0
    
    # Decode 0 or 1 back to 'normal' or 'attack'
    final_label = label_encoder.inverse_transform([prediction_int])[0]
    
    return final_label, f"Confidence: {y_pred_proba:.4f}"


# --- Gradio Interface Setup ---

# Create a list of the 41 feature names (excluding 'label') for the UI
raw_input_features_names = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
    'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins',
    'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 'num_root',
    'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
    'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate',
    'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate',
    'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
    'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
    'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
    'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
]

# Create Gradio inputs corresponding to the feature types
inputs = [
    gr.Number(label=name, value=0) if name not in CATEGORICAL_COLS else 
    gr.Textbox(label=name, value='tcp') # Default example for categorical
    for name in raw_input_features_names
]

iface = gr.Interface(
    fn=preprocess_and_predict,
    inputs=inputs,
    outputs=[gr.Label(label="Prediction"), gr.Textbox(label="Details")],
    title="CNN Network Intrusion Detector (KDD)",
    description="Enter the 41 raw network traffic features to classify the connection as 'normal' or 'attack'. Optimized with 0.7 threshold.",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()