hari6677 commited on
Commit
2aed8d6
·
verified ·
1 Parent(s): cb76da2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import joblib
5
+ from tensorflow.keras.models import load_model
6
+
7
+ # --- 1. Load Model and Preprocessing Components ---
8
+ try:
9
+ # Load the trained model
10
+ model = load_model('improved_intrusion_detection_model.h5')
11
+
12
+ # Load the scaler, label encoder, and feature names
13
+ scaler = joblib.load('final_standard_scaler.pkl')
14
+ le = joblib.load('final_label_encoder.pkl')
15
+ feature_names = joblib.load('final_feature_names.pkl')
16
+
17
+ print("Model and preprocessing components loaded successfully.")
18
+
19
+ except Exception as e:
20
+ print(f"Error loading files: {e}")
21
+ # Exit or raise error if crucial files are missing
22
+ raise
23
+
24
+ # --- 2. Define the Prediction Function ---
25
+ def predict_intrusion(
26
+ duration, src_bytes, dst_bytes, land, wrong_fragment, urgent, hot,
27
+ num_failed_logins, logged_in, num_compromised, root_shell, su_attempted,
28
+ num_root, num_file_creations, num_shells, num_access_files,
29
+ num_outbound_cmds, is_host_login, is_guest_login, count, srv_count,
30
+ serror_rate, srv_serror_rate, rerror_rate, srv_rerror_rate, same_srv_rate,
31
+ diff_srv_rate, srv_diff_host_rate, dst_host_count, dst_host_srv_count,
32
+ dst_host_same_srv_rate, dst_host_diff_srv_rate, dst_host_same_src_port_rate,
33
+ dst_host_srv_diff_host_rate, dst_host_serror_rate, dst_host_srv_serror_rate,
34
+ dst_host_rerror_rate, dst_host_srv_rerror_rate,
35
+ # Categorical features - Use integer/dummy values for simple Gradio input
36
+ protocol_type, service, flag
37
+ ):
38
+ # --- A. Create raw DataFrame (Using the first 41 features, others are 0) ---
39
+ # The full model requires 119 features, including one-hot encoded categories.
40
+
41
+ # We create a template DataFrame with all 119 features initialized to zero
42
+ data_dict = {name: [0.0] for name in feature_names}
43
+ input_df = pd.DataFrame(data_dict)
44
+
45
+ # Map the simple inputs to the DataFrame
46
+ numerical_features = {
47
+ 'duration': duration, 'src_bytes': src_bytes, 'dst_bytes': dst_bytes,
48
+ 'land': land, 'wrong_fragment': wrong_fragment, 'urgent': urgent,
49
+ 'hot': hot, 'num_failed_logins': num_failed_logins,
50
+ 'logged_in': logged_in, 'num_compromised': num_compromised,
51
+ 'root_shell': root_shell, 'su_attempted': su_attempted,
52
+ 'num_root': num_root, 'num_file_creations': num_file_creations,
53
+ 'num_shells': num_shells, 'num_access_files': num_access_files,
54
+ 'num_outbound_cmds': num_outbound_cmds, 'is_host_login': is_host_login,
55
+ 'is_guest_login': is_guest_login, 'count': count, 'srv_count': srv_count,
56
+ 'serror_rate': serror_rate, 'srv_serror_rate': srv_serror_rate,
57
+ 'rerror_rate': rerror_rate, 'srv_rerror_rate': srv_rerror_rate,
58
+ 'same_srv_rate': same_srv_rate, 'diff_srv_rate': diff_srv_rate,
59
+ 'srv_diff_host_rate': srv_diff_host_rate, 'dst_host_count': dst_host_count,
60
+ 'dst_host_srv_count': dst_host_srv_count,
61
+ 'dst_host_same_srv_rate': dst_host_same_srv_rate,
62
+ 'dst_host_diff_srv_rate': dst_host_diff_srv_rate,
63
+ 'dst_host_same_src_port_rate': dst_host_same_src_port_rate,
64
+ 'dst_host_srv_diff_host_rate': dst_host_srv_diff_host_rate,
65
+ 'dst_host_serror_rate': dst_host_serror_rate,
66
+ 'dst_host_srv_serror_rate': dst_host_srv_serror_rate,
67
+ 'dst_host_rerror_rate': dst_host_rerror_rate,
68
+ 'dst_host_srv_rerror_rate': dst_host_srv_rerror_rate
69
+ }
70
+
71
+ # Update numerical features
72
+ for col, val in numerical_features.items():
73
+ if col in input_df.columns:
74
+ input_df[col] = val
75
+
76
+ # Update one-hot encoded features
77
+ if f'protocol_type_{protocol_type}' in input_df.columns:
78
+ input_df[f'protocol_type_{protocol_type}'] = 1.0
79
+ if f'service_{service}' in input_df.columns:
80
+ input_df[f'service_{service}'] = 1.0
81
+ if f'flag_{flag}' in input_df.columns:
82
+ input_df[f'flag_{flag}'] = 1.0
83
+
84
+ # Ensure the order of columns matches the training data
85
+ input_df = input_df[feature_names]
86
+
87
+ # --- B. Scale the input data ---
88
+ X_scaled = scaler.transform(input_df.values)
89
+
90
+ # --- C. Reshape for CNN (3D input) ---
91
+ X_cnn = X_scaled.reshape(1, X_scaled.shape[1], 1)
92
+
93
+ # --- D. Predict ---
94
+ prediction_proba = model.predict(X_cnn)[0][0]
95
+
96
+ # --- E. Decode Result ---
97
+ # Threshold at 0.5
98
+ prediction_class = (prediction_proba > 0.5).astype(int)
99
+
100
+ # Decode the result (assuming 0 is 'normal' and 1 is 'attack' based on your previous output)
101
+ decoded_prediction = le.classes_[prediction_class]
102
+
103
+ # Format the output for Gradio
104
+ probability_percent = f"{prediction_proba * 100:.2f}%"
105
+
106
+ if decoded_prediction == 'attack':
107
+ result_message = f"🚨 INTRUSION DETECTED! (Attack Probability: {probability_percent})"
108
+ else:
109
+ result_message = f"✅ Normal Traffic. (Attack Probability: {probability_percent})"
110
+
111
+ return result_message
112
+
113
+ # --- 3. Gradio Interface Setup ---
114
+
115
+ # Create simplified input components for the most critical features
116
+ inputs = [
117
+ gr.Number(label="Duration (seconds)", value=0),
118
+ gr.Number(label="Source Bytes", value=491),
119
+ gr.Number(label="Destination Bytes", value=0),
120
+ gr.Dropdown(label="Protocol Type", choices=['tcp', 'udp', 'icmp'], value='tcp'),
121
+ gr.Dropdown(label="Service", choices=['ftp_data', 'http', 'private', 'domain_u', 'other'], value='ftp_data'),
122
+ gr.Dropdown(label="Flag", choices=['SF', 'S0', 'REJ', 'RSTO'], value='SF'),
123
+ gr.Slider(0, 1, label="Logged In (1=Yes, 0=No)", step=1, value=1),
124
+ gr.Number(label="Count (connections to same host)", value=2),
125
+ gr.Number(label="Service Count (connections to same service)", value=2),
126
+ gr.Slider(0.0, 1.0, label="Same Service Rate", step=0.01, value=1.0),
127
+ gr.Slider(0.0, 1.0, label="SError Rate", step=0.01, value=0.0)
128
+ ]
129
+
130
+ # Add more numerical inputs (optional, for a complete interface)
131
+ # NOTE: The function takes 39 raw inputs + 3 categoricals. We'll simplify the interface.
132
+ # For simplicity, we are only showing the 11 most important inputs in the Gradio interface
133
+ # and setting defaults for the remaining 31 numerical and 3 categorical features.
134
+
135
+ # For a full interface, you would need to list all 41 raw features here!
136
+ # For a practical demo, the current 11 inputs and the fixed defaults will work.
137
+
138
+ # The function call below must pass ALL 39 numerical features (plus the 3 categoricals).
139
+ # To make this demo work without 39 explicit inputs, we use a wrapper:
140
+
141
+ def predict_wrapper(*args):
142
+ # Map the simplified 11 inputs to the full 39+3 required by the main function
143
+ simplified_inputs = args
144
+
145
+ # Default values for all 39 numerical inputs, ordered by the original feature list
146
+ defaults = [0.0] * 39
147
+
148
+ # Replace the defaults with the 11 provided inputs
149
+ # The index mapping must be carefully checked against your feature_names list!
150
+
151
+ # Simplified mapping for demonstration (adjust based on your full feature list order)
152
+ # This is a highly simplified mapping and assumes the features are in this order:
153
+
154
+ # Numerical features (39)
155
+ defaults[0] = simplified_inputs[0] # duration
156
+ defaults[1] = simplified_inputs[1] # src_bytes
157
+ defaults[2] = simplified_inputs[2] # dst_bytes
158
+ # ... many features skipped for simplicity ...
159
+ defaults[8] = simplified_inputs[6] # logged_in (index 8)
160
+ # ... many features skipped for simplicity ...
161
+ defaults[19] = simplified_inputs[7] # count (index 19)
162
+ defaults[20] = simplified_inputs[8] # srv_count (index 20)
163
+ defaults[21] = simplified_inputs[9] # serror_rate (index 21)
164
+
165
+ # Placeholder for the remaining features (need to be filled with defaults)
166
+
167
+ # For a proper demo, it's better to use the full feature list:
168
+ full_args = list(defaults)
169
+ full_args[0] = simplified_inputs[0] # duration
170
+ full_args[1] = simplified_inputs[1] # src_bytes
171
+ full_args[2] = simplified_inputs[2] # dst_bytes
172
+ full_args[8] = simplified_inputs[6] # logged_in
173
+ full_args[19] = simplified_inputs[7] # count
174
+ full_args[20] = simplified_inputs[8] # srv_count
175
+ full_args[21] = simplified_inputs[9] # serror_rate
176
+ full_args[22] = simplified_inputs[9] # srv_serror_rate (Assuming serror_rate = srv_serror_rate for simplicity)
177
+ full_args[23] = simplified_inputs[10] # rerror_rate (Assuming rerror_rate = srv_rerror_rate for simplicity)
178
+ full_args[24] = simplified_inputs[10] # srv_rerror_rate
179
+ full_args[25] = simplified_inputs[9] # same_srv_rate (Assuming same_srv_rate is also serror_rate's complement)
180
+
181
+ # The actual numerical inputs from the list of 39 in the original feature_names
182
+ # The missing 31 features are currently stuck at 0.0 in the `defaults` list.
183
+
184
+ # The three categorical inputs:
185
+ protocol_type = simplified_inputs[3]
186
+ service = simplified_inputs[4]
187
+ flag = simplified_inputs[5]
188
+
189
+ # Combine all numerical features and categorical strings
190
+ final_args = full_args + [protocol_type, service, flag]
191
+
192
+ # Call the main prediction function
193
+ return predict_intrusion(*final_args)
194
+
195
+
196
+ # Define the Gradio interface
197
+ demo = gr.Interface(
198
+ fn=predict_wrapper,
199
+ inputs=inputs,
200
+ outputs=gr.Label(label="Intrusion Detection Result"),
201
+ title="Intrusion Detection System (KDD Cup '99)",
202
+ description="Predicts whether a network connection is 'normal' or an 'attack' using a 1D CNN model. The model achieved 99.28% accuracy on the test set."
203
+ )
204
+
205
+ # Launch the Gradio app
206
+ if __name__ == "__main__":
207
+ demo.launch()