Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import os | |
| import ipaddress | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.optimizers.legacy import SGD | |
| # ==== File Paths ==== | |
| MODEL_FILE = "src/model1.h5" | |
| WEIGHTS_FILE = "src/weights.h5" | |
| SCALER_FILE = "src/standard_scaler1.pkl" | |
| LABEL_ENCODER_FILE = "src/label_encoder1.pkl" | |
| ENCODER_PATHS = { | |
| "proto": "src/categorical_label_encoder_proto1.pkl", | |
| "conn_state": "src/categorical_label_encoder_conn_state1.pkl", | |
| "history": "src/categorical_label_encoder_history1.pkl" | |
| } | |
| SAMPLE_FILE = "src/sample_input.csv" | |
| # ==== Class for Malware Prediction ==== | |
| class MalwareClassifier: | |
| def __init__(self): | |
| for path in [MODEL_FILE, WEIGHTS_FILE, SCALER_FILE, LABEL_ENCODER_FILE] + list(ENCODER_PATHS.values()): | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Missing file: {path}") | |
| self.model = load_model(MODEL_FILE, compile=False) | |
| self.model.load_weights(WEIGHTS_FILE) | |
| self.model.compile(optimizer=SGD(), loss="categorical_crossentropy", metrics=["accuracy"]) | |
| self.scaler = joblib.load(SCALER_FILE) | |
| self.label_encoder = joblib.load(LABEL_ENCODER_FILE) | |
| self.encoders = {k: joblib.load(v) for k, v in ENCODER_PATHS.items()} | |
| def _validate_numeric_column(self, col_name, values): | |
| if not values.astype(str).str.isdigit().all(): | |
| raise ValueError(f"Non-integer value found in column {col_name}") | |
| if (values < 0).any(): | |
| raise ValueError(f"Negative value found in column {col_name}") | |
| def _validate_ip_address(self, col_name, values): | |
| for value in values: | |
| try: | |
| ipaddress.ip_address(value) | |
| except ValueError: | |
| raise ValueError(f"Invalid IP address in column {col_name}: {value}") | |
| def _validate_input_data(self, data): | |
| required_columns = { | |
| 'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p', | |
| 'proto', 'conn_state', 'history', | |
| 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes' | |
| } | |
| if data.empty: | |
| raise ValueError("CSV is empty.") | |
| missing = required_columns - set(data.columns) | |
| if missing: | |
| raise ValueError(f"Missing required columns: {missing}") | |
| for col in data.columns: | |
| if col in {'id.orig_p', 'id.resp_p', 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes'}: | |
| self._validate_numeric_column(col, data[col]) | |
| elif col in {'id.orig_h', 'id.resp_h'}: | |
| self._validate_ip_address(col, data[col]) | |
| def _encode_data(self, df): | |
| for col in ['proto', 'conn_state', 'history']: | |
| df[col] = self.encoders[col].transform(df[col]) | |
| df['id.orig_h'] = df['id.orig_h'].apply(lambda x: int(ipaddress.ip_address(x))) | |
| df['id.resp_h'] = df['id.resp_h'].apply(lambda x: int(ipaddress.ip_address(x))) | |
| return df | |
| def _preprocess_data(self, df): | |
| self._validate_input_data(df) | |
| df = self._encode_data(df) | |
| model_columns = [ | |
| 'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p', | |
| 'proto', 'conn_state', 'history', | |
| 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes' | |
| ] | |
| return self.scaler.transform(df[model_columns]) | |
| def predict(self, df): | |
| preprocessed = self._preprocess_data(df) | |
| preds = self.model.predict(preprocessed) | |
| results = [] | |
| for pred in preds: | |
| label = self.label_encoder.inverse_transform([np.argmax(pred)])[0] | |
| scores = {label: f"{score:.6f}" for label, score in zip(self.label_encoder.classes_, pred)} | |
| results.append({"result": label, "scores": scores}) | |
| return results | |
| # ==== Streamlit UI ==== | |
| def main(): | |
| st.set_page_config(page_title="Malware Detection System", page_icon="π‘οΈ") | |
| st.title("π‘οΈ Malware Detection System") | |
| st.markdown("Upload a CSV file with network traffic logs to detect malware.") | |
| st.markdown("π Need a sample file to test? Download below:") | |
| if os.path.exists(SAMPLE_FILE): | |
| with open(SAMPLE_FILE, "rb") as f: | |
| st.download_button( | |
| label="π₯ Download Sample CSV", | |
| data=f, | |
| file_name="sample_input.csv", # β Do NOT include `src/` in file_name | |
| mime="text/csv" | |
| ) | |
| else: | |
| st.warning("β οΈ 'sample_input.csv' not found in the 'src' directory.") | |
| try: | |
| classifier = MalwareClassifier() | |
| except Exception as e: | |
| st.error(f"β Model loading failed: {e}") | |
| return | |
| uploaded_file = st.file_uploader("π Upload your network CSV", type=["csv"]) | |
| if uploaded_file is not None: | |
| try: | |
| df = pd.read_csv(uploaded_file) | |
| df.columns = df.columns.str.strip() | |
| required_prediction_columns = [ | |
| 'id.orig_h', 'id.orig_p', 'id.resp_h', 'id.resp_p', | |
| 'proto', 'conn_state', 'history', | |
| 'orig_pkts', 'orig_ip_bytes', 'resp_pkts', 'resp_ip_bytes' | |
| ] | |
| missing_columns = set(required_prediction_columns) - set(df.columns) | |
| if missing_columns: | |
| st.error(f"β Missing required columns for prediction: {missing_columns}") | |
| st.write("π Detected columns:", df.columns.tolist()) | |
| return | |
| prediction_input = df[required_prediction_columns].copy() | |
| predictions = classifier.predict(prediction_input) | |
| st.success("β Prediction complete!") | |
| for i, result in enumerate(predictions): | |
| st.subheader(f"Prediction {i + 1}") | |
| st.write(f"**Predicted Label:** {result['result']}") | |
| st.json(result['scores']) | |
| except Exception as e: | |
| st.error(f"β Error during prediction: {e}") | |
| if __name__ == "__main__": | |
| main() | |