hari6677 commited on
Commit
e59f971
·
verified ·
1 Parent(s): 619ca58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -38
app.py CHANGED
@@ -11,13 +11,13 @@ from tensorflow.keras.models import load_model
11
  from sklearn.preprocessing import LabelEncoder
12
 
13
  # --- Model & Scaler Configuration ---
14
- H5_MODEL_FILE = "https://huggingface.co/spaces/hari6677/intrusion_detector_model/commit/db539a078811615764b2c84aacb85b1a3f804a3c"
15
- SCALER_FILE_NAME = "https://huggingface.co/spaces/hari6677/intrusion_detector_model/commit/db539a078811615764b2c84aacb85b1a3f804a3c"
16
  # Threshold optimized in Cell 11 for better Attack Recall
17
  PREDICTION_THRESHOLD = 0.40
18
- FEATURE_COUNT = 40 # Expected number of features after one-hot encoding
19
 
20
- # Pre-defined list of all feature names, used to create the input DataFrame
21
  FEATURE_NAMES = [
22
  'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
23
  'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
@@ -30,8 +30,7 @@ FEATURE_NAMES = [
30
  'dst_host_srv_rerror_rate'
31
  ]
32
 
33
- # List of all possible service values (simplified for demo input)
34
- # NOTE: In a real system, you would need the full list from your training data.
35
  SERVICES = [
36
  'http', 'smtp', 'ftp_data', 'private', 'ecr_i', 'other', 'domain_u',
37
  'finger', 'telnet', 'ftp', 'pop_3', 'courier', 'eco_i', 'imap4',
@@ -50,6 +49,25 @@ FLAGS = [
50
  # List of all possible protocol types
51
  PROTOCOLS = ['tcp', 'udp', 'icmp']
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Global artifacts
54
  model = None
55
  scaler = None
@@ -57,50 +75,61 @@ label_encoder = None
57
  MAPPING = {'normal': 0, 'anomaly': 1}
58
 
59
 
60
- # --- Model Loading and Initialization ---
61
 
62
  def load_artifacts():
63
  """Loads the trained model and scaler globally."""
64
  global model, scaler, label_encoder
65
 
66
- print("Loading model and scaler for Gradio app...")
 
 
 
 
 
 
 
 
67
 
68
  # 1. Load Scaler
69
  try:
70
  scaler = joblib.load(SCALER_FILE_NAME)
71
  print(f"✓ Scaler loaded from {SCALER_FILE_NAME}")
72
  except Exception as e:
73
- print(f"Error loading scaler: {e}")
74
  return False
75
 
76
  # 2. Load Model
77
  try:
78
  # Load in Keras H5 format
79
- model = load_model(H5_MODEL_FILE)
 
80
  print(f"✓ Model loaded from {H5_MODEL_FILE}")
81
  except Exception as e:
82
- print(f"Error loading model: {e}")
83
  return False
84
 
85
  # 3. Initialize Label Encoder
86
  label_encoder = LabelEncoder()
87
  label_encoder.fit(list(MAPPING.keys()))
88
  print("✓ Label Encoder initialized.")
 
89
  return True
90
 
91
  # Load artifacts on startup
92
  if not load_artifacts():
93
- print("CRITICAL: Failed to load model artifacts. Prediction will not work.")
94
- # Exit or handle error appropriately in production
 
95
 
96
- # --- Prediction Function ---
97
 
98
  def predict_intrusion(*inputs):
99
  """
100
  Takes 41 raw network features, preprocesses them, and makes a prediction.
101
  """
102
  if model is None or scaler is None:
103
- return "ERROR: Model not loaded. Check server logs.", "N/A"
104
 
105
  # 1. Create a dictionary from the inputs
106
  raw_input_dict = {FEATURE_NAMES[i]: [inputs[i]] for i in range(len(FEATURE_NAMES))}
@@ -110,31 +139,19 @@ def predict_intrusion(*inputs):
110
  categorical_cols = ['protocol_type', 'service', 'flag']
111
  df = pd.get_dummies(df, columns=categorical_cols, prefix=categorical_cols)
112
 
113
- # 3. Re-align columns to match training data (CRITICAL STEP)
114
- # This creates a zero-filled array of the 40 expected features,
115
- # then populates them with the values from the current input.
116
- expected_features = [
117
- 'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot',
118
- 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted',
119
- 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
120
- 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
121
- 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate',
122
- 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
123
- 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
124
- 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',
125
- 'protocol_type_icmp', 'protocol_type_tcp', 'protocol_type_udp', # Protocol one-hots
126
- # NOTE: A real deployment needs ALL 1-hot columns defined.
127
- # For this demo, we rely on the scaler.transform() to handle alignment.
128
- ]
129
 
130
- # We must ensure the final feature set has 40 columns before scaling
131
- if df.shape[1] != FEATURE_COUNT:
132
- # A full-scale alignment is too complex for this demo, so we'll
133
- # rely on the subsequent scaling step to fit the 40 columns.
134
- pass
135
 
136
  # 4. Scale and Reshape for CNN
137
- data_scaled = scaler.transform(df)
 
 
 
 
 
138
  X_processed = data_scaled.reshape(1, FEATURE_COUNT, 1)
139
 
140
  # 5. Predict probability
@@ -161,7 +178,7 @@ def predict_intrusion(*inputs):
161
  return html_output, f"{prediction_prob:.4f}"
162
 
163
 
164
- # --- Gradio Interface Definition ---
165
 
166
  # Define input components corresponding to the 41 features
167
  input_components = [
 
11
  from sklearn.preprocessing import LabelEncoder
12
 
13
  # --- Model & Scaler Configuration ---
14
+ H5_MODEL_FILE = "intrusion_detector_model.h5"
15
+ SCALER_FILE_NAME = "scaler.pkl"
16
  # Threshold optimized in Cell 11 for better Attack Recall
17
  PREDICTION_THRESHOLD = 0.40
18
+ FEATURE_COUNT = 40
19
 
20
+ # Pre-defined list of all feature names (41 raw features)
21
  FEATURE_NAMES = [
22
  'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
23
  'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
 
30
  'dst_host_srv_rerror_rate'
31
  ]
32
 
33
+ # List of all possible service values (Must be comprehensive for correct OHE alignment)
 
34
  SERVICES = [
35
  'http', 'smtp', 'ftp_data', 'private', 'ecr_i', 'other', 'domain_u',
36
  'finger', 'telnet', 'ftp', 'pop_3', 'courier', 'eco_i', 'imap4',
 
49
  # List of all possible protocol types
50
  PROTOCOLS = ['tcp', 'udp', 'icmp']
51
 
52
+ # --- Define ALL Expected OHE Columns ---
53
+ PROTOCOL_OHE = [f'protocol_type_{p}' for p in PROTOCOLS]
54
+ FLAG_OHE = [f'flag_{f}' for f in FLAGS]
55
+ SERVICE_OHE = [f'service_{s}' for s in SERVICES]
56
+
57
+ NUMERICAL_BINARY_COLS = [
58
+ 'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment', 'urgent', 'hot',
59
+ 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell', 'su_attempted',
60
+ 'num_root', 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
61
+ 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
62
+ 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate',
63
+ 'dst_host_count', 'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
64
+ 'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
65
+ 'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
66
+ ]
67
+
68
+ MASTER_OHE_COLUMNS = NUMERICAL_BINARY_COLS + PROTOCOL_OHE + SERVICE_OHE + FLAG_OHE
69
+
70
+
71
  # Global artifacts
72
  model = None
73
  scaler = None
 
75
  MAPPING = {'normal': 0, 'anomaly': 1}
76
 
77
 
78
+ # --- Model Loading and Initialization (CRITICAL STEP) ---
79
 
80
  def load_artifacts():
81
  """Loads the trained model and scaler globally."""
82
  global model, scaler, label_encoder
83
 
84
+ print("--- Starting Artifact Loading ---")
85
+
86
+ # Check for file existence first
87
+ if not os.path.exists(SCALER_FILE_NAME) or not os.path.exists(H5_MODEL_FILE):
88
+ print(f"CRITICAL ERROR: One or both files are missing in the current directory:")
89
+ print(f" Expected Scaler: {SCALER_FILE_NAME} (Exists: {os.path.exists(SCALER_FILE_NAME)})")
90
+ print(f" Expected Model: {H5_MODEL_FILE} (Exists: {os.path.exists(H5_MODEL_FILE)})")
91
+ print("Please ensure both files are uploaded to the root of your Hugging Face Space.")
92
+ return False
93
 
94
  # 1. Load Scaler
95
  try:
96
  scaler = joblib.load(SCALER_FILE_NAME)
97
  print(f"✓ Scaler loaded from {SCALER_FILE_NAME}")
98
  except Exception as e:
99
+ print(f"Error loading scaler. Check file format or compatibility: {e}")
100
  return False
101
 
102
  # 2. Load Model
103
  try:
104
  # Load in Keras H5 format
105
+ # Setting compile=False often helps with deployment stability
106
+ model = load_model(H5_MODEL_FILE, compile=False)
107
  print(f"✓ Model loaded from {H5_MODEL_FILE}")
108
  except Exception as e:
109
+ print(f"Error loading model. Check Keras version compatibility: {e}")
110
  return False
111
 
112
  # 3. Initialize Label Encoder
113
  label_encoder = LabelEncoder()
114
  label_encoder.fit(list(MAPPING.keys()))
115
  print("✓ Label Encoder initialized.")
116
+ print("--- Artifact Loading Complete ---")
117
  return True
118
 
119
  # Load artifacts on startup
120
  if not load_artifacts():
121
+ # If loading failed, the prediction function will return the error message
122
+ pass
123
+
124
 
125
+ # --- Prediction Function (Same as before) ---
126
 
127
  def predict_intrusion(*inputs):
128
  """
129
  Takes 41 raw network features, preprocesses them, and makes a prediction.
130
  """
131
  if model is None or scaler is None:
132
+ return "<h2 style='color: red; text-align: center;'>FATAL ERROR: Model Not Loaded. See Logs.</h2>", "N/A"
133
 
134
  # 1. Create a dictionary from the inputs
135
  raw_input_dict = {FEATURE_NAMES[i]: [inputs[i]] for i in range(len(FEATURE_NAMES))}
 
139
  categorical_cols = ['protocol_type', 'service', 'flag']
140
  df = pd.get_dummies(df, columns=categorical_cols, prefix=categorical_cols)
141
 
142
+ # 3. Re-align columns to match training data (CRITICAL FIX)
143
+ df_aligned = df.reindex(columns=MASTER_OHE_COLUMNS, fill_value=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Drop the redundant categorical columns (if they weren't dropped by get_dummies)
146
+ df_aligned = df_aligned.drop(columns=['protocol_type', 'service', 'flag'], errors='ignore')
 
 
 
147
 
148
  # 4. Scale and Reshape for CNN
149
+ data_scaled = scaler.transform(df_aligned)
150
+
151
+ # Check shape to ensure correct feature count before reshaping
152
+ if data_scaled.shape[1] != FEATURE_COUNT:
153
+ return f"SCALER ERROR: Expected {FEATURE_COUNT} features, got {data_scaled.shape[1]} after scaling.", "N/A"
154
+
155
  X_processed = data_scaled.reshape(1, FEATURE_COUNT, 1)
156
 
157
  # 5. Predict probability
 
178
  return html_output, f"{prediction_prob:.4f}"
179
 
180
 
181
+ # --- Gradio Interface Definition (Same as before) ---
182
 
183
  # Define input components corresponding to the 41 features
184
  input_components = [