it-ny / app.py
hari6677's picture
Create app.py
3dc7452 verified
import gradio as gr
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model
import joblib
import pickle
import json
# Load model and artifacts
model = load_model('multiclass_cnn_model_fast.h5')
scaler = joblib.load('multiclass_standard_scaler.pkl')
label_encoder = joblib.load('multiclass_label_encoder.pkl')
with open('multiclass_feature_names.pkl', 'rb') as f:
feature_names = pickle.load(f)
class NetworkIntrusionDetector:
def __init__(self, model, scaler, label_encoder, feature_names):
self.model = model
self.scaler = scaler
self.label_encoder = label_encoder
self.feature_names = feature_names
def preprocess(self, features_dict):
"""Preprocess input features"""
features_df = pd.DataFrame([features_dict])
features_aligned = features_df.reindex(columns=self.feature_names, fill_value=0)
features_scaled = self.scaler.transform(features_aligned)
features_cnn = features_scaled.reshape(1, -1, 1)
return features_cnn
def predict(self, **input_features):
"""Predict attack type"""
try:
# Preprocess
features_cnn = self.preprocess(input_features)
# Predict
prediction = self.model.predict(features_cnn, verbose=0)
class_idx = np.argmax(prediction, axis=1)[0]
confidence = np.max(prediction)
attack_type = self.label_encoder.inverse_transform([class_idx])[0]
# Get top 3 predictions
all_probs = {
self.label_encoder.classes_[i]: float(prediction[0][i])
for i in range(len(self.label_encoder.classes_))
}
top_3 = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)[:3]
# Prepare results
is_malicious = attack_type != 'normal'
status = "🚨 MALICIOUS TRAFFIC" if is_malicious else "βœ… NORMAL TRAFFIC"
color = "red" if is_malicious else "green"
result = {
"status": status,
"attack_type": attack_type,
"confidence": f"{confidence:.4f}",
"is_malicious": is_malicious,
"color": color
}
# Format output
output = f"""
## πŸ” **Network Traffic Analysis Result**
**Status:** <span style='color:{color}; font-weight:bold'>{status}</span>
**Classification:** {attack_type}
**Confidence:** {confidence:.4f}
### Top 3 Predictions:
{chr(10).join([f'β€’ {pred[0]}: {pred[1]:.4f}' for pred in top_3])}
**Action Recommended:** {'🚨 Immediate investigation required!' if is_malicious else 'βœ… No action needed'}
"""
return output
except Exception as e:
return f"❌ Error: {str(e)}"
# Initialize detector
detector = NetworkIntrusionDetector(model, scaler, label_encoder, feature_names)
# Common network traffic examples
sample_normal = {
'duration': 0, 'src_bytes': 0, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0,
'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 1, 'num_compromised': 0,
'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0,
'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0,
'is_guest_login': 0, 'count': 2, 'srv_count': 2, 'serror_rate': 0.0,
'srv_serror_rate': 0.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0,
'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0,
'dst_host_count': 150, 'dst_host_srv_count': 150, 'dst_host_same_srv_rate': 1.0,
'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0,
'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 0.0,
'dst_host_srv_serror_rate': 0.0, 'dst_host_rerror_rate': 0.0,
'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1,
'protocol_type_udp': 0, 'service_http': 1, 'service_other': 0, 'flag_SF': 1
}
sample_attack = {
'duration': 0, 'src_bytes': 1032, 'dst_bytes': 0, 'land': 0, 'wrong_fragment': 0,
'urgent': 0, 'hot': 0, 'num_failed_logins': 0, 'logged_in': 0, 'num_compromised': 0,
'root_shell': 0, 'su_attempted': 0, 'num_root': 0, 'num_file_creations': 0,
'num_shells': 0, 'num_access_files': 0, 'num_outbound_cmds': 0, 'is_host_login': 0,
'is_guest_login': 0, 'count': 1, 'srv_count': 1, 'serror_rate': 1.0,
'srv_serror_rate': 1.0, 'rerror_rate': 0.0, 'srv_rerror_rate': 0.0,
'same_srv_rate': 1.0, 'diff_srv_rate': 0.0, 'srv_diff_host_rate': 0.0,
'dst_host_count': 255, 'dst_host_srv_count': 255, 'dst_host_same_srv_rate': 1.0,
'dst_host_diff_srv_rate': 0.0, 'dst_host_same_src_port_rate': 0.0,
'dst_host_srv_diff_host_rate': 0.0, 'dst_host_serror_rate': 1.0,
'dst_host_srv_serror_rate': 1.0, 'dst_host_rerror_rate': 0.0,
'dst_host_srv_rerror_rate': 0.0, 'protocol_type_icmp': 0, 'protocol_type_tcp': 1,
'protocol_type_udp': 0, 'service_http': 0, 'service_other': 1, 'flag_S0': 1
}
# Create Gradio interface
def create_interface():
# Define input components
inputs = []
# Basic features
inputs.append(gr.Number(label="Duration", value=0))
inputs.append(gr.Number(label="Source Bytes", value=0))
inputs.append(gr.Number(label="Destination Bytes", value=0))
inputs.append(gr.Number(label="Land (0/1)", value=0))
inputs.append(gr.Number(label="Wrong Fragment", value=0))
# Connection features
inputs.append(gr.Number(label="Hot", value=0))
inputs.append(gr.Number(label="Num Failed Logins", value=0))
inputs.append(gr.Number(label="Logged In (0/1)", value=0))
inputs.append(gr.Number(label="Num Compromised", value=0))
inputs.append(gr.Number(label="Root Shell (0/1)", value=0))
# Rate features
inputs.append(gr.Number(label="Count", value=1))
inputs.append(gr.Number(label="Service Count", value=1))
inputs.append(gr.Slider(0, 1, label="Error Rate", value=0))
inputs.append(gr.Slider(0, 1, label="Service Error Rate", value=0))
inputs.append(gr.Slider(0, 1, label="Same Service Rate", value=1))
# Protocol type
inputs.append(gr.Radio([0, 1], label="Protocol TCP", value=1))
inputs.append(gr.Radio([0, 1], label="Protocol UDP", value=0))
inputs.append(gr.Radio([0, 1], label="Protocol ICMP", value=0))
# Service type
inputs.append(gr.Radio([0, 1], label="Service HTTP", value=1))
inputs.append(gr.Radio([0, 1], label="Service Other", value=0))
# Flag
inputs.append(gr.Radio([0, 1], label="Flag SF", value=1))
inputs.append(gr.Radio([0, 1], label="Flag S0", value=0))
def predict_attack(*args):
# Convert inputs to dictionary
feature_names_simple = [
'duration', 'src_bytes', 'dst_bytes', 'land', 'wrong_fragment',
'hot', 'num_failed_logins', 'logged_in', 'num_compromised', 'root_shell',
'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'same_srv_rate',
'protocol_type_tcp', 'protocol_type_udp', 'protocol_type_icmp',
'service_http', 'service_other', 'flag_SF', 'flag_S0'
]
features_dict = dict(zip(feature_names_simple, args))
# Fill missing features with 0
for feature in feature_names:
if feature not in features_dict:
features_dict[feature] = 0
return detector.predict(**features_dict)
# Create interface
iface = gr.Interface(
fn=predict_attack,
inputs=inputs,
outputs=gr.Markdown(),
title="🚨 Network Intrusion Detection System",
description="""**Detect malicious network traffic in real-time using AI**\n
This system can identify 40+ different types of network attacks including:
- DoS attacks (neptune, smurf, teardrop)
- Probing attacks (portsweep, nmap, satan)
- R2L attacks (guess_passwd, warezclient)
- U2R attacks (buffer_overflow, rootkit)
*Enter network traffic features below to analyze:*""",
examples=[
[sample_normal[k] for k in list(sample_normal.keys())[:len(inputs)]],
[sample_attack[k] for k in list(sample_attack.keys())[:len(inputs)]]
],
theme="soft"
)
return iface
# Launch app
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)