JasHugF commited on
Commit
ddc5c50
·
verified ·
1 Parent(s): 359227b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ # 1. Load your trained .keras model
6
+ # This line loads your entire model from the .keras file.
7
+ model = tf.keras.models.load_model("adversarially_trained_model.keras")
8
+ print("output_shape:", model.output_shape) # e.g., (None, 5)
9
+ CLASS_NAMES = [f'class_{i}' for i in range(23)]# Class names based on the attack mapping from the notebook
10
+ print(CLASS_NAMES)
11
+
12
+ def preprocess_single_record(record: dict, scaler, df_reference: pd.DataFrame) -> np.ndarray:
13
+ import pandas as pd
14
+
15
+ # Convert input dict to DataFrame
16
+ df_input = pd.DataFrame([record])
17
+
18
+ # Fill missing columns with 0 (for numeric) or mode (for categorical)
19
+ for col in df_reference.columns:
20
+ if col not in df_input.columns:
21
+ if df_reference[col].dtype == 'float64':
22
+ df_input[col] = 0.0
23
+ else:
24
+ df_input[col] = df_reference[col].mode()[0] if not df_reference[col].empty else 'unknown'
25
+
26
+ # Ensure correct column order
27
+ df_input = df_input[df_reference.columns]
28
+
29
+ # Convert numeric columns to float
30
+ numeric_cols = [
31
+ "duration", "src_bytes", "dst_bytes", "land", "wrong_fragment", "urgent",
32
+ "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
33
+ "su_attempted", "num_root", "num_file_creations", "num_shells",
34
+ "num_access_files", "num_outbound_cmds", "is_host_login", "is_guest_login",
35
+ "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
36
+ "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate",
37
+ "dst_host_count", "dst_host_srv_count", "dst_host_same_srv_rate",
38
+ "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
39
+ "dst_host_srv_diff_host_rate", "dst_host_serror_rate",
40
+ "dst_host_srv_serror_rate", "dst_host_rerror_rate", "dst_host_srv_rerror_rate"
41
+ ]
42
+ df_input[numeric_cols] = df_input[numeric_cols].astype(float)
43
+
44
+ # One-hot encode categorical features
45
+ df_categorical = pd.get_dummies(df_input[["protocol_type", "service", "flag"]])
46
+ df_numeric = df_input[numeric_cols]
47
+ df_final = pd.concat([df_numeric, df_categorical], axis=1)
48
+
49
+ # Align with reference columns (from training)
50
+ df_final = df_final.reindex(columns=df_reference.columns, fill_value=0)
51
+
52
+ # Scale using the provided scaler
53
+ X_processed = scaler.transform(df_final)
54
+
55
+ return X_processed
56
+
57
+ # 2. Create the prediction function
58
+ def predict_from_array(input_text):
59
+ """
60
+ Takes a comma-separated string of floats, preprocesses it,
61
+ and returns the model's prediction.
62
+ """
63
+
64
+ # 3. Parse the input string
65
+ try:
66
+ # Split by comma, strip whitespace, and convert to float
67
+ float_values = [float(x.strip()) for x in input_text.split(',')]
68
+
69
+ # Convert the list to a NumPy array
70
+ input_array = np.array(float_values)
71
+
72
+ except ValueError as e:
73
+ return f"Input Error: Please enter numbers only. Details: {e}"
74
+ except Exception as e:
75
+ return f"Error: {e}"
76
+
77
+ # 4. Reshape the array for the model
78
+ # !!! IMPORTANT !!!
79
+ # Most Keras models expect a batch. If your model was trained on
80
+ # inputs with shape (num_samples, num_features), you must reshape
81
+ # your 1D array to (1, num_features).
82
+
83
+ # This line assumes your model expects one "row" of features
84
+ try:
85
+ input_array = input_array.reshape(1, -1)
86
+ except ValueError as e:
87
+ return f"Shape Error: Model expects a different number of features. Got {len(float_values)}. Details: {e}"
88
+
89
+ # 5. Make prediction
90
+ predictions = model.predict(input_array)
91
+ scores = predictions[0]
92
+ # Ensure 1D array
93
+ scores = np.array(scores).reshape(-1)
94
+
95
+ # 6. Format the output
96
+
97
+ # --- OPTION A: For Classification ---
98
+ # Use this if your model outputs probabilities for classes
99
+ num_outputs = len(scores)
100
+ # Align class names to model outputs
101
+ if len(CLASS_NAMES) < num_outputs:
102
+ aligned_class_names = CLASS_NAMES + [f"class_{i}" for i in range(len(CLASS_NAMES), num_outputs)]
103
+ else:
104
+ aligned_class_names = CLASS_NAMES[:num_outputs]
105
+
106
+ # Normalize scores to probabilities if they don't sum to ~1
107
+ exp_scores = np.exp(scores - np.max(scores))
108
+ prob_scores = exp_scores / np.sum(exp_scores) if np.isfinite(exp_scores).all() and exp_scores.sum() > 0 else scores
109
+
110
+ confidences = {aligned_class_names[i]: float(prob_scores[i]) for i in range(num_outputs)}
111
+ return confidences
112
+
113
+ # --- OPTION B: For Regression (or single value output) ---
114
+ # Use this if your model outputs a single number
115
+ # predicted_value = float(scores[0])
116
+ # return f"Predicted Value: {predicted_value:.4f}"
117
+
118
+ # --- OPTION C: For Raw Array Output ---
119
+ # Just return the raw prediction scores
120
+ # return str(scores)
121
+
122
+
123
+ # --- Gradio UI ---
124
+
125
+ # 7. Define the Gradio Interface
126
+ demo = gr.Interface(
127
+ fn=predict_from_array,
128
+
129
+ # Input: A Textbox.
130
+ inputs=gr.Textbox(
131
+ label="Input Features",
132
+ placeholder="Enter comma-separated float values, e.g., 1.5, 0.8, -2.3, 4.0"
133
+ ),
134
+
135
+ # Output: A Textbox for regression (Option B) or Label for classification (Option A)
136
+ # Use gr.Label() if you return a dict for classification (Option A)
137
+ outputs=gr.Label(num_top_classes=3, label="Attack Type Predictions"),
138
+
139
+ # Use gr.Textbox() for regression (Option B) or raw output (Option C)
140
+ # outputs=gr.Textbox(label="Prediction"),
141
+
142
+ title="Network Intrusion Detection System (NIDS)",
143
+ description="Adversarially trained model for network attack classification. Input network traffic features as comma-separated values to detect attack types: normal, dos, probe, r2l, u2r, or other."
144
+ )
145
+
146
+ # 8. Launch the app
147
+ if __name__ == "__main__":
148
+ demo.launch()