GoatAI-space / app.py
FlameF0X's picture
Update app.py
a873281 verified
import gradio as gr
import tensorflow as tf
import numpy as np
import pandas as pd
import joblib
import os
from huggingface_hub import hf_hub_download
# --- Configuration ---
REPO_ID = "netgoat-ai/GoatAI"
MODEL_FILENAME = "goatai.keras"
SCALER_FILENAME = "scaler.pkl"
DEFAULT_THRESHOLD = 0.003
FEATURE_NAMES = [
"Flow Duration",
"Total Fwd Packets",
"Total Backward Packets",
"Packet Length Mean",
"Flow IAT Mean",
"Fwd Flag Count"
]
# --- Load Resources ---
def load_file(filename, repo_id):
"""
Checks for a local file. If not found, attempts to download from HF Hub.
"""
if os.path.exists(filename):
print(f"Found local file: {filename}")
return filename
print(f"'{filename}' not found locally. Attempting download from {repo_id}...")
try:
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Successfully downloaded to: {downloaded_path}")
return downloaded_path
except Exception as e:
print(f"Could not download {filename}: {e}")
return None
# 1. Load Model
model_path = load_file(MODEL_FILENAME, REPO_ID)
model = None
if model_path:
try:
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# 2. Load Scaler
scaler_path = load_file(SCALER_FILENAME, REPO_ID)
scaler = None
if scaler_path:
try:
scaler = joblib.load(scaler_path)
print("Scaler loaded successfully.")
except Exception as e:
print(f"Error loading scaler: {e}")
else:
print("Warning: Scaler not available. Input data must be pre-normalized.")
# --- Simulation Logic (Based on the dataset code) ---
def generate_benign_sample():
"""Generates a realistic BENIGN traffic sample."""
duration = np.random.randint(50000, 60000000)
fwd_pkts = np.random.randint(10, 100)
bwd_pkts = fwd_pkts + np.random.randint(5, 50)
pkt_len = abs(np.random.normal(loc=500, scale=200))
iat = abs(np.random.normal(loc=100000, scale=50000))
syn_flag = np.random.choice([0, 1], p=[0.95, 0.05])
data = [duration, fwd_pkts, bwd_pkts, pkt_len, iat, syn_flag]
return ", ".join([f"{x:.2f}" for x in data])
def generate_attack_sample():
"""Generates a realistic ATTACK traffic sample."""
duration = np.random.randint(100, 10000)
fwd_pkts = np.random.randint(500, 50000) # Huge volume
bwd_pkts = np.random.randint(0, 5) # No response
pkt_len = np.random.normal(loc=1200, scale=10) # Fixed size (approx)
iat = np.random.exponential(scale=100) # Super fast
syn_flag = np.random.choice([0, 1], p=[0.1, 0.9])
data = [duration, fwd_pkts, bwd_pkts, pkt_len, iat, syn_flag]
return ", ".join([f"{x:.2f}" for x in data])
# --- Prediction Logic ---
def predict(csv_text, threshold):
if model is None:
return "System Error: Model not loaded.", 0.0, f"Could not find {MODEL_FILENAME}."
try:
# 1. Parse Input
data_list = [float(x.strip()) for x in csv_text.split(',') if x.strip()]
# 2. Validate Dimensions (Must be 6)
if len(data_list) != 6:
return f"<h3 style='color: orange;'>Input Error</h3>", 0.0, f"Expected 6 features, got {len(data_list)}.\nRequired: {', '.join(FEATURE_NAMES)}"
data_array = np.array([data_list])
# 3. Scale Data
if scaler:
try:
processed_data = scaler.transform(data_array)
except ValueError as ve:
return f"Scaling Error: {str(ve)}", 0.0, "Dimension mismatch."
else:
processed_data = data_array
# 4. Predict & Calculate Loss
reconstructions = model.predict(processed_data)
loss = tf.keras.losses.mse(reconstructions, processed_data)
loss_value = float(loss[0])
# 5. Threshold Logic
if loss_value > threshold:
label = "โš ๏ธ DDoS Attack Detected"
status_color = "#ff4b4b" # Red
desc = "High reconstruction error indicates anomalous pattern."
else:
label = "โœ… Benign Traffic"
status_color = "#2b9348" # Green
desc = "Low reconstruction error indicates normal pattern."
result_html = f"""
<div style="text-align: center; padding: 20px; background-color: {status_color}20; border-radius: 10px; border: 2px solid {status_color};">
<h2 style="color: {status_color}; margin: 0;">{label}</h2>
<p style="margin-top: 5px; color: #555;">{desc}</p>
</div>
"""
log = f"Input Features: {len(data_list)}\nScaled Values: {processed_data[0]}"
return result_html, loss_value, log
except ValueError:
return f"<h3 style='color: orange;'>Format Error</h3>", 0.0, "Could not convert input to numbers."
except Exception as e:
return f"<h2 style='color: orange;'>System Error</h2>", 0.0, str(e)
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# ๐Ÿ GoatAI DDoS Detector
**Model:** `{REPO_ID}` (Autoencoder)
This system analyzes 6 network traffic features to detect DDoS attacks.
It expects comma-separated values in the following order:
`{', '.join(FEATURE_NAMES)}`
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Generate / Input Data")
with gr.Row():
btn_benign = gr.Button("Simulate Benign User ๐Ÿ‘ค", size="sm")
btn_attack = gr.Button("Simulate DDoS Bot ๐Ÿค–", size="sm", variant="stop")
input_text = gr.Textbox(
label="Traffic Feature Vector (6 values)",
placeholder="e.g. 50000, 50, 60, 500, 100000, 0",
lines=2
)
gr.Markdown("### 2. Settings")
threshold_slider = gr.Slider(
minimum=0.0001,
maximum=0.1,
value=DEFAULT_THRESHOLD,
step=0.0001,
label="Sensitivity Threshold (MSE)",
info="Lower = More sensitive (flags more traffic as attacks)."
)
predict_btn = gr.Button("Analyze Traffic", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### 3. Analysis Results")
output_label = gr.HTML(label="Status")
output_loss = gr.Number(label="Anomaly Score (MSE)", precision=6)
output_log = gr.Textbox(label="Debug Log", lines=4)
# Wire up buttons
btn_benign.click(fn=generate_benign_sample, outputs=input_text)
btn_attack.click(fn=generate_attack_sample, outputs=input_text)
predict_btn.click(
fn=predict,
inputs=[input_text, threshold_slider],
outputs=[output_label, output_loss, output_log]
)
# Examples (Valid 6-feature inputs)
gr.Examples(
examples=[
["55000, 20, 25, 520.5, 95000, 0", 0.003], # Benign-like
["200, 5000, 0, 1200, 50, 1", 0.003], # Attack-like
],
inputs=[input_text, threshold_slider],
label="Quick Examples"
)
if __name__ == "__main__":
demo.launch()