Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import pandas as pd | |
| import joblib | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # --- Configuration --- | |
| REPO_ID = "netgoat-ai/GoatAI" | |
| MODEL_FILENAME = "goatai.keras" | |
| SCALER_FILENAME = "scaler.pkl" | |
| DEFAULT_THRESHOLD = 0.003 | |
| FEATURE_NAMES = [ | |
| "Flow Duration", | |
| "Total Fwd Packets", | |
| "Total Backward Packets", | |
| "Packet Length Mean", | |
| "Flow IAT Mean", | |
| "Fwd Flag Count" | |
| ] | |
| # --- Load Resources --- | |
| def load_file(filename, repo_id): | |
| """ | |
| Checks for a local file. If not found, attempts to download from HF Hub. | |
| """ | |
| if os.path.exists(filename): | |
| print(f"Found local file: {filename}") | |
| return filename | |
| print(f"'{filename}' not found locally. Attempting download from {repo_id}...") | |
| try: | |
| downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| print(f"Successfully downloaded to: {downloaded_path}") | |
| return downloaded_path | |
| except Exception as e: | |
| print(f"Could not download {filename}: {e}") | |
| return None | |
| # 1. Load Model | |
| model_path = load_file(MODEL_FILENAME, REPO_ID) | |
| model = None | |
| if model_path: | |
| try: | |
| model = tf.keras.models.load_model(model_path) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # 2. Load Scaler | |
| scaler_path = load_file(SCALER_FILENAME, REPO_ID) | |
| scaler = None | |
| if scaler_path: | |
| try: | |
| scaler = joblib.load(scaler_path) | |
| print("Scaler loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading scaler: {e}") | |
| else: | |
| print("Warning: Scaler not available. Input data must be pre-normalized.") | |
| # --- Simulation Logic (Based on the dataset code) --- | |
| def generate_benign_sample(): | |
| """Generates a realistic BENIGN traffic sample.""" | |
| duration = np.random.randint(50000, 60000000) | |
| fwd_pkts = np.random.randint(10, 100) | |
| bwd_pkts = fwd_pkts + np.random.randint(5, 50) | |
| pkt_len = abs(np.random.normal(loc=500, scale=200)) | |
| iat = abs(np.random.normal(loc=100000, scale=50000)) | |
| syn_flag = np.random.choice([0, 1], p=[0.95, 0.05]) | |
| data = [duration, fwd_pkts, bwd_pkts, pkt_len, iat, syn_flag] | |
| return ", ".join([f"{x:.2f}" for x in data]) | |
| def generate_attack_sample(): | |
| """Generates a realistic ATTACK traffic sample.""" | |
| duration = np.random.randint(100, 10000) | |
| fwd_pkts = np.random.randint(500, 50000) # Huge volume | |
| bwd_pkts = np.random.randint(0, 5) # No response | |
| pkt_len = np.random.normal(loc=1200, scale=10) # Fixed size (approx) | |
| iat = np.random.exponential(scale=100) # Super fast | |
| syn_flag = np.random.choice([0, 1], p=[0.1, 0.9]) | |
| data = [duration, fwd_pkts, bwd_pkts, pkt_len, iat, syn_flag] | |
| return ", ".join([f"{x:.2f}" for x in data]) | |
| # --- Prediction Logic --- | |
| def predict(csv_text, threshold): | |
| if model is None: | |
| return "System Error: Model not loaded.", 0.0, f"Could not find {MODEL_FILENAME}." | |
| try: | |
| # 1. Parse Input | |
| data_list = [float(x.strip()) for x in csv_text.split(',') if x.strip()] | |
| # 2. Validate Dimensions (Must be 6) | |
| if len(data_list) != 6: | |
| return f"<h3 style='color: orange;'>Input Error</h3>", 0.0, f"Expected 6 features, got {len(data_list)}.\nRequired: {', '.join(FEATURE_NAMES)}" | |
| data_array = np.array([data_list]) | |
| # 3. Scale Data | |
| if scaler: | |
| try: | |
| processed_data = scaler.transform(data_array) | |
| except ValueError as ve: | |
| return f"Scaling Error: {str(ve)}", 0.0, "Dimension mismatch." | |
| else: | |
| processed_data = data_array | |
| # 4. Predict & Calculate Loss | |
| reconstructions = model.predict(processed_data) | |
| loss = tf.keras.losses.mse(reconstructions, processed_data) | |
| loss_value = float(loss[0]) | |
| # 5. Threshold Logic | |
| if loss_value > threshold: | |
| label = "โ ๏ธ DDoS Attack Detected" | |
| status_color = "#ff4b4b" # Red | |
| desc = "High reconstruction error indicates anomalous pattern." | |
| else: | |
| label = "โ Benign Traffic" | |
| status_color = "#2b9348" # Green | |
| desc = "Low reconstruction error indicates normal pattern." | |
| result_html = f""" | |
| <div style="text-align: center; padding: 20px; background-color: {status_color}20; border-radius: 10px; border: 2px solid {status_color};"> | |
| <h2 style="color: {status_color}; margin: 0;">{label}</h2> | |
| <p style="margin-top: 5px; color: #555;">{desc}</p> | |
| </div> | |
| """ | |
| log = f"Input Features: {len(data_list)}\nScaled Values: {processed_data[0]}" | |
| return result_html, loss_value, log | |
| except ValueError: | |
| return f"<h3 style='color: orange;'>Format Error</h3>", 0.0, "Could not convert input to numbers." | |
| except Exception as e: | |
| return f"<h2 style='color: orange;'>System Error</h2>", 0.0, str(e) | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # ๐ GoatAI DDoS Detector | |
| **Model:** `{REPO_ID}` (Autoencoder) | |
| This system analyzes 6 network traffic features to detect DDoS attacks. | |
| It expects comma-separated values in the following order: | |
| `{', '.join(FEATURE_NAMES)}` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Generate / Input Data") | |
| with gr.Row(): | |
| btn_benign = gr.Button("Simulate Benign User ๐ค", size="sm") | |
| btn_attack = gr.Button("Simulate DDoS Bot ๐ค", size="sm", variant="stop") | |
| input_text = gr.Textbox( | |
| label="Traffic Feature Vector (6 values)", | |
| placeholder="e.g. 50000, 50, 60, 500, 100000, 0", | |
| lines=2 | |
| ) | |
| gr.Markdown("### 2. Settings") | |
| threshold_slider = gr.Slider( | |
| minimum=0.0001, | |
| maximum=0.1, | |
| value=DEFAULT_THRESHOLD, | |
| step=0.0001, | |
| label="Sensitivity Threshold (MSE)", | |
| info="Lower = More sensitive (flags more traffic as attacks)." | |
| ) | |
| predict_btn = gr.Button("Analyze Traffic", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 3. Analysis Results") | |
| output_label = gr.HTML(label="Status") | |
| output_loss = gr.Number(label="Anomaly Score (MSE)", precision=6) | |
| output_log = gr.Textbox(label="Debug Log", lines=4) | |
| # Wire up buttons | |
| btn_benign.click(fn=generate_benign_sample, outputs=input_text) | |
| btn_attack.click(fn=generate_attack_sample, outputs=input_text) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider], | |
| outputs=[output_label, output_loss, output_log] | |
| ) | |
| # Examples (Valid 6-feature inputs) | |
| gr.Examples( | |
| examples=[ | |
| ["55000, 20, 25, 520.5, 95000, 0", 0.003], # Benign-like | |
| ["200, 5000, 0, 1200, 50, 1", 0.003], # Attack-like | |
| ], | |
| inputs=[input_text, threshold_slider], | |
| label="Quick Examples" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |