hari6677 commited on
Commit
96adf77
·
verified ·
1 Parent(s): c54fdc5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import pandas as pd
5
+ import gradio as gr
6
+ import os
7
+
8
+ # --- Configuration ---
9
+ MODEL_PATH = "improved_intrusion_detection_model_SIMPLIFIED.h5"
10
+ SCALER_PATH = "standard_scaler.pkl"
11
+ FEATURES_PATH = "feature_names.pkl"
12
+ LABEL_ENCODER_PATH = "label_encoder.pkl"
13
+ FINAL_THRESHOLD = 0.7
14
+ CATEGORICAL_COLS = ['protocol_type', 'service', 'flag']
15
+
16
+ # --- Load Artifacts ---
17
+ # The model and preprocessors are loaded once when the app starts
18
+ try:
19
+ model = tf.keras.models.load_model(MODEL_PATH)
20
+ scaler = joblib.load(SCALER_PATH)
21
+ final_features = joblib.load(FEATURES_PATH)
22
+ label_encoder = joblib.load(LABEL_ENCODER_PATH)
23
+ print("Model and preprocessors loaded successfully.")
24
+ except Exception as e:
25
+ print(f"Error loading model artifacts: {e}")
26
+ # Exit if essential files are missing
27
+ exit()
28
+
29
+ def preprocess_and_predict(*raw_input_features):
30
+ """
31
+ Takes raw inputs, preprocesses them exactly like the training data,
32
+ and returns the prediction.
33
+ """
34
+
35
+ # 1. Convert tuple of inputs to a single list/Series
36
+ input_data = pd.Series(raw_input_features, index=raw_input_features_names)
37
+
38
+ # Reshape for single sample processing
39
+ df_raw = pd.DataFrame([input_data])
40
+
41
+ # 2. One-Hot Encode Categorical Features
42
+ df_encoded = pd.get_dummies(df_raw, columns=CATEGORICAL_COLS)
43
+
44
+ # 3. Align columns with training data and fill missing features with 0
45
+ # This is CRUCIAL for deployment correctness.
46
+ df_encoded = df_encoded.reindex(columns=final_features, fill_value=0)
47
+
48
+ # 4. Scale Numerical Features
49
+ X_scaled = scaler.transform(df_encoded)
50
+
51
+ # 5. Reshape for CNN Input: (1 sample, 122 features, 1 channel)
52
+ X_cnn = X_scaled.reshape(X_scaled.shape[0], X_scaled.shape[1], 1)
53
+
54
+ # 6. Predict Probability
55
+ y_pred_proba = model.predict(X_cnn, verbose=0)[0][0]
56
+
57
+ # 7. Apply Fixed Threshold and Decode Label
58
+ if y_pred_proba >= FINAL_THRESHOLD:
59
+ prediction_int = 1
60
+ else:
61
+ prediction_int = 0
62
+
63
+ # Decode 0 or 1 back to 'normal' or 'attack'
64
+ final_label = label_encoder.inverse_transform([prediction_int])[0]
65
+
66
+ return final_label, f"Confidence: {y_pred_proba:.4f}"
67
+
68
+
69
+ # --- Gradio Interface Setup ---
70
+
71
+ # Create a list of the 41 feature names (excluding 'label') for the UI
72
+ raw_input_features_names = [
73
+ 'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
74
+ 'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins',
75
+ 'logged_in', 'num_compromised', 'root_shell', 'su_attempted', 'num_root',
76
+ 'num_file_creations', 'num_shells', 'num_access_files', 'num_outbound_cmds',
77
+ 'is_host_login', 'is_guest_login', 'count', 'srv_count', 'serror_rate',
78
+ 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate',
79
+ 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
80
+ 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
81
+ 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
82
+ 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
83
+ ]
84
+
85
+ # Create Gradio inputs corresponding to the feature types
86
+ inputs = [
87
+ gr.Number(label=name, value=0) if name not in CATEGORICAL_COLS else
88
+ gr.Textbox(label=name, value='tcp') # Default example for categorical
89
+ for name in raw_input_features_names
90
+ ]
91
+
92
+ iface = gr.Interface(
93
+ fn=preprocess_and_predict,
94
+ inputs=inputs,
95
+ outputs=[gr.Label(label="Prediction"), gr.Textbox(label="Details")],
96
+ title="CNN Network Intrusion Detector (KDD)",
97
+ description="Enter the 41 raw network traffic features to classify the connection as 'normal' or 'attack'. Optimized with 0.7 threshold.",
98
+ allow_flagging="never"
99
+ )
100
+
101
+ if __name__ == "__main__":
102
+ iface.launch()