File size: 3,324 Bytes
99a8d8b
 
 
 
 
017407d
49264dc
017407d
 
99a8d8b
 
 
49264dc
99a8d8b
49264dc
 
 
 
 
 
 
 
 
 
 
 
017407d
49264dc
 
 
 
 
 
017407d
49264dc
 
 
 
 
 
017407d
 
49264dc
017407d
49264dc
 
017407d
49264dc
 
017407d
49264dc
 
017407d
 
 
 
 
 
99a8d8b
 
017407d
49264dc
 
 
017407d
 
49264dc
 
 
 
 
017407d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a8d8b
017407d
49264dc
017407d
99a8d8b
 
017407d
 
 
 
 
 
 
 
 
 
 
 
49264dc
99a8d8b
 
 
49264dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download

# =====================================================
# Load CICIDS2018 Model
# =====================================================
rf_model_path = hf_hub_download(
    repo_id="CodebaseAi/netraids-ml-models",
    filename="rf_pipeline.joblib"
)
rf_model = joblib.load(rf_model_path)

# =====================================================
# Load Training Artifacts (SAFE)
# =====================================================
artifacts = {}
try:
    artifacts_path = hf_hub_download(
        repo_id="CodebaseAi/netraids-ml-models",
        filename="training_artifacts.joblib"
    )
    artifacts = joblib.load(artifacts_path)
except Exception as e:
    print("Artifacts not loaded:", e)

# Try to infer feature columns safely
FEATURE_COLUMNS = None
for key in ["feature_columns", "features", "columns", "X_columns"]:
    if key in artifacts:
        FEATURE_COLUMNS = artifacts[key]
        break

# Try to infer class mapping safely
CLASS_MAPPING = None
for key in ["class_mapping", "label_mapping", "classes", "target_mapping"]:
    if key in artifacts:
        CLASS_MAPPING = artifacts[key]
        break

# =====================================================
# Load BCC-Darknet Model (5-class)
# =====================================================
darknet_model = joblib.load(
    hf_hub_download("CodebaseAi/netraids-ml-models", "realtime_model.pkl")
)
darknet_scaler = joblib.load(
    hf_hub_download("CodebaseAi/netraids-ml-models", "realtime_scaler.pkl")
)
darknet_encoder = joblib.load(
    hf_hub_download("CodebaseAi/netraids-ml-models", "realtime_encoder.pkl")
)

# =====================================================
# Prediction Router
# =====================================================
def predict(model_choice, features: dict):
    df = pd.DataFrame([features])

    if model_choice == "CICIDS2018 (13 Classes)":
        # Enforce feature order ONLY if available
        if FEATURE_COLUMNS is not None:
            df = df[FEATURE_COLUMNS]

        pred_idx = rf_model.predict(df)[0]

        if CLASS_MAPPING is not None:
            pred_label = CLASS_MAPPING.get(pred_idx, str(pred_idx))
        else:
            pred_label = str(pred_idx)

        return {
            "dataset": "CICIDS2018",
            "prediction": pred_label
        }

    else:
        X_scaled = darknet_scaler.transform(df)
        pred_encoded = darknet_model.predict(X_scaled)[0]
        pred_label = darknet_encoder.inverse_transform([pred_encoded])[0]

        return {
            "dataset": "BCC-Darknet",
            "prediction": str(pred_label)
        }

# =====================================================
# Gradio UI
# =====================================================
app = gr.Interface(
    fn=predict,
    inputs=[
        gr.Dropdown(
            choices=[
                "CICIDS2018 (13 Classes)",
                "BCC-Darknet (5 Classes)"
            ],
            label="Select Dataset"
        ),
        gr.JSON(label="Network Traffic Features (JSON)")
    ],
    outputs=gr.JSON(label="Detection Result"),
    title="NetraIDS – Dual-Model Network Intrusion Detection",
    description="Robust cloud-deployed ML inference with dataset-specific pipelines"
)

app.launch()