utsavNagar's picture
Update app.py
7baa7c6 verified
import os
import pickle
import pandas as pd
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, InferenceClient
import gradio as gr
# Load environment
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
# ------------------------------------------
# Load IDS Model
# ------------------------------------------
model_path = hf_hub_download(
repo_id="utsavNagar/cyberids-ml",
filename="ids_model.pkl",
token=hf_token
)
with open(model_path, "rb") as f:
ids_model = pickle.load(f)
# ------------------------------------------
# IDS Prediction Logic
# ------------------------------------------
def predict_intrusion(data_dict):
df = pd.DataFrame([data_dict])
df = df.apply(pd.to_numeric, errors="coerce").fillna(0)
prob = ids_model.predict(df)[0]
return "Attack" if prob > 0.5 else "Normal"
# ------------------------------------------
# LLM Setup
# ------------------------------------------
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
client = InferenceClient(model=LLM_MODEL, token=hf_token)
# ------------------------------------------
# Chat-based explanation from LLM
# ------------------------------------------
def generate_report(features, prediction):
prompt = f"""
You are a cybersecurity analyst.
Network Data:
{features}
IDS Prediction: {prediction}
Provide a clear and professional analysis including:
1. Attack or Normal?
2. Why the IDS believes this.
3. Most likely attack type (DoS, Probe, R2L, U2R).
4. Severity level.
5. Recommended actions.
6. Brief incident summary (2–3 sentences).
"""
try:
resp = client.chat_completion(
messages=[{"role": "user", "content": prompt}],
max_tokens=400,
temperature=0.5
)
return resp.choices[0].message["content"]
except Exception as e:
return f"[LLM Error] {e}"
# ------------------------------------------
# Wrapper for Gradio
# ------------------------------------------
def analyze(**inputs):
prediction = predict_intrusion(inputs)
report = generate_report(inputs, prediction)
return prediction, report
# ------------------------------------------
# Gradio UI
# ------------------------------------------
feature_inputs = []
NSL_FEATURES = [
"duration","protocol_type","service","flag","src_bytes","dst_bytes",
"land","wrong_fragment","urgent","hot","num_failed_logins","logged_in",
"num_compromised","root_shell","su_attempted","num_root","num_file_creations",
"num_shells","num_access_files","num_outbound_cmds","is_host_login",
"is_guest_login","count","srv_count","serror_rate","srv_serror_rate",
"rerror_rate","srv_rerror_rate","same_srv_rate","diff_srv_rate",
"srv_diff_host_rate","dst_host_count","dst_host_srv_count",
"dst_host_same_srv_rate","dst_host_diff_srv_rate",
"dst_host_same_src_port_rate","dst_host_srv_diff_host_rate",
"dst_host_serror_rate","dst_host_srv_serror_rate","dst_host_rerror_rate",
"dst_host_srv_rerror_rate"
]
for f in NSL_FEATURES:
feature_inputs.append(gr.Number(label=f))
interface = gr.Interface(
fn=analyze,
inputs=feature_inputs,
outputs=[
gr.Textbox(label="Prediction (Attack / Normal)", interactive=False),
gr.Textbox(label="AI-Generated Incident Report", lines=12, interactive=False)
],
title="Cybersecurity Intrusion Detection System (IDS + AI Analyst)",
description="Detect network intrusions using a Machine Learning IDS model and get a full explanation via an LLM."
)
if __name__ == "__main__":
interface.launch()