File size: 9,527 Bytes
2aed8d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
import numpy as np
import pandas as pd
import joblib
from tensorflow.keras.models import load_model

# --- 1. Load Model and Preprocessing Components ---
try:
    # Load the trained model
    model = load_model('improved_intrusion_detection_model.h5')
    
    # Load the scaler, label encoder, and feature names
    scaler = joblib.load('final_standard_scaler.pkl')
    le = joblib.load('final_label_encoder.pkl')
    feature_names = joblib.load('final_feature_names.pkl')
    
    print("Model and preprocessing components loaded successfully.")
    
except Exception as e:
    print(f"Error loading files: {e}")
    # Exit or raise error if crucial files are missing
    raise

# --- 2. Define the Prediction Function ---
def predict_intrusion(
    duration, src_bytes, dst_bytes, land, wrong_fragment, urgent, hot,
    num_failed_logins, logged_in, num_compromised, root_shell, su_attempted,
    num_root, num_file_creations, num_shells, num_access_files, 
    num_outbound_cmds, is_host_login, is_guest_login, count, srv_count, 
    serror_rate, srv_serror_rate, rerror_rate, srv_rerror_rate, same_srv_rate,
    diff_srv_rate, srv_diff_host_rate, dst_host_count, dst_host_srv_count,
    dst_host_same_srv_rate, dst_host_diff_srv_rate, dst_host_same_src_port_rate,
    dst_host_srv_diff_host_rate, dst_host_serror_rate, dst_host_srv_serror_rate,
    dst_host_rerror_rate, dst_host_srv_rerror_rate, 
    # Categorical features - Use integer/dummy values for simple Gradio input
    protocol_type, service, flag
):
    # --- A. Create raw DataFrame (Using the first 41 features, others are 0) ---
    # The full model requires 119 features, including one-hot encoded categories.
    
    # We create a template DataFrame with all 119 features initialized to zero
    data_dict = {name: [0.0] for name in feature_names}
    input_df = pd.DataFrame(data_dict)
    
    # Map the simple inputs to the DataFrame
    numerical_features = {
        'duration': duration, 'src_bytes': src_bytes, 'dst_bytes': dst_bytes,
        'land': land, 'wrong_fragment': wrong_fragment, 'urgent': urgent,
        'hot': hot, 'num_failed_logins': num_failed_logins,
        'logged_in': logged_in, 'num_compromised': num_compromised,
        'root_shell': root_shell, 'su_attempted': su_attempted,
        'num_root': num_root, 'num_file_creations': num_file_creations,
        'num_shells': num_shells, 'num_access_files': num_access_files,
        'num_outbound_cmds': num_outbound_cmds, 'is_host_login': is_host_login,
        'is_guest_login': is_guest_login, 'count': count, 'srv_count': srv_count,
        'serror_rate': serror_rate, 'srv_serror_rate': srv_serror_rate,
        'rerror_rate': rerror_rate, 'srv_rerror_rate': srv_rerror_rate,
        'same_srv_rate': same_srv_rate, 'diff_srv_rate': diff_srv_rate,
        'srv_diff_host_rate': srv_diff_host_rate, 'dst_host_count': dst_host_count,
        'dst_host_srv_count': dst_host_srv_count,
        'dst_host_same_srv_rate': dst_host_same_srv_rate,
        'dst_host_diff_srv_rate': dst_host_diff_srv_rate,
        'dst_host_same_src_port_rate': dst_host_same_src_port_rate,
        'dst_host_srv_diff_host_rate': dst_host_srv_diff_host_rate,
        'dst_host_serror_rate': dst_host_serror_rate,
        'dst_host_srv_serror_rate': dst_host_srv_serror_rate,
        'dst_host_rerror_rate': dst_host_rerror_rate,
        'dst_host_srv_rerror_rate': dst_host_srv_rerror_rate
    }

    # Update numerical features
    for col, val in numerical_features.items():
        if col in input_df.columns:
            input_df[col] = val

    # Update one-hot encoded features
    if f'protocol_type_{protocol_type}' in input_df.columns:
        input_df[f'protocol_type_{protocol_type}'] = 1.0
    if f'service_{service}' in input_df.columns:
        input_df[f'service_{service}'] = 1.0
    if f'flag_{flag}' in input_df.columns:
        input_df[f'flag_{flag}'] = 1.0

    # Ensure the order of columns matches the training data
    input_df = input_df[feature_names]

    # --- B. Scale the input data ---
    X_scaled = scaler.transform(input_df.values)

    # --- C. Reshape for CNN (3D input) ---
    X_cnn = X_scaled.reshape(1, X_scaled.shape[1], 1)

    # --- D. Predict ---
    prediction_proba = model.predict(X_cnn)[0][0]
    
    # --- E. Decode Result ---
    # Threshold at 0.5
    prediction_class = (prediction_proba > 0.5).astype(int)
    
    # Decode the result (assuming 0 is 'normal' and 1 is 'attack' based on your previous output)
    decoded_prediction = le.classes_[prediction_class]

    # Format the output for Gradio
    probability_percent = f"{prediction_proba * 100:.2f}%"
    
    if decoded_prediction == 'attack':
        result_message = f"🚨 INTRUSION DETECTED! (Attack Probability: {probability_percent})"
    else:
        result_message = f"✅ Normal Traffic. (Attack Probability: {probability_percent})"
        
    return result_message

# --- 3. Gradio Interface Setup ---

# Create simplified input components for the most critical features
inputs = [
    gr.Number(label="Duration (seconds)", value=0),
    gr.Number(label="Source Bytes", value=491),
    gr.Number(label="Destination Bytes", value=0),
    gr.Dropdown(label="Protocol Type", choices=['tcp', 'udp', 'icmp'], value='tcp'),
    gr.Dropdown(label="Service", choices=['ftp_data', 'http', 'private', 'domain_u', 'other'], value='ftp_data'),
    gr.Dropdown(label="Flag", choices=['SF', 'S0', 'REJ', 'RSTO'], value='SF'),
    gr.Slider(0, 1, label="Logged In (1=Yes, 0=No)", step=1, value=1),
    gr.Number(label="Count (connections to same host)", value=2),
    gr.Number(label="Service Count (connections to same service)", value=2),
    gr.Slider(0.0, 1.0, label="Same Service Rate", step=0.01, value=1.0),
    gr.Slider(0.0, 1.0, label="SError Rate", step=0.01, value=0.0)
]

# Add more numerical inputs (optional, for a complete interface)
# NOTE: The function takes 39 raw inputs + 3 categoricals. We'll simplify the interface.
# For simplicity, we are only showing the 11 most important inputs in the Gradio interface
# and setting defaults for the remaining 31 numerical and 3 categorical features.

# For a full interface, you would need to list all 41 raw features here!
# For a practical demo, the current 11 inputs and the fixed defaults will work.

# The function call below must pass ALL 39 numerical features (plus the 3 categoricals).
# To make this demo work without 39 explicit inputs, we use a wrapper:

def predict_wrapper(*args):
    # Map the simplified 11 inputs to the full 39+3 required by the main function
    simplified_inputs = args
    
    # Default values for all 39 numerical inputs, ordered by the original feature list
    defaults = [0.0] * 39
    
    # Replace the defaults with the 11 provided inputs
    # The index mapping must be carefully checked against your feature_names list!
    
    # Simplified mapping for demonstration (adjust based on your full feature list order)
    # This is a highly simplified mapping and assumes the features are in this order:
    
    # Numerical features (39)
    defaults[0] = simplified_inputs[0]  # duration
    defaults[1] = simplified_inputs[1]  # src_bytes
    defaults[2] = simplified_inputs[2]  # dst_bytes
    # ... many features skipped for simplicity ...
    defaults[8] = simplified_inputs[6]  # logged_in (index 8)
    # ... many features skipped for simplicity ...
    defaults[19] = simplified_inputs[7] # count (index 19)
    defaults[20] = simplified_inputs[8] # srv_count (index 20)
    defaults[21] = simplified_inputs[9] # serror_rate (index 21)
    
    # Placeholder for the remaining features (need to be filled with defaults)
    
    # For a proper demo, it's better to use the full feature list:
    full_args = list(defaults)
    full_args[0] = simplified_inputs[0] # duration
    full_args[1] = simplified_inputs[1] # src_bytes
    full_args[2] = simplified_inputs[2] # dst_bytes
    full_args[8] = simplified_inputs[6] # logged_in
    full_args[19] = simplified_inputs[7] # count
    full_args[20] = simplified_inputs[8] # srv_count
    full_args[21] = simplified_inputs[9] # serror_rate
    full_args[22] = simplified_inputs[9] # srv_serror_rate (Assuming serror_rate = srv_serror_rate for simplicity)
    full_args[23] = simplified_inputs[10] # rerror_rate (Assuming rerror_rate = srv_rerror_rate for simplicity)
    full_args[24] = simplified_inputs[10] # srv_rerror_rate
    full_args[25] = simplified_inputs[9] # same_srv_rate (Assuming same_srv_rate is also serror_rate's complement)
    
    # The actual numerical inputs from the list of 39 in the original feature_names
    # The missing 31 features are currently stuck at 0.0 in the `defaults` list.
    
    # The three categorical inputs:
    protocol_type = simplified_inputs[3]
    service = simplified_inputs[4]
    flag = simplified_inputs[5]

    # Combine all numerical features and categorical strings
    final_args = full_args + [protocol_type, service, flag]
    
    # Call the main prediction function
    return predict_intrusion(*final_args)


# Define the Gradio interface
demo = gr.Interface(
    fn=predict_wrapper,
    inputs=inputs,
    outputs=gr.Label(label="Intrusion Detection Result"),
    title="Intrusion Detection System (KDD Cup '99)",
    description="Predicts whether a network connection is 'normal' or an 'attack' using a 1D CNN model. The model achieved 99.28% accuracy on the test set."
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()