MrUtakata commited on
Commit
faa66dd
·
verified ·
1 Parent(s): 511e3f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -13,7 +13,6 @@ mi_features = joblib.load("selected_mi_features.pkl")
13
  roa_features = joblib.load("selected_roa_features.pkl")
14
  class_names = joblib.load("class_names.pkl")
15
 
16
- # Manually defined feature names (from your dataset)
17
  FEATURE_COLUMNS = [
18
  "flow_duration", "header_length", "protocol_type", "duration", "rate", "srate", "drate",
19
  "fin_flag_number", "syn_flag_number", "rst_flag_number", "psh_flag_number", "ack_flag_number",
@@ -33,28 +32,33 @@ default_row = (
33
  "9.5,10.39230485,0,0,0,141.55"
34
  )
35
 
36
- user_input = st.text_area("Paste the feature row here (no label column)", default_row, height=150)
37
 
38
  if st.button("🔍 Predict Attack Type"):
39
  try:
40
- # Auto-detect delimiter (tab or comma)
41
  delimiter = '\t' if '\t' in user_input else ','
42
 
43
- # Parse input row
44
- input_list = [float(x.strip()) for x in user_input.strip().split(delimiter)]
 
 
 
 
 
 
 
45
 
46
  if len(input_list) != len(FEATURE_COLUMNS):
47
- st.error(f"🚫 Expected {len(FEATURE_COLUMNS)} values but got {len(input_list)}.")
48
  else:
49
  input_df = pd.DataFrame([input_list], columns=FEATURE_COLUMNS)
50
 
51
- # Preprocessing
52
  scaled = scaler.transform(input_df)
53
  var_filtered = var_thresh.transform(scaled)
54
  mi_selected = pd.DataFrame(var_filtered, columns=np.array(FEATURE_COLUMNS)[var_thresh.get_support()])
55
  final_features = mi_selected[mi_features].iloc[:, roa_features]
56
 
57
- # Prediction
58
  prediction = model.predict(final_features)
59
  predicted_label = label_encoder.inverse_transform(prediction)[0]
60
 
 
13
  roa_features = joblib.load("selected_roa_features.pkl")
14
  class_names = joblib.load("class_names.pkl")
15
 
 
16
  FEATURE_COLUMNS = [
17
  "flow_duration", "header_length", "protocol_type", "duration", "rate", "srate", "drate",
18
  "fin_flag_number", "syn_flag_number", "rst_flag_number", "psh_flag_number", "ack_flag_number",
 
32
  "9.5,10.39230485,0,0,0,141.55"
33
  )
34
 
35
+ user_input = st.text_area("Paste the feature row here (label column optional)", default_row, height=150)
36
 
37
  if st.button("🔍 Predict Attack Type"):
38
  try:
 
39
  delimiter = '\t' if '\t' in user_input else ','
40
 
41
+ input_parts = user_input.strip().split(delimiter)
42
+
43
+ # Handle optional label at the end
44
+ try:
45
+ _ = float(input_parts[-1]) # If this fails, it's a string label
46
+ except ValueError:
47
+ input_parts = input_parts[:-1] # Drop the label
48
+
49
+ input_list = [float(x.strip()) for x in input_parts]
50
 
51
  if len(input_list) != len(FEATURE_COLUMNS):
52
+ st.error(f"🚫 Expected {len(FEATURE_COLUMNS)} features but got {len(input_list)}.")
53
  else:
54
  input_df = pd.DataFrame([input_list], columns=FEATURE_COLUMNS)
55
 
56
+ # Preprocessing pipeline
57
  scaled = scaler.transform(input_df)
58
  var_filtered = var_thresh.transform(scaled)
59
  mi_selected = pd.DataFrame(var_filtered, columns=np.array(FEATURE_COLUMNS)[var_thresh.get_support()])
60
  final_features = mi_selected[mi_features].iloc[:, roa_features]
61
 
 
62
  prediction = model.predict(final_features)
63
  predicted_label = label_encoder.inverse_transform(prediction)[0]
64