IotNetworkMalwareDetection / src /streamlit_app.py
ATllll's picture
Update src/streamlit_app.py
f72ebae verified
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import os
import ipaddress
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers.legacy import SGD
# ==== File Paths ====
MODEL_FILE = "src/model1.h5"
WEIGHTS_FILE = "src/weights.h5"
SCALER_FILE = "src/standard_scaler1.pkl"
LABEL_ENCODER_FILE = "src/label_encoder1.pkl"
ENCODER_PATHS = {
"proto": "src/categorical_label_encoder_proto1.pkl",
"conn_state": "src/categorical_label_encoder_conn_state1.pkl",
"history": "src/categorical_label_encoder_history1.pkl"
}
SAMPLE_FILE = "src/sample_input.csv"
# ==== Class for Malware Prediction ====
class MalwareClassifier:
def __init__(self):
for path in [MODEL_FILE, WEIGHTS_FILE, SCALER_FILE, LABEL_ENCODER_FILE] + list(ENCODER_PATHS.values()):
if not os.path.exists(path):
raise FileNotFoundError(f"Missing file: {path}")
self.model = load_model(MODEL_FILE, compile=False)
self.model.load_weights(WEIGHTS_FILE)
self.model.compile(optimizer=SGD(), loss="categorical_crossentropy", metrics=["accuracy"])
self.scaler = joblib.load(SCALER_FILE)
self.label_encoder = joblib.load(LABEL_ENCODER_FILE)
self.encoders = {k: joblib.load(v) for k, v in ENCODER_PATHS.items()}
def _validate_numeric_column(self, col_name, values):
if not values.astype(str).str.isdigit().all():
raise ValueError(f"Non-integer value found in column {col_name}")
if (values < 0).any():
raise ValueError(f"Negative value found in column {col_name}")
def _validate_ip_address(self, col_name, values):
for value in values:
try:
ipaddress.ip_address(value)
except ValueError:
raise ValueError(f"Invalid IP address in column {col_name}: {value}")
def _validate_input_data(self, data):
required_columns = {
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
}
if data.empty:
raise ValueError("CSV is empty.")
missing = required_columns - set(data.columns)
if missing:
raise ValueError(f"Missing required columns: {missing}")
for col in data.columns:
if col in {'id.orig_p', 'id.resp_p', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'}:
self._validate_numeric_column(col, data[col])
elif col in {'id.orig_h', 'id.resp_h'}:
self._validate_ip_address(col, data[col])
def _encode_data(self, df):
for col in ['proto', 'conn_state', 'history']:
df[col] = self.encoders[col].transform(df[col])
df['id.orig_h'] = df['id.orig_h'].apply(lambda x: int(ipaddress.ip_address(x)))
df['id.resp_h'] = df['id.resp_h'].apply(lambda x: int(ipaddress.ip_address(x)))
return df
def _preprocess_data(self, df):
self._validate_input_data(df)
df = self._encode_data(df)
model_columns = [
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
]
return self.scaler.transform(df[model_columns])
def predict(self, df):
preprocessed = self._preprocess_data(df)
preds = self.model.predict(preprocessed)
results = []
for pred in preds:
label = self.label_encoder.inverse_transform([np.argmax(pred)])[0]
scores = {label: f"{score:.6f}" for label, score in zip(self.label_encoder.classes_, pred)}
results.append({"result": label, "scores": scores})
return results
# ==== Streamlit UI ====
def main():
st.set_page_config(page_title="Malware Detection System", page_icon="πŸ›‘οΈ")
st.title("πŸ›‘οΈ Malware Detection System")
st.markdown("Upload a CSV file with network traffic logs to detect malware.")
st.markdown("πŸ“„ Need a sample file to test? Download below:")
if os.path.exists(SAMPLE_FILE):
with open(SAMPLE_FILE, "rb") as f:
st.download_button(
label="πŸ“₯ Download Sample CSV",
data=f,
file_name="sample_input.csv", # βœ… Do NOT include `src/` in file_name
mime="text/csv"
)
else:
st.warning("⚠️ 'sample_input.csv' not found in the 'src' directory.")
try:
classifier = MalwareClassifier()
except Exception as e:
st.error(f"❌ Model loading failed: {e}")
return
uploaded_file = st.file_uploader("πŸ“‚ Upload your network CSV", type=["csv"])
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file)
df.columns = df.columns.str.strip()
required_prediction_columns = [
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
]
missing_columns = set(required_prediction_columns) - set(df.columns)
if missing_columns:
st.error(f"❌ Missing required columns for prediction: {missing_columns}")
st.write("πŸ“‹ Detected columns:", df.columns.tolist())
return
prediction_input = df[required_prediction_columns].copy()
predictions = classifier.predict(prediction_input)
st.success("βœ… Prediction complete!")
for i, result in enumerate(predictions):
st.subheader(f"Prediction {i + 1}")
st.write(f"**Predicted Label:** {result['result']}")
st.json(result['scores'])
except Exception as e:
st.error(f"❌ Error during prediction: {e}")
if __name__ == "__main__":
main()