Update app.py
Browse files
app.py
CHANGED
|
@@ -11,13 +11,13 @@ from tensorflow.keras.models import load_model
|
|
| 11 |
from sklearn.preprocessing import LabelEncoder
|
| 12 |
|
| 13 |
# --- Model & Scaler Configuration ---
|
| 14 |
-
H5_MODEL_FILE = "
|
| 15 |
-
SCALER_FILE_NAME = "
|
| 16 |
# Threshold optimized in Cell 11 for better Attack Recall
|
| 17 |
PREDICTION_THRESHOLD = 0.40
|
| 18 |
-
FEATURE_COUNT = 40
|
| 19 |
|
| 20 |
-
# Pre-defined list of all feature names
|
| 21 |
FEATURE_NAMES = [
|
| 22 |
'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
|
| 23 |
'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
|
|
@@ -30,8 +30,7 @@ FEATURE_NAMES = [
|
|
| 30 |
'dst_host_srv_rerror_rate'
|
| 31 |
]
|
| 32 |
|
| 33 |
-
# List of all possible service values (
|
| 34 |
-
# NOTE: In a real system, you would need the full list from your training data.
|
| 35 |
SERVICES = [
|
| 36 |
'http', 'smtp', 'ftp_data', 'private', 'ecr_i', 'other', 'domain_u',
|
| 37 |
'finger', 'telnet', 'ftp', 'pop_3', 'courier', 'eco_i', 'imap4',
|
|
@@ -50,6 +49,25 @@ FLAGS = [
|
|
| 50 |
# List of all possible protocol types
|
| 51 |
PROTOCOLS = ['tcp', 'udp', 'icmp']
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Global artifacts
|
| 54 |
model = None
|
| 55 |
scaler = None
|
|
@@ -57,50 +75,61 @@ label_encoder = None
|
|
| 57 |
MAPPING = {'normal': 0, 'anomaly': 1}
|
| 58 |
|
| 59 |
|
| 60 |
-
# --- Model Loading and Initialization ---
|
| 61 |
|
| 62 |
def load_artifacts():
|
| 63 |
"""Loads the trained model and scaler globally."""
|
| 64 |
global model, scaler, label_encoder
|
| 65 |
|
| 66 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# 1. Load Scaler
|
| 69 |
try:
|
| 70 |
scaler = joblib.load(SCALER_FILE_NAME)
|
| 71 |
print(f"✓ Scaler loaded from {SCALER_FILE_NAME}")
|
| 72 |
except Exception as e:
|
| 73 |
-
print(f"Error loading scaler: {e}")
|
| 74 |
return False
|
| 75 |
|
| 76 |
# 2. Load Model
|
| 77 |
try:
|
| 78 |
# Load in Keras H5 format
|
| 79 |
-
|
|
|
|
| 80 |
print(f"✓ Model loaded from {H5_MODEL_FILE}")
|
| 81 |
except Exception as e:
|
| 82 |
-
print(f"Error loading model: {e}")
|
| 83 |
return False
|
| 84 |
|
| 85 |
# 3. Initialize Label Encoder
|
| 86 |
label_encoder = LabelEncoder()
|
| 87 |
label_encoder.fit(list(MAPPING.keys()))
|
| 88 |
print("✓ Label Encoder initialized.")
|
|
|
|
| 89 |
return True
|
| 90 |
|
| 91 |
# Load artifacts on startup
|
| 92 |
if not load_artifacts():
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
-
# --- Prediction Function ---
|
| 97 |
|
| 98 |
def predict_intrusion(*inputs):
|
| 99 |
"""
|
| 100 |
Takes 41 raw network features, preprocesses them, and makes a prediction.
|
| 101 |
"""
|
| 102 |
if model is None or scaler is None:
|
| 103 |
-
return "ERROR: Model
|
| 104 |
|
| 105 |
# 1. Create a dictionary from the inputs
|
| 106 |
raw_input_dict = {FEATURE_NAMES[i]: [inputs[i]] for i in range(len(FEATURE_NAMES))}
|
|
@@ -110,31 +139,19 @@ def predict_intrusion(*inputs):
|
|
| 110 |
categorical_cols = ['protocol_type', 'service', 'flag']
|
| 111 |
df = pd.get_dummies(df, columns=categorical_cols, prefix=categorical_cols)
|
| 112 |
|
| 113 |
-
# 3. Re-align columns to match training data (CRITICAL
|
| 114 |
-
|
| 115 |
-
# then populates them with the values from the current input.
|
| 116 |
-
expected_features = [
|
| 117 |
-
'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot',
|
| 118 |
-
'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted',
|
| 119 |
-
'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
|
| 120 |
-
'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
|
| 121 |
-
'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate',
|
| 122 |
-
'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
|
| 123 |
-
'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
|
| 124 |
-
'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',
|
| 125 |
-
'protocol_type_icmp', 'protocol_type_tcp', 'protocol_type_udp', # Protocol one-hots
|
| 126 |
-
# NOTE: A real deployment needs ALL 1-hot columns defined.
|
| 127 |
-
# For this demo, we rely on the scaler.transform() to handle alignment.
|
| 128 |
-
]
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
# A full-scale alignment is too complex for this demo, so we'll
|
| 133 |
-
# rely on the subsequent scaling step to fit the 40 columns.
|
| 134 |
-
pass
|
| 135 |
|
| 136 |
# 4. Scale and Reshape for CNN
|
| 137 |
-
data_scaled = scaler.transform(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
X_processed = data_scaled.reshape(1, FEATURE_COUNT, 1)
|
| 139 |
|
| 140 |
# 5. Predict probability
|
|
@@ -161,7 +178,7 @@ def predict_intrusion(*inputs):
|
|
| 161 |
return html_output, f"{prediction_prob:.4f}"
|
| 162 |
|
| 163 |
|
| 164 |
-
# --- Gradio Interface Definition ---
|
| 165 |
|
| 166 |
# Define input components corresponding to the 41 features
|
| 167 |
input_components = [
|
|
|
|
| 11 |
from sklearn.preprocessing import LabelEncoder
|
| 12 |
|
| 13 |
# --- Model & Scaler Configuration ---
|
| 14 |
+
H5_MODEL_FILE = "intrusion_detector_model.h5"
|
| 15 |
+
SCALER_FILE_NAME = "scaler.pkl"
|
| 16 |
# Threshold optimized in Cell 11 for better Attack Recall
|
| 17 |
PREDICTION_THRESHOLD = 0.40
|
| 18 |
+
FEATURE_COUNT = 40
|
| 19 |
|
| 20 |
+
# Pre-defined list of all feature names (41 raw features)
|
| 21 |
FEATURE_NAMES = [
|
| 22 |
'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
|
| 23 |
'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
|
|
|
|
| 30 |
'dst_host_srv_rerror_rate'
|
| 31 |
]
|
| 32 |
|
| 33 |
+
# List of all possible service values (Must be comprehensive for correct OHE alignment)
|
|
|
|
| 34 |
SERVICES = [
|
| 35 |
'http', 'smtp', 'ftp_data', 'private', 'ecr_i', 'other', 'domain_u',
|
| 36 |
'finger', 'telnet', 'ftp', 'pop_3', 'courier', 'eco_i', 'imap4',
|
|
|
|
| 49 |
# List of all possible protocol types
|
| 50 |
PROTOCOLS = ['tcp', 'udp', 'icmp']
|
| 51 |
|
| 52 |
+
# --- Define ALL Expected OHE Columns ---
|
| 53 |
+
PROTOCOL_OHE = [f'protocol_type_{p}' for p in PROTOCOLS]
|
| 54 |
+
FLAG_OHE = [f'flag_{f}' for f in FLAGS]
|
| 55 |
+
SERVICE_OHE = [f'service_{s}' for s in SERVICES]
|
| 56 |
+
|
| 57 |
+
NUMERICAL_BINARY_COLS = [
|
| 58 |
+
'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot',
|
| 59 |
+
'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted',
|
| 60 |
+
'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
|
| 61 |
+
'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
|
| 62 |
+
'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate',
|
| 63 |
+
'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
|
| 64 |
+
'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
|
| 65 |
+
'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
MASTER_OHE_COLUMNS = NUMERICAL_BINARY_COLS + PROTOCOL_OHE + SERVICE_OHE + FLAG_OHE
|
| 69 |
+
|
| 70 |
+
|
| 71 |
# Global artifacts
|
| 72 |
model = None
|
| 73 |
scaler = None
|
|
|
|
| 75 |
MAPPING = {'normal': 0, 'anomaly': 1}
|
| 76 |
|
| 77 |
|
| 78 |
+
# --- Model Loading and Initialization (CRITICAL STEP) ---
|
| 79 |
|
| 80 |
def load_artifacts():
|
| 81 |
"""Loads the trained model and scaler globally."""
|
| 82 |
global model, scaler, label_encoder
|
| 83 |
|
| 84 |
+
print("--- Starting Artifact Loading ---")
|
| 85 |
+
|
| 86 |
+
# Check for file existence first
|
| 87 |
+
if not os.path.exists(SCALER_FILE_NAME) or not os.path.exists(H5_MODEL_FILE):
|
| 88 |
+
print(f"CRITICAL ERROR: One or both files are missing in the current directory:")
|
| 89 |
+
print(f" Expected Scaler: {SCALER_FILE_NAME} (Exists: {os.path.exists(SCALER_FILE_NAME)})")
|
| 90 |
+
print(f" Expected Model: {H5_MODEL_FILE} (Exists: {os.path.exists(H5_MODEL_FILE)})")
|
| 91 |
+
print("Please ensure both files are uploaded to the root of your Hugging Face Space.")
|
| 92 |
+
return False
|
| 93 |
|
| 94 |
# 1. Load Scaler
|
| 95 |
try:
|
| 96 |
scaler = joblib.load(SCALER_FILE_NAME)
|
| 97 |
print(f"✓ Scaler loaded from {SCALER_FILE_NAME}")
|
| 98 |
except Exception as e:
|
| 99 |
+
print(f"Error loading scaler. Check file format or compatibility: {e}")
|
| 100 |
return False
|
| 101 |
|
| 102 |
# 2. Load Model
|
| 103 |
try:
|
| 104 |
# Load in Keras H5 format
|
| 105 |
+
# Setting compile=False often helps with deployment stability
|
| 106 |
+
model = load_model(H5_MODEL_FILE, compile=False)
|
| 107 |
print(f"✓ Model loaded from {H5_MODEL_FILE}")
|
| 108 |
except Exception as e:
|
| 109 |
+
print(f"Error loading model. Check Keras version compatibility: {e}")
|
| 110 |
return False
|
| 111 |
|
| 112 |
# 3. Initialize Label Encoder
|
| 113 |
label_encoder = LabelEncoder()
|
| 114 |
label_encoder.fit(list(MAPPING.keys()))
|
| 115 |
print("✓ Label Encoder initialized.")
|
| 116 |
+
print("--- Artifact Loading Complete ---")
|
| 117 |
return True
|
| 118 |
|
| 119 |
# Load artifacts on startup
|
| 120 |
if not load_artifacts():
|
| 121 |
+
# If loading failed, the prediction function will return the error message
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
|
| 125 |
+
# --- Prediction Function (Same as before) ---
|
| 126 |
|
| 127 |
def predict_intrusion(*inputs):
|
| 128 |
"""
|
| 129 |
Takes 41 raw network features, preprocesses them, and makes a prediction.
|
| 130 |
"""
|
| 131 |
if model is None or scaler is None:
|
| 132 |
+
return "<h2 style='color: red; text-align: center;'>FATAL ERROR: Model Not Loaded. See Logs.</h2>", "N/A"
|
| 133 |
|
| 134 |
# 1. Create a dictionary from the inputs
|
| 135 |
raw_input_dict = {FEATURE_NAMES[i]: [inputs[i]] for i in range(len(FEATURE_NAMES))}
|
|
|
|
| 139 |
categorical_cols = ['protocol_type', 'service', 'flag']
|
| 140 |
df = pd.get_dummies(df, columns=categorical_cols, prefix=categorical_cols)
|
| 141 |
|
| 142 |
+
# 3. Re-align columns to match training data (CRITICAL FIX)
|
| 143 |
+
df_aligned = df.reindex(columns=MASTER_OHE_COLUMNS, fill_value=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# Drop the redundant categorical columns (if they weren't dropped by get_dummies)
|
| 146 |
+
df_aligned = df_aligned.drop(columns=['protocol_type', 'service', 'flag'], errors='ignore')
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# 4. Scale and Reshape for CNN
|
| 149 |
+
data_scaled = scaler.transform(df_aligned)
|
| 150 |
+
|
| 151 |
+
# Check shape to ensure correct feature count before reshaping
|
| 152 |
+
if data_scaled.shape[1] != FEATURE_COUNT:
|
| 153 |
+
return f"SCALER ERROR: Expected {FEATURE_COUNT} features, got {data_scaled.shape[1]} after scaling.", "N/A"
|
| 154 |
+
|
| 155 |
X_processed = data_scaled.reshape(1, FEATURE_COUNT, 1)
|
| 156 |
|
| 157 |
# 5. Predict probability
|
|
|
|
| 178 |
return html_output, f"{prediction_prob:.4f}"
|
| 179 |
|
| 180 |
|
| 181 |
+
# --- Gradio Interface Definition (Same as before) ---
|
| 182 |
|
| 183 |
# Define input components corresponding to the 41 features
|
| 184 |
input_components = [
|