IotNetworkMalwareDetection / src /streamlit_app.py
ATllll's picture
Update src/streamlit_app.py
dd609d0 verified
raw
history blame
6.38 kB
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import os
import ipaddress
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers.legacy import SGD
# ==== File Paths (update if needed) ====
MODEL_FILE = ("model1.h5")
WEIGHTS_FILE = ("weights.h5")
SCALER_FILE = ("standard_scaler1.pkl")
LABEL_ENCODER_FILE = ("label_encoder1.pkl")
ENCODER_PATHS = {
"proto": ("categorical_label_encoder_proto1.pkl"),
"conn_state": ("categorical_label_encoder_conn_state1.pkl"),
"history": ("categorical_label_encoder_history1.pkl")
}
SAMPLE_FILE = "sample_input.csv"
# ==== Class for Malware Prediction ====
class MalwareClassifier:
def __init__(self):
for path in [MODEL_FILE, WEIGHTS_FILE, SCALER_FILE, LABEL_ENCODER_FILE] + list(ENCODER_PATHS.values()):
if not os.path.exists(path):
raise FileNotFoundError(f"Missing file: {path}")
self.model = load_model(MODEL_FILE, compile=False)
self.model.load_weights(WEIGHTS_FILE)
self.model.compile(optimizer=SGD(), loss="categorical_crossentropy", metrics=["accuracy"])
self.scaler = joblib.load(SCALER_FILE)
self.label_encoder = joblib.load(LABEL_ENCODER_FILE)
self.encoders = {k: joblib.load(v) for k, v in ENCODER_PATHS.items()}
def _validate_numeric_column(self, col_name, values):
if not values.astype(str).str.isdigit().all():
raise ValueError(f"Non-integer value found in column {col_name}")
if (values < 0).any():
raise ValueError(f"Negative value found in column {col_name}")
def _validate_ip_address(self, col_name, values):
for value in values:
try:
ipaddress.ip_address(value)
except ValueError:
raise ValueError(f"Invalid IP address in column {col_name}: {value}")
def _validate_input_data(self, data):
required_columns = {
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
}
if data.empty:
raise ValueError("CSV is empty.")
missing = required_columns - set(data.columns)
if missing:
raise ValueError(f"Missing required columns: {missing}")
for col in data.columns:
if col in {'id.orig_p', 'id.resp_p', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'}:
self._validate_numeric_column(col, data[col])
elif col in {'id.orig_h', 'id.resp_h'}:
self._validate_ip_address(col, data[col])
def _encode_data(self, df):
for col in ['proto', 'conn_state', 'history']:
df[col] = self.encoders[col].transform(df[col])
df['id.orig_h'] = df['id.orig_h'].apply(lambda x: int(ipaddress.ip_address(x)))
df['id.resp_h'] = df['id.resp_h'].apply(lambda x: int(ipaddress.ip_address(x)))
return df
def _preprocess_data(self, df):
self._validate_input_data(df)
df = self._encode_data(df)
model_columns = [
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
]
return self.scaler.transform(df[model_columns])
def predict(self, df):
preprocessed = self._preprocess_data(df)
preds = self.model.predict(preprocessed)
results = []
for pred in preds:
label = self.label_encoder.inverse_transform([np.argmax(pred)])[0]
scores = {label: f"{score:.6f}" for label, score in zip(self.label_encoder.classes_, pred)}
results.append({"result": label, "scores": scores})
return results
# ==== Streamlit UI ====
def main():
import os
st.set_page_config(page_title="Malware Detection System", page_icon="πŸ›‘οΈ")
st.title("πŸ›‘οΈ Malware Detection System")
st.markdown("Upload a CSV file with network traffic logs to detect malware.")
# πŸ› οΈ DEBUG INFO
# πŸ”½ Sample CSV download
SAMPLE_FILE = "sample_input.csv"
st.markdown("πŸ“„ Need a sample file to test? Download below:")
if os.path.exists(SAMPLE_FILE):
with open(SAMPLE_FILE, "rb") as f:
st.download_button(
label="πŸ“₯ Download Sample CSV",
data=f,
file_name="sample_input.csv",
mime="text/csv"
)
else:
st.warning("⚠️ Sample CSV not found. Please place 'sample_input.csv' in the app directory.")
try:
classifier = MalwareClassifier()
except Exception as e:
st.error(f"❌ Model loading failed: {e}")
return
uploaded_file = st.file_uploader("πŸ“‚ Drag & drop or select a CSV file", type=["csv"])
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file, delimiter=',')
df.columns = df.columns.str.strip() # βœ… Normalize column names
required_prediction_columns = [
'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p',
'proto', 'conn_state', 'history',
'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'
]
# βœ… Check for missing columns
missing_columns = set(required_prediction_columns) - set(df.columns)
if missing_columns:
st.error(f"❌ Missing required columns for prediction: {missing_columns}")
st.write("πŸ“‹ Detected columns:", df.columns.tolist()) # Helpful debug info
return
# βœ… Extract necessary columns
prediction_input = df[required_prediction_columns].copy()
# βœ… Predict
predictions = classifier.predict(prediction_input)
st.success("βœ… Prediction complete!")
# βœ… Display results
for i, result in enumerate(predictions):
st.subheader(f"Prediction {i + 1}")
st.write(f"**Predicted Label:** {result['result']}")
st.json(result['scores'])
except Exception as e:
st.error(f"❌ Error during prediction: {e}")
if __name__ == "__main__":
main()