infoSecProject / app.py
AmnaHassan's picture
Update app.py
ea4398d verified
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go
# -----------------------------
# Page config
# -----------------------------
st.set_page_config(
page_title="Cyber Threat Detection Dashboard",
page_icon="🛡️",
layout="wide"
)
# -----------------------------
# Feature names (MUST match training order)
# -----------------------------
FEATURE_COLUMNS = [
"processId",
"threadId",
"parentProcessId",
"userId",
"mountNamespace",
"argsNum",
"returnValue"
]
# -----------------------------
# Title
# -----------------------------
st.markdown("""
# 🛡️ Cyber Threat Detection Dashboard
**SOC-Style Deep Learning–Based Suspicious Activity Detection**
""")
st.markdown("---")
# -----------------------------
# Load scaler
# -----------------------------
scaler = joblib.load("scaler.pkl")
INPUT_DIM = len(FEATURE_COLUMNS)
# -----------------------------
# Model (EXACT training architecture)
# -----------------------------
model = nn.Sequential(
nn.Linear(INPUT_DIM, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1)
)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()
# =============================
# SIDEBAR
# =============================
st.sidebar.header("🔍 Analysis Mode")
mode = st.sidebar.radio(
"Choose input type",
["Single Log Event", "Upload CSV (Batch Analysis)"]
)
# =============================
# SINGLE EVENT MODE
# =============================
if mode == "Single Log Event":
st.sidebar.subheader("Log Features")
features = []
for col in FEATURE_COLUMNS:
val = st.sidebar.number_input(col, value=0)
features.append(val)
if st.sidebar.button("🚨 Analyze Event"):
x = np.array(features).reshape(1, -1)
x_scaled = scaler.transform(x)
x_tensor = torch.tensor(x_scaled, dtype=torch.float32)
with torch.no_grad():
prob = torch.sigmoid(model(x_tensor)).item()
# Risk logic
if prob > 0.7:
risk = "HIGH"
color = "red"
action = "Immediate investigation required. Isolate affected system."
elif prob > 0.4:
risk = "MEDIUM"
color = "orange"
action = "Monitor closely. Correlate with other logs."
else:
risk = "LOW"
color = "green"
action = "No action required. Log for auditing."
col1, col2, col3 = st.columns(3)
col1.metric("Risk Level", risk)
col2.metric("Suspicion Probability", f"{prob:.2f}")
col3.metric("Recommended Action", action)
fig = go.Figure(go.Indicator(
mode="gauge+number",
value=prob * 100,
title={'text': "Threat Confidence (%)"},
gauge={
'axis': {'range': [0, 100]},
'bar': {'color': color},
'steps': [
{'range': [0, 40], 'color': "lightgreen"},
{'range': [40, 70], 'color': "orange"},
{'range': [70, 100], 'color': "red"}
],
}
))
st.plotly_chart(fig, use_container_width=True)
# =============================
# CSV UPLOAD MODE
# =============================
else:
st.subheader("📄 Batch Log Analysis")
uploaded_file = st.file_uploader(
"Upload CSV file (validation/test logs)",
type=["csv"]
)
if uploaded_file:
df = pd.read_csv(uploaded_file)
# Drop label column if present
if "sus_label" in df.columns:
df = df.drop(columns=["sus_label"])
# Ensure correct columns
missing = set(FEATURE_COLUMNS) - set(df.columns)
if missing:
st.error(f"Missing required columns: {missing}")
else:
df = df[FEATURE_COLUMNS] # enforce correct order
X_scaled = scaler.transform(df.values)
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
with torch.no_grad():
probs = torch.sigmoid(model(X_tensor)).numpy().flatten()
df["suspicion_probability"] = probs
df["risk_level"] = df["suspicion_probability"].apply(
lambda p: "HIGH" if p > 0.7 else "MEDIUM" if p > 0.4 else "LOW"
)
st.success("Batch analysis completed")
st.dataframe(df, use_container_width=True)
st.markdown("### 📊 Risk Distribution")
st.bar_chart(df["risk_level"].value_counts())
# =============================
# Footer
# =============================
st.markdown("---")
st.info(
"This dashboard simulates a Security Operations Center (SOC) workflow by "
"analyzing system logs using a deep learning model trained on the BETH dataset."
)