Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +15 -26
src/streamlit_app.py
CHANGED
|
@@ -8,20 +8,19 @@ import tensorflow as tf
|
|
| 8 |
from tensorflow.keras.models import load_model
|
| 9 |
from tensorflow.keras.optimizers.legacy import SGD
|
| 10 |
|
| 11 |
-
# ==== File Paths
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
LABEL_ENCODER_FILE = ("src/label_encoder1.pkl")
|
| 17 |
|
| 18 |
ENCODER_PATHS = {
|
| 19 |
-
"proto":
|
| 20 |
-
"conn_state":
|
| 21 |
-
"history":
|
| 22 |
}
|
| 23 |
|
| 24 |
-
SAMPLE_FILE = "src/sample_input.csv"
|
| 25 |
|
| 26 |
# ==== Class for Malware Prediction ====
|
| 27 |
class MalwareClassifier:
|
|
@@ -101,15 +100,10 @@ class MalwareClassifier:
|
|
| 101 |
|
| 102 |
# ==== Streamlit UI ====
|
| 103 |
def main():
|
| 104 |
-
import os
|
| 105 |
st.set_page_config(page_title="Malware Detection System", page_icon="π‘οΈ")
|
| 106 |
st.title("π‘οΈ Malware Detection System")
|
| 107 |
st.markdown("Upload a CSV file with network traffic logs to detect malware.")
|
| 108 |
|
| 109 |
-
# π οΈ DEBUG INFO
|
| 110 |
-
|
| 111 |
-
# π½ Sample CSV download
|
| 112 |
-
SAMPLE_FILE = "src/sample_input.csv"
|
| 113 |
st.markdown("π Need a sample file to test? Download below:")
|
| 114 |
|
| 115 |
if os.path.exists(SAMPLE_FILE):
|
|
@@ -117,11 +111,11 @@ def main():
|
|
| 117 |
st.download_button(
|
| 118 |
label="π₯ Download Sample CSV",
|
| 119 |
data=f,
|
| 120 |
-
file_name="
|
| 121 |
mime="text/csv"
|
| 122 |
)
|
| 123 |
else:
|
| 124 |
-
st.warning("β οΈ
|
| 125 |
|
| 126 |
try:
|
| 127 |
classifier = MalwareClassifier()
|
|
@@ -129,12 +123,12 @@ def main():
|
|
| 129 |
st.error(f"β Model loading failed: {e}")
|
| 130 |
return
|
| 131 |
|
| 132 |
-
uploaded_file = st.file_uploader("π
|
| 133 |
|
| 134 |
if uploaded_file is not None:
|
| 135 |
try:
|
| 136 |
-
df = pd.read_csv(uploaded_file
|
| 137 |
-
df.columns = df.columns.str.strip()
|
| 138 |
|
| 139 |
required_prediction_columns = [
|
| 140 |
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
|
|
@@ -142,21 +136,16 @@ def main():
|
|
| 142 |
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
|
| 143 |
]
|
| 144 |
|
| 145 |
-
# β
Check for missing columns
|
| 146 |
missing_columns = set(required_prediction_columns) - set(df.columns)
|
| 147 |
if missing_columns:
|
| 148 |
st.error(f"β Missing required columns for prediction: {missing_columns}")
|
| 149 |
-
st.write("π Detected columns:", df.columns.tolist())
|
| 150 |
return
|
| 151 |
|
| 152 |
-
# β
Extract necessary columns
|
| 153 |
prediction_input = df[required_prediction_columns].copy()
|
| 154 |
-
|
| 155 |
-
# β
Predict
|
| 156 |
predictions = classifier.predict(prediction_input)
|
| 157 |
st.success("β
Prediction complete!")
|
| 158 |
|
| 159 |
-
# β
Display results
|
| 160 |
for i, result in enumerate(predictions):
|
| 161 |
st.subheader(f"Prediction {i + 1}")
|
| 162 |
st.write(f"**Predicted Label:** {result['result']}")
|
|
|
|
| 8 |
from tensorflow.keras.models import load_model
|
| 9 |
from tensorflow.keras.optimizers.legacy import SGD
|
| 10 |
|
| 11 |
+
# ==== File Paths ====
|
| 12 |
+
MODEL_FILE = "src/model1.h5"
|
| 13 |
+
WEIGHTS_FILE = "src/weights.h5"
|
| 14 |
+
SCALER_FILE = "src/standard_scaler1.pkl"
|
| 15 |
+
LABEL_ENCODER_FILE = "src/label_encoder1.pkl"
|
|
|
|
| 16 |
|
| 17 |
ENCODER_PATHS = {
|
| 18 |
+
"proto": "src/categorical_label_encoder_proto1.pkl",
|
| 19 |
+
"conn_state": "src/categorical_label_encoder_conn_state1.pkl",
|
| 20 |
+
"history": "src/categorical_label_encoder_history1.pkl"
|
| 21 |
}
|
| 22 |
|
| 23 |
+
SAMPLE_FILE = "src/sample_input.csv"
|
| 24 |
|
| 25 |
# ==== Class for Malware Prediction ====
|
| 26 |
class MalwareClassifier:
|
|
|
|
| 100 |
|
| 101 |
# ==== Streamlit UI ====
|
| 102 |
def main():
|
|
|
|
| 103 |
st.set_page_config(page_title="Malware Detection System", page_icon="π‘οΈ")
|
| 104 |
st.title("π‘οΈ Malware Detection System")
|
| 105 |
st.markdown("Upload a CSV file with network traffic logs to detect malware.")
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
st.markdown("π Need a sample file to test? Download below:")
|
| 108 |
|
| 109 |
if os.path.exists(SAMPLE_FILE):
|
|
|
|
| 111 |
st.download_button(
|
| 112 |
label="π₯ Download Sample CSV",
|
| 113 |
data=f,
|
| 114 |
+
file_name="sample_input.csv", # β
Do NOT include `src/` in file_name
|
| 115 |
mime="text/csv"
|
| 116 |
)
|
| 117 |
else:
|
| 118 |
+
st.warning("β οΈ 'sample_input.csv' not found in the 'src' directory.")
|
| 119 |
|
| 120 |
try:
|
| 121 |
classifier = MalwareClassifier()
|
|
|
|
| 123 |
st.error(f"β Model loading failed: {e}")
|
| 124 |
return
|
| 125 |
|
| 126 |
+
uploaded_file = st.file_uploader("π Upload your network CSV", type=["csv"])
|
| 127 |
|
| 128 |
if uploaded_file is not None:
|
| 129 |
try:
|
| 130 |
+
df = pd.read_csv(uploaded_file)
|
| 131 |
+
df.columns = df.columns.str.strip()
|
| 132 |
|
| 133 |
required_prediction_columns = [
|
| 134 |
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
|
|
|
|
| 136 |
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
|
| 137 |
]
|
| 138 |
|
|
|
|
| 139 |
missing_columns = set(required_prediction_columns) - set(df.columns)
|
| 140 |
if missing_columns:
|
| 141 |
st.error(f"β Missing required columns for prediction: {missing_columns}")
|
| 142 |
+
st.write("π Detected columns:", df.columns.tolist())
|
| 143 |
return
|
| 144 |
|
|
|
|
| 145 |
prediction_input = df[required_prediction_columns].copy()
|
|
|
|
|
|
|
| 146 |
predictions = classifier.predict(prediction_input)
|
| 147 |
st.success("β
Prediction complete!")
|
| 148 |
|
|
|
|
| 149 |
for i, result in enumerate(predictions):
|
| 150 |
st.subheader(f"Prediction {i + 1}")
|
| 151 |
st.write(f"**Predicted Label:** {result['result']}")
|