SilverDragon9 commited on
Commit
5ef57e6
·
1 Parent(s): 47b79b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -70
app.py CHANGED
@@ -2,96 +2,67 @@ import gradio as gr
2
  import joblib
3
  import requests
4
  import os
5
-
6
- from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, AdaBoostClassifier
7
- from sklearn.tree import DecisionTreeClassifier
8
 
9
  # Load the saved models
10
  rf_model = joblib.load('rf_model.pkl')
11
- dt_model = joblib.load('decision_tree_model.pkl')
12
- bagging_model = joblib.load('model_bagging.pkl')
13
- ada_model = joblib.load('model_adaboost.pkl')
14
 
15
- # Define the feature names
16
  feature_names = [
17
- "src_ip", "src_port", "dst_ip", "dst_port", "proto", "service", "duration",
18
- "src_bytes", "dst_bytes", "conn_state", "missed_bytes", "src_pkts",
19
- "src_ip_bytes", "dst_pkts", "dst_ip_bytes", "dns_query", "dns_qclass",
20
- "dns_qtype", "dns_rcode", "dns_AA", "dns_RD", "dns_RA", "dns_rejected",
21
- "ssl_version", "ssl_cipher", "ssl_resumed", "ssl_established", "ssl_subject",
22
- "ssl_issuer", "http_trans_depth", "http_method", "http_uri", "http_version",
23
- "http_request_body_len", "http_response_body_len", "http_status_code",
24
- "http_user_agent", "http_orig_mime_types", "http_resp_mime_types",
25
- "weird_name", "weird_addl", "weird_notice", "label"
26
  ]
27
 
28
  class_labels = {
29
  0: "normal",
30
  1: "backdoor",
31
  2: "ddos",
32
- 3: "dos",
33
- 4: "injection",
34
- 5: "password",
35
- 6: "ransomware",
36
- 7: "scanning",
37
- 8: "xss",
38
- 9: "mitm"
39
  }
40
 
41
- def detect_intrusion(*feature_values, label_value, model_choice="Random Forest"):
42
- # Ensure the length of feature_values matches feature_names (excluding label)
43
- if len(feature_values) != len(feature_names) - 1:
44
- return "Please fill in all the required feature values."
45
 
46
- # Convert the input values to floats and match them with feature names
 
47
  try:
48
- feature_values = [float(value) for value in feature_values]
49
- label_value = int(label_value) # Label should be an integer
50
- except ValueError:
51
- return "Please enter valid numerical values for all fields, including the label."
52
-
53
- # Add the label to feature values
54
- feature_values.append(label_value)
55
-
56
- # Choose the model based on user selection
57
- if model_choice == "Random Forest":
58
- model = rf_model
59
- elif model_choice == "Decision Tree":
60
- model = dt_model
61
- elif model_choice == "Bagging Classifier":
62
- model = bagging_model
63
- elif model_choice == "AdaBoost Classifier":
64
- model = ada_model
65
- else:
66
- return "Invalid model choice!"
67
-
68
- # Predict the class (multi-class classification)
69
- prediction = model.predict([feature_values])
70
- predicted_class = prediction[0] # Get the predicted class (an integer between 0-9)
71
-
72
- # Notify the user of the detected attack or normal traffic
73
- if predicted_class == 0:
74
- return "No Intrusion Detected"
75
- else:
76
- return f"Intrusion Detected: {class_labels.get(predicted_class, 'Unknown Attack')}"
77
-
78
- # Create Gradio input fields for each feature (excluding label initially)
79
- inputs = [gr.Textbox(label=feature_name) for feature_name in feature_names[:-1]] # Exclude the last one (label) from inputs
80
 
81
- # Add label input field
82
- inputs.append(gr.Textbox(label="label"))
 
 
 
 
 
83
 
84
- # Add model choice dropdown
85
- inputs.append(gr.Dropdown(choices=["Random Forest", "Decision Tree", "Bagging Classifier", "AdaBoost Classifier"], label="Select Model"))
 
 
 
86
 
87
- # Create the Gradio interface
88
  iface = gr.Interface(
89
- fn=detect_intrusion,
90
- inputs=inputs, # Pass the list of inputs directly
 
 
91
  outputs="text",
92
  title="Intrusion Detection System",
93
- description="Fill in the blank fields for the network traffic features, the label value (0-9), and choose the model to detect intrusions."
 
 
 
 
 
 
94
  )
95
 
96
- # Launch the interface with a public shareable link
97
- iface.launch(share=True)
 
2
  import joblib
3
  import requests
4
  import os
5
+ from sklearn.ensemble import RandomForestClassifier
 
 
6
 
7
  # Load the saved models
8
  rf_model = joblib.load('rf_model.pkl')
 
 
 
9
 
10
+ # Define the feature names (excluding the target column 'type')
11
  feature_names = [
12
+ "date", "time", "door_state", "sphone_signal", "label"
 
 
 
 
 
 
 
 
13
  ]
14
 
15
  class_labels = {
16
  0: "normal",
17
  1: "backdoor",
18
  2: "ddos",
19
+ 3: "injection",
20
+ 4: "password",
21
+ 5: "ransomware",
22
+ 6: "scanning",
23
+ 7: "xss",
 
 
24
  }
25
 
26
+ # Placeholder model (replace with actual Random Forest model object)
27
+ rf_model = None # Load the actual trained Random Forest model here
 
 
28
 
29
+ def detect_intrusion(file):
30
+ # Read the uploaded log file as a CSV or structured data
31
  try:
32
+ log_data = pd.read_csv(file.name) # Use file.name to get the path for reading
33
+ except Exception as e:
34
+ return f"Error reading file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Check if all required feature columns are in the log file
37
+ missing_features = [feature for feature in feature_names if feature not in log_data.columns]
38
+ if missing_features:
39
+ return f"Missing features in file: {', '.join(missing_features)}"
40
+
41
+ # Extract the feature values (excluding the 'type' column which is the target)
42
+ feature_values = log_data[feature_names].astype(float).values
43
 
44
+ # Predict the class (multi-class classification) for each row in the log file
45
+ predictions = rf_model.predict(feature_values)
46
+
47
+ # Return only the 'Prediction' and 'label' columns
48
+ return log_data[['Prediction']].head().to_string()
49
 
50
+ # Create a Gradio interface
51
  iface = gr.Interface(
52
+ fn=detect_intrusion,
53
+ inputs=[
54
+ gr.File(label="Upload Log File (CSV format)") # File input
55
+ ],
56
  outputs="text",
57
  title="Intrusion Detection System",
58
+ description=("""
59
+ Upload a CSV log file containing the following features:
60
+ date, time, door_state, sphone_signal, label (without the 'type' column).
61
+ Example file structure:
62
+ date,time,door_state,sphone_signal,label
63
+ 2025-03-12,10:45:00,1,-85,normal
64
+ """)
65
  )
66
 
67
+ # Launch the interface locally for testing
68
+ iface.launch()