hari6677 commited on
Commit
b748463
·
verified ·
1 Parent(s): c96de15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py CHANGED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ import joblib
6
+ import pickle
7
+ import os
8
+
9
+ # --- 1. CONFIGURATION AND FILE LOADING ---
10
+
11
+ # Define file paths (assuming you'll upload your improved model)
12
+ MODEL_PATH = 'improved_intrusion_detection_model.h5'
13
+ SCALER_PATH = 'standard_scaler.pkl'
14
+ FEATURE_NAMES_PATH = 'feature_names.pkl'
15
+
16
+ # Define the 41 original raw features expected from the user input
17
+ # NOTE: This list needs to be manually defined based on the KDD dataset structure.
18
+ # The 'feature_names.pkl' you provided contains the FINAL 119 feature names.
19
+ RAW_41_FEATURES = [
20
+ 'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
21
+ 'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
22
+ 'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
23
+ 'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
24
+ 'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
25
+ 'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate',
26
+ 'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
27
+ 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
28
+ 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
29
+ 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate'
30
+ ]
31
+
32
+ # Identify categorical columns from the raw features
33
+ CATEGORICAL_COLS = ['protocol_type', 'service', 'flag']
34
+ NUMERICAL_COLS = [col for col in RAW_41_FEATURES if col not in CATEGORICAL_COLS]
35
+
36
+
37
+ try:
38
+ # Load Model (assuming it's in the directory)
39
+ model = tf.keras.models.load_model(MODEL_PATH)
40
+
41
+ # Load Preprocessing Objects
42
+ scaler = joblib.load(SCALER_PATH)
43
+
44
+ # Load final 119 feature names list
45
+ # The feature_names.pkl file contains the FINAL 119 column names, including OHE columns.
46
+ with open(FEATURE_NAMES_PATH, 'rb') as f:
47
+ FINAL_119_COLUMNS = pickle.load(f).tolist()
48
+
49
+ # --- Derived Configuration ---
50
+ # The final columns must match the scaler's feature count
51
+ if scaler.n_features_in_ != len(FINAL_119_COLUMNS):
52
+ raise ValueError(f"Scaler expects {scaler.n_features_in_} features, but feature_names.pkl has {len(FINAL_119_COLUMNS)}. Check file consistency.")
53
+
54
+ except (FileNotFoundError, ValueError) as e:
55
+ print(f"FATAL ERROR: Failed to load required file or file inconsistent: {e}")
56
+ print("Please ensure your improved model (.h5) and all .pkl files are in the same folder.")
57
+ raise
58
+
59
+ # --- 2. PREDICTION FUNCTION ---
60
+
61
+ def predict_attack(*raw_input_values):
62
+ """
63
+ Processes the 41 raw user inputs, converts them to 119 scaled features, and predicts.
64
+ """
65
+ if len(raw_input_values) != len(RAW_41_FEATURES):
66
+ return f'<h1 style="color:red; font-size:24px;">Input Error: Expected {len(RAW_41_FEATURES)} features, received {len(raw_input_values)}.</h1>'
67
+
68
+ # 1. Create a raw DataFrame from the user input
69
+ raw_df = pd.DataFrame([raw_input_values], columns=RAW_41_FEATURES)
70
+
71
+ # Ensure numerical columns are numeric type
72
+ for col in NUMERICAL_COLS:
73
+ raw_df[col] = pd.to_numeric(raw_df[col], errors='coerce').fillna(0.0)
74
+
75
+
76
+ # 2. One-Hot Encoding
77
+ # Use pandas get_dummies on the categorical columns
78
+ df_encoded = pd.get_dummies(raw_df, columns=CATEGORICAL_COLS, dtype=float)
79
+
80
+ # 3. Align and Reorder Features to match the 119 FINAL_119_COLUMNS list
81
+ # This crucial step ensures the exact order and column presence (filling missing with 0)
82
+ X_processed = df_encoded.reindex(columns=FINAL_119_COLUMNS, fill_value=0)
83
+
84
+ # Convert to NumPy array
85
+ X_array = X_processed.values.astype(np.float32)
86
+
87
+ # 4. Standard Scaling (on the entire 119-feature vector)
88
+ X_scaled = scaler.transform(X_array)
89
+
90
+ # 5. Reshape for CNN (1, 119, 1)
91
+ X_cnn = X_scaled.reshape((1, X_scaled.shape[1], 1))
92
+
93
+ # 6. Predict
94
+ prediction = model.predict(X_cnn, verbose=0)
95
+
96
+ # Determine result (binary classification threshold 0.5)
97
+ probability = prediction[0][0]
98
+
99
+ if probability > 0.5:
100
+ # Detected as Attack
101
+ result = f"🚨 ATTACK DETECTED! (Probability: {probability*100:.2f}%)"
102
+ color = "red"
103
+ else:
104
+ # Detected as Normal
105
+ result = f"✅ Normal Traffic (Probability: {(1 - probability)*100:.2f}%)"
106
+ color = "green"
107
+
108
+ return f'<h1 style="color:{color}; font-size:24px;">{result}</h1>'
109
+
110
+ # --- 3. GRADIO INTERFACE SETUP ---
111
+
112
+ # Use placeholders for the categorical choices since we don't have the categorical map file
113
+ # This assumes the user will input valid strings like 'tcp', 'http', 'SF'.
114
+ # For a robust deployed app, you should load the unique categorical values.
115
+ # For demonstration, we'll use simple Textboxes or common examples.
116
+ input_components = []
117
+ for name in RAW_41_FEATURES:
118
+ if name in NUMERICAL_COLS:
119
+ input_components.append(gr.Number(label=name, value=0.0))
120
+ elif name == 'protocol_type':
121
+ input_components.append(gr.Dropdown(label=name, choices=['tcp', 'udp', 'icmp'], value='tcp'))
122
+ elif name == 'flag':
123
+ input_components.append(gr.Dropdown(label=name, choices=['SF', 'S0', 'REJ', 'RSTR', 'OTH'], value='SF'))
124
+ elif name == 'service':
125
+ # Service has 70+ values; using Textbox is best unless all choices are loaded
126
+ input_components.append(gr.Textbox(label=name, value='http'))
127
+ else:
128
+ input_components.append(gr.Textbox(label=name, value='0'))
129
+
130
+ # Example Neptune DoS attack vector: [0, tcp, private, S0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 10, 1, 1, 0, 0, 0.04, 0.06, 0, 255, 10, 0.04, 0.06, 0, 0, 1, 1, 0, 0]
131
+ example_attack_data = [
132
+ 0.0, 'tcp', 'private', 'S0', 0.0, 0.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
133
+ 255, 10, 1.0, 1.0, 0.0, 0.0, 0.04, 0.06, 0.0, 255, 10, 0.04, 0.06, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0
134
+ ]
135
+
136
+
137
+ # Gradio Interface
138
+ iface = gr.Interface(
139
+ fn=predict_attack,
140
+ inputs=input_components,
141
+ outputs=gr.HTML(label="Prediction Result"),
142
+ title="KDD Intrusion Detection System (CNN)",
143
+ description="Enter the 41 raw features of a network connection. The model predicts if the traffic is 'normal' or an 'attack'.",
144
+ examples=[example_attack_data]
145
+ )
146
+
147
+ # Launch the app
148
+ if __name__ == "__main__":
149
+ iface.launch(share=False)