ATllll commited on
Commit
f72ebae
Β·
verified Β·
1 Parent(s): b884894

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +15 -26
src/streamlit_app.py CHANGED
@@ -8,20 +8,19 @@ import tensorflow as tf
8
  from tensorflow.keras.models import load_model
9
  from tensorflow.keras.optimizers.legacy import SGD
10
 
11
- # ==== File Paths (update if needed) ====
12
-
13
- MODEL_FILE = ("src/model1.h5")
14
- WEIGHTS_FILE = ("src/weights.h5")
15
- SCALER_FILE = ("src/standard_scaler1.pkl")
16
- LABEL_ENCODER_FILE = ("src/label_encoder1.pkl")
17
 
18
  ENCODER_PATHS = {
19
- "proto": ("src/categorical_label_encoder_proto1.pkl"),
20
- "conn_state": ("src/categorical_label_encoder_conn_state1.pkl"),
21
- "history": ("src/categorical_label_encoder_history1.pkl")
22
  }
23
 
24
- SAMPLE_FILE = "src/sample_input.csv"
25
 
26
  # ==== Class for Malware Prediction ====
27
  class MalwareClassifier:
@@ -101,15 +100,10 @@ class MalwareClassifier:
101
 
102
  # ==== Streamlit UI ====
103
  def main():
104
- import os
105
  st.set_page_config(page_title="Malware Detection System", page_icon="πŸ›‘οΈ")
106
  st.title("πŸ›‘οΈ Malware Detection System")
107
  st.markdown("Upload a CSV file with network traffic logs to detect malware.")
108
 
109
- # πŸ› οΈ DEBUG INFO
110
-
111
- # πŸ”½ Sample CSV download
112
- SAMPLE_FILE = "src/sample_input.csv"
113
  st.markdown("πŸ“„ Need a sample file to test? Download below:")
114
 
115
  if os.path.exists(SAMPLE_FILE):
@@ -117,11 +111,11 @@ def main():
117
  st.download_button(
118
  label="πŸ“₯ Download Sample CSV",
119
  data=f,
120
- file_name="src/sample_input.csv",
121
  mime="text/csv"
122
  )
123
  else:
124
- st.warning("⚠️ Sample CSV not found. Please place 'sample_input.csv' in the app directory.")
125
 
126
  try:
127
  classifier = MalwareClassifier()
@@ -129,12 +123,12 @@ def main():
129
  st.error(f"❌ Model loading failed: {e}")
130
  return
131
 
132
- uploaded_file = st.file_uploader("πŸ“‚ Drag & drop or select a CSV file", type=["csv"])
133
 
134
  if uploaded_file is not None:
135
  try:
136
- df = pd.read_csv(uploaded_file, delimiter=',')
137
- df.columns = df.columns.str.strip() # βœ… Normalize column names
138
 
139
  required_prediction_columns = [
140
  'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
@@ -142,21 +136,16 @@ def main():
142
  'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
143
  ]
144
 
145
- # βœ… Check for missing columns
146
  missing_columns = set(required_prediction_columns) - set(df.columns)
147
  if missing_columns:
148
  st.error(f"❌ Missing required columns for prediction: {missing_columns}")
149
- st.write("πŸ“‹ Detected columns:", df.columns.tolist()) # Helpful debug info
150
  return
151
 
152
- # βœ… Extract necessary columns
153
  prediction_input = df[required_prediction_columns].copy()
154
-
155
- # βœ… Predict
156
  predictions = classifier.predict(prediction_input)
157
  st.success("βœ… Prediction complete!")
158
 
159
- # βœ… Display results
160
  for i, result in enumerate(predictions):
161
  st.subheader(f"Prediction {i + 1}")
162
  st.write(f"**Predicted Label:** {result['result']}")
 
8
  from tensorflow.keras.models import load_model
9
  from tensorflow.keras.optimizers.legacy import SGD
10
 
11
+ # ==== File Paths ====
12
+ MODEL_FILE = "src/model1.h5"
13
+ WEIGHTS_FILE = "src/weights.h5"
14
+ SCALER_FILE = "src/standard_scaler1.pkl"
15
+ LABEL_ENCODER_FILE = "src/label_encoder1.pkl"
 
16
 
17
  ENCODER_PATHS = {
18
+ "proto": "src/categorical_label_encoder_proto1.pkl",
19
+ "conn_state": "src/categorical_label_encoder_conn_state1.pkl",
20
+ "history": "src/categorical_label_encoder_history1.pkl"
21
  }
22
 
23
+ SAMPLE_FILE = "src/sample_input.csv"
24
 
25
  # ==== Class for Malware Prediction ====
26
  class MalwareClassifier:
 
100
 
101
  # ==== Streamlit UI ====
102
  def main():
 
103
  st.set_page_config(page_title="Malware Detection System", page_icon="πŸ›‘οΈ")
104
  st.title("πŸ›‘οΈ Malware Detection System")
105
  st.markdown("Upload a CSV file with network traffic logs to detect malware.")
106
 
 
 
 
 
107
  st.markdown("πŸ“„ Need a sample file to test? Download below:")
108
 
109
  if os.path.exists(SAMPLE_FILE):
 
111
  st.download_button(
112
  label="πŸ“₯ Download Sample CSV",
113
  data=f,
114
+ file_name="sample_input.csv", # βœ… Do NOT include `src/` in file_name
115
  mime="text/csv"
116
  )
117
  else:
118
+ st.warning("⚠️ 'sample_input.csv' not found in the 'src' directory.")
119
 
120
  try:
121
  classifier = MalwareClassifier()
 
123
  st.error(f"❌ Model loading failed: {e}")
124
  return
125
 
126
+ uploaded_file = st.file_uploader("πŸ“‚ Upload your network CSV", type=["csv"])
127
 
128
  if uploaded_file is not None:
129
  try:
130
+ df = pd.read_csv(uploaded_file)
131
+ df.columns = df.columns.str.strip()
132
 
133
  required_prediction_columns = [
134
  'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
 
136
  'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
137
  ]
138
 
 
139
  missing_columns = set(required_prediction_columns) - set(df.columns)
140
  if missing_columns:
141
  st.error(f"❌ Missing required columns for prediction: {missing_columns}")
142
+ st.write("πŸ“‹ Detected columns:", df.columns.tolist())
143
  return
144
 
 
145
  prediction_input = df[required_prediction_columns].copy()
 
 
146
  predictions = classifier.predict(prediction_input)
147
  st.success("βœ… Prediction complete!")
148
 
 
149
  for i, result in enumerate(predictions):
150
  st.subheader(f"Prediction {i + 1}")
151
  st.write(f"**Predicted Label:** {result['result']}")