File size: 10,911 Bytes
851c6ed
042de76
851c6ed
 
 
 
 
 
 
 
 
 
 
e59f971
 
851c6ed
 
e59f971
851c6ed
e59f971
851c6ed
 
 
 
 
 
 
 
 
 
 
 
e59f971
851c6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e59f971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851c6ed
 
 
 
 
 
 
e59f971
851c6ed
 
 
 
 
e59f971
 
 
 
 
 
 
 
 
851c6ed
 
 
 
 
 
e59f971
851c6ed
 
 
 
 
e59f971
 
851c6ed
 
e59f971
851c6ed
 
 
 
 
 
e59f971
851c6ed
 
 
 
e59f971
 
 
851c6ed
e59f971
851c6ed
 
 
 
 
 
e59f971
851c6ed
 
 
 
 
 
 
 
 
e59f971
 
851c6ed
e59f971
 
851c6ed
 
e59f971
 
 
 
 
 
851c6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e59f971
851c6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# GRADIO APPLICATION FOR HUGGING FACE SPACES
# Loads the trained CNN and scaler to provide a web interface for network anomaly prediction. #int

import os
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
import gradio as gr
from tensorflow.keras.models import load_model
from sklearn.preprocessing import LabelEncoder

# --- Model & Scaler Configuration ---
H5_MODEL_FILE = "intrusion_detector_model.h5"
SCALER_FILE_NAME = "scaler.pkl"
# Threshold optimized in Cell 11 for better Attack Recall
PREDICTION_THRESHOLD = 0.40 
FEATURE_COUNT = 40 

# Pre-defined list of all feature names (41 raw features)
FEATURE_NAMES = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
    'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 
    'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 
    'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 
    'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 
    'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 
    'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 
    'dst_host_serror_rate', 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 
    'dst_host_srv_rerror_rate'
]

# List of all possible service values (Must be comprehensive for correct OHE alignment)
SERVICES = [
    'http', 'smtp', 'ftp_data', 'private', 'ecr_i', 'other', 'domain_u', 
    'finger', 'telnet', 'ftp', 'pop_3', 'courier', 'eco_i', 'imap4', 
    'domain_n', 'auth', 'time', 'shell', 'login', 'hostnames', 'ntp_service', 
    'echo', 'discard', 'systat', 'ctf', 'ssh', 'iso_tsap', 'whois', 'remote_job', 
    'sunrpc', 'rje', 'gopher', 'netbios_ssn', 'pm_srv', 'mtp', 'exec', 'klogin', 
    'kshell', 'daytime', 'message', 'icmp', 'netstat', 'Z39_50', 'bgp', 'nnsp', 
    'ctinrp', 'IRC', 'urp_i', 'pop_2', 'aol', 'rev_telnet', 'tftp_u'
]

# List of all possible flag values
FLAGS = [
    'SF', 'S0', 'REJ', 'RSTO', 'SH', 'S1', 'S2', 'RSTOS0', 'S3', 'OTH', 'RSTR'
]

# List of all possible protocol types
PROTOCOLS = ['tcp', 'udp', 'icmp']

# --- Define ALL Expected OHE Columns ---
PROTOCOL_OHE = [f'protocol_type_{p}' for p in PROTOCOLS]
FLAG_OHE = [f'flag_{f}' for f in FLAGS]
SERVICE_OHE = [f'service_{s}' for s in SERVICES]

NUMERICAL_BINARY_COLS = [
    'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot', 
    'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 
    'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds', 
    'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 
    'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 
    'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 
    'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 
    'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
]

MASTER_OHE_COLUMNS = NUMERICAL_BINARY_COLS + PROTOCOL_OHE + SERVICE_OHE + FLAG_OHE


# Global artifacts
model = None
scaler = None
label_encoder = None
MAPPING = {'normal': 0, 'anomaly': 1}


# --- Model Loading and Initialization (CRITICAL STEP) ---

def load_artifacts():
    """Loads the trained model and scaler globally."""
    global model, scaler, label_encoder
    
    print("--- Starting Artifact Loading ---")

    # Check for file existence first
    if not os.path.exists(SCALER_FILE_NAME) or not os.path.exists(H5_MODEL_FILE):
        print(f"CRITICAL ERROR: One or both files are missing in the current directory:")
        print(f"  Expected Scaler: {SCALER_FILE_NAME} (Exists: {os.path.exists(SCALER_FILE_NAME)})")
        print(f"  Expected Model: {H5_MODEL_FILE} (Exists: {os.path.exists(H5_MODEL_FILE)})")
        print("Please ensure both files are uploaded to the root of your Hugging Face Space.")
        return False
    
    # 1. Load Scaler
    try:
        scaler = joblib.load(SCALER_FILE_NAME)
        print(f"βœ“ Scaler loaded from {SCALER_FILE_NAME}")
    except Exception as e:
        print(f"Error loading scaler. Check file format or compatibility: {e}")
        return False

    # 2. Load Model
    try:
        # Load in Keras H5 format
        # Setting compile=False often helps with deployment stability
        model = load_model(H5_MODEL_FILE, compile=False)
        print(f"βœ“ Model loaded from {H5_MODEL_FILE}")
    except Exception as e:
        print(f"Error loading model. Check Keras version compatibility: {e}")
        return False

    # 3. Initialize Label Encoder
    label_encoder = LabelEncoder()
    label_encoder.fit(list(MAPPING.keys()))
    print("βœ“ Label Encoder initialized.")
    print("--- Artifact Loading Complete ---")
    return True

# Load artifacts on startup
if not load_artifacts():
    # If loading failed, the prediction function will return the error message
    pass


# --- Prediction Function (Same as before) ---

def predict_intrusion(*inputs):
    """
    Takes 41 raw network features, preprocesses them, and makes a prediction.
    """
    if model is None or scaler is None:
        return "<h2 style='color: red; text-align: center;'>FATAL ERROR: Model Not Loaded. See Logs.</h2>", "N/A"

    # 1. Create a dictionary from the inputs
    raw_input_dict = {FEATURE_NAMES[i]: [inputs[i]] for i in range(len(FEATURE_NAMES))}
    df = pd.DataFrame(raw_input_dict)
    
    # 2. Apply One-Hot Encoding (OHE) for categorical features
    categorical_cols = ['protocol_type', 'service', 'flag']
    df = pd.get_dummies(df, columns=categorical_cols, prefix=categorical_cols)

    # 3. Re-align columns to match training data (CRITICAL FIX)
    df_aligned = df.reindex(columns=MASTER_OHE_COLUMNS, fill_value=0)
    
    # Drop the redundant categorical columns (if they weren't dropped by get_dummies)
    df_aligned = df_aligned.drop(columns=['protocol_type', 'service', 'flag'], errors='ignore')

    # 4. Scale and Reshape for CNN
    data_scaled = scaler.transform(df_aligned)
        
    # Check shape to ensure correct feature count before reshaping
    if data_scaled.shape[1] != FEATURE_COUNT:
        return f"SCALER ERROR: Expected {FEATURE_COUNT} features, got {data_scaled.shape[1]} after scaling.", "N/A"
        
    X_processed = data_scaled.reshape(1, FEATURE_COUNT, 1)

    # 5. Predict probability
    prediction_prob = model.predict(X_processed, verbose=0)[0][0]
    
    # 6. Apply optimized threshold (0.40)
    prediction_int = 1 if prediction_prob >= PREDICTION_THRESHOLD else 0
    
    # 7. Inverse transform the prediction
    prediction_label = label_encoder.inverse_transform([prediction_int])[0].upper()
    
    
    # 8. Determine result display
    if prediction_label == 'ANOMALY':
        color = "red"
        message = f"🚨 ANOMALY DETECTED! (Confidence: {prediction_prob:.4f})"
    else:
        color = "green"
        message = f"🟒 Connection is NORMAL. (Confidence: {1 - prediction_prob:.4f})"

    # Gradio requires HTML to display styled text
    html_output = f"<h2 style='color: {color}; text-align: center;'>{message}</h2>"
    
    return html_output, f"{prediction_prob:.4f}"


# --- Gradio Interface Definition (Same as before) ---

# Define input components corresponding to the 41 features
input_components = [
    gr.Number(label='duration (float, sec)', value=0.0),
    gr.Dropdown(label='protocol_type', choices=PROTOCOLS, value='tcp'),
    gr.Dropdown(label='service', choices=SERVICES, value='http'),
    gr.Dropdown(label='flag', choices=FLAGS, value='SF'),
    gr.Number(label='src_bytes (int)', value=491),
    gr.Number(label='dst_bytes (int)', value=0),
    gr.Dropdown(label='land (binary)', choices=[0, 1], value=0),
    gr.Number(label='wrong_fragment (int)', value=0),
    gr.Number(label='urgent (int)', value=0),
    gr.Number(label='hot (int)', value=0),
    gr.Number(label='num_failed_logins (int)', value=0),
    gr.Dropdown(label='logged_in (binary)', choices=[0, 1], value=0),
    gr.Number(label='num_compromised (int)', value=0),
    gr.Dropdown(label='root_shell (binary)', choices=[0, 1], value=0),
    gr.Dropdown(label='su_attempted (binary)', choices=[0, 1], value=0),
    gr.Number(label='num_root (int)', value=0),
    gr.Number(label='num_file_creations (int)', value=0),
    gr.Number(label='num_shells (int)', value=0),
    gr.Number(label='num_access_files (int)', value=0),
    gr.Number(label='num_outbound_cmds (int)', value=0),
    gr.Dropdown(label='is_host_login (binary)', choices=[0, 1], value=0),
    gr.Dropdown(label='is_guest_login (binary)', choices=[0, 1], value=0),
    gr.Number(label='count (float)', value=2.0),
    gr.Number(label='srv_count (float)', value=2.0),
    gr.Number(label='serror_rate (float)', value=0.0),
    gr.Number(label='srv_serror_rate (float)', value=0.0),
    gr.Number(label='rerror_rate (float)', value=0.0),
    gr.Number(label='srv_rerror_rate (float)', value=0.0),
    gr.Number(label='same_srv_rate (float)', value=1.0),
    gr.Number(label='diff_srv_rate (float)', value=0.0),
    gr.Number(label='srv_diff_host_rate (float)', value=0.0),
    gr.Number(label='dst_host_count (float)', value=150.0),
    gr.Number(label='dst_host_srv_count (float)', value=25.0),
    gr.Number(label='dst_host_same_srv_rate (float)', value=0.17),
    gr.Number(label='dst_host_diff_srv_rate (float)', value=0.03),
    gr.Number(label='dst_host_same_src_port_rate (float)', value=0.17),
    gr.Number(label='dst_host_srv_diff_host_rate (float)', value=0.0),
    gr.Number(label='dst_host_serror_rate (float)', value=0.0),
    gr.Number(label='dst_host_srv_serror_rate (float)', value=0.0),
    gr.Number(label='dst_host_rerror_rate (float)', value=0.05),
    gr.Number(label='dst_host_srv_rerror_rate (float)', value=0.0)
]

# Define output components
output_components = [
    gr.HTML(label="Prediction Result"),
    gr.Label(label="Attack Probability")
]


# Combine all into the Gradio interface
iface = gr.Interface(
    fn=predict_intrusion,
    inputs=input_components,
    outputs=output_components,
    title="CNN Network Intrusion Detector (KDDCup'99)",
    description=(
        "Enter the 41 features of a network connection record to determine if it is "
        "a **Normal** connection or an **Anomaly (Attack)**. This model is a 1D Convolutional Neural Network (CNN) "
        f"optimized for high Attack Recall (using a prediction threshold of **{PREDICTION_THRESHOLD}**).<br>"
        "Default values are set for a NORMAL FTP data connection."
    ),
    live=False,
    allow_flagging='never'
)

# Launch the interface (Hugging Face Spaces runs this automatically)
iface.launch()