Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download, InferenceClient | |
| import gradio as gr | |
| # Load environment | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN") | |
| # ------------------------------------------ | |
| # Load IDS Model | |
| # ------------------------------------------ | |
| model_path = hf_hub_download( | |
| repo_id="utsavNagar/cyberids-ml", | |
| filename="ids_model.pkl", | |
| token=hf_token | |
| ) | |
| with open(model_path, "rb") as f: | |
| ids_model = pickle.load(f) | |
| # ------------------------------------------ | |
| # IDS Prediction Logic | |
| # ------------------------------------------ | |
| def predict_intrusion(data_dict): | |
| df = pd.DataFrame([data_dict]) | |
| df = df.apply(pd.to_numeric, errors="coerce").fillna(0) | |
| prob = ids_model.predict(df)[0] | |
| return "Attack" if prob > 0.5 else "Normal" | |
| # ------------------------------------------ | |
| # LLM Setup | |
| # ------------------------------------------ | |
| LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" | |
| client = InferenceClient(model=LLM_MODEL, token=hf_token) | |
| # ------------------------------------------ | |
| # Chat-based explanation from LLM | |
| # ------------------------------------------ | |
| def generate_report(features, prediction): | |
| prompt = f""" | |
| You are a cybersecurity analyst. | |
| Network Data: | |
| {features} | |
| IDS Prediction: {prediction} | |
| Provide a clear and professional analysis including: | |
| 1. Attack or Normal? | |
| 2. Why the IDS believes this. | |
| 3. Most likely attack type (DoS, Probe, R2L, U2R). | |
| 4. Severity level. | |
| 5. Recommended actions. | |
| 6. Brief incident summary (2–3 sentences). | |
| """ | |
| try: | |
| resp = client.chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=400, | |
| temperature=0.5 | |
| ) | |
| return resp.choices[0].message["content"] | |
| except Exception as e: | |
| return f"[LLM Error] {e}" | |
| # ------------------------------------------ | |
| # Wrapper for Gradio | |
| # ------------------------------------------ | |
| def analyze(**inputs): | |
| prediction = predict_intrusion(inputs) | |
| report = generate_report(inputs, prediction) | |
| return prediction, report | |
| # ------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------ | |
| feature_inputs = [] | |
| NSL_FEATURES = [ | |
| "duration","protocol_type","service","flag","src_bytes","dst_bytes", | |
| "land","wrong_fragment","urgent","hot","num_failed_logins","logged_in", | |
| "num_compromised","root_shell","su_attempted","num_root","num_file_creations", | |
| "num_shells","num_access_files","num_outbound_cmds","is_host_login", | |
| "is_guest_login","count","srv_count","serror_rate","srv_serror_rate", | |
| "rerror_rate","srv_rerror_rate","same_srv_rate","diff_srv_rate", | |
| "srv_diff_host_rate","dst_host_count","dst_host_srv_count", | |
| "dst_host_same_srv_rate","dst_host_diff_srv_rate", | |
| "dst_host_same_src_port_rate","dst_host_srv_diff_host_rate", | |
| "dst_host_serror_rate","dst_host_srv_serror_rate","dst_host_rerror_rate", | |
| "dst_host_srv_rerror_rate" | |
| ] | |
| for f in NSL_FEATURES: | |
| feature_inputs.append(gr.Number(label=f)) | |
| interface = gr.Interface( | |
| fn=analyze, | |
| inputs=feature_inputs, | |
| outputs=[ | |
| gr.Textbox(label="Prediction (Attack / Normal)", interactive=False), | |
| gr.Textbox(label="AI-Generated Incident Report", lines=12, interactive=False) | |
| ], | |
| title="Cybersecurity Intrusion Detection System (IDS + AI Analyst)", | |
| description="Detect network intrusions using a Machine Learning IDS model and get a full explanation via an LLM." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |