File size: 10,919 Bytes
304df69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec9367
 
304df69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Gradio app for NSL-KDD binary intrusion detection demo (MVP)
Expecting these files in the same repo/root of the Space:
  - nsl_kdd_tf_model.h5       (optional; if present will be used)
  - scaler.pkl                (optional; sklearn StandardScaler, must match model training)
  - columns.json              (optional; list of feature column names used by the model)

If artifacts are missing, the app will instruct you how to add them and offers a quick fallback
where you can upload a CSV and the app will train a lightweight sklearn model for demo purposes.
"""

import os
import json
import tempfile
import traceback
from typing import Tuple, List

import numpy as np
import pandas as pd

import gradio as gr

# optional heavy import guarded
TF_AVAILABLE = True
try:
    import tensorflow as tf
except Exception:
    TF_AVAILABLE = False

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
import joblib

# artifact filenames
MODEL_FILE = "nsl_kdd_tf_model.h5"
SCALER_FILE = "scaler.pkl"
COLUMNS_FILE = "columns.json"

# helper: load artifacts if exist
def load_artifacts():
    model = None
    scaler = None
    columns = None
    model_type = None

    # load columns.json if present
    if os.path.exists(COLUMNS_FILE):
        with open(COLUMNS_FILE, "r", encoding="utf-8") as f:
            columns = json.load(f)

    # load scaler if present
    if os.path.exists(SCALER_FILE):
        try:
            scaler = joblib.load(SCALER_FILE)
        except Exception:
            try:
                scaler = joblib.load(open(SCALER_FILE, "rb"))
            except Exception:
                scaler = None

    # load TF model if present and TF available
    if os.path.exists(MODEL_FILE) and TF_AVAILABLE:
        try:
            model = tf.keras.models.load_model(MODEL_FILE)
            model_type = "tensorflow"
        except Exception:
            model = None

    return model, scaler, columns, model_type

MODEL, SCALER, COLUMNS, MODEL_TYPE = load_artifacts()

def model_available_message() -> str:
    if MODEL is not None and SCALER is not None and COLUMNS is not None:
        return "✅ Pretrained TensorFlow model and artifacts loaded. Ready to predict."
    pieces = []
    if MODEL is None:
        pieces.append(f"Missing `{MODEL_FILE}`")
    if SCALER is None:
        pieces.append(f"Missing `{SCALER_FILE}`")
    if COLUMNS is None:
        pieces.append(f"Missing `{COLUMNS_FILE}`")
    msg = "⚠️ Artifacts missing: " + ", ".join(pieces) + ".\n\n"
    msg += "To run the TF model, add those files to the Space repository (same folder as app.py).\n"
    msg += "Alternatively, upload a CSV of NSL-KDD records (the app will train a quick sklearn model for demo).\n\n"
    msg += "columns.json should be a JSON array of feature names that match the model input (same as X_train.columns).\n"
    return msg

# utility: preprocess input dataframe into model-ready X using columns & scaler
def prepare_X_from_df(df: pd.DataFrame, expected_columns: List[str], scaler_obj) -> np.ndarray:
    # Align columns: fill missing with 0
    X = df.reindex(columns=expected_columns, fill_value=0)
    # Ensure numeric type
    X = X.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    if scaler_obj is not None:
        Xs = scaler_obj.transform(X)
    else:
        # if no scaler provided, return raw numpy
        Xs = X.values.astype(np.float32)
    return Xs

def predict_batch_from_df(df: pd.DataFrame) -> Tuple[pd.DataFrame, str]:
    """
    returns (result_df, status_message)
    result_df contains prob and predicted class per row
    """
    try:
        if MODEL is not None and SCALER is not None and COLUMNS is not None and MODEL_TYPE == "tensorflow":
            Xs = prepare_X_from_df(df, COLUMNS, SCALER)
            probs = MODEL.predict(Xs).ravel()
            preds = (probs >= 0.5).astype(int)
            out = df.copy()
            out["_pred_prob"] = probs
            out["_pred_class"] = preds
            return out, "Predictions from TensorFlow model"
        else:
            # fallback: train a quick logistic regression on uploaded data if contains label
            if 'label' in df.columns or 'label_bin' in df.columns:
                # If label present, run quick preprocess similar to notebook: create X (one-hot for cats)
                # Identify expected categorical columns if present
                cats = ['protocol_type', 'service', 'flag']
                col_names = df.columns.tolist()
                # We'll try to mimic preprocess from notebook: numeric vs cats
                num_cols = [c for c in col_names if c not in cats + ['label','label_bin']]
                X_num = df[num_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0)
                X_cat = pd.get_dummies(df[cats], drop_first=True)
                X = pd.concat([X_num, X_cat], axis=1)
                y = df['label_bin'] if 'label_bin' in df.columns else df['label'].apply(lambda s: 0 if str(s).strip().lower()=="normal" else 1)
                # minimal scaler + logistic
                scaler_local = StandardScaler()
                Xs = scaler_local.fit_transform(X)
                clf = LogisticRegression(max_iter=200)
                clf.fit(Xs, y)
                probs = clf.predict_proba(Xs)[:,1]
                preds = (probs >= 0.5).astype(int)
                out = df.copy()
                out["_pred_prob"] = probs
                out["_pred_class"] = preds
                return out, "Trained temporary LogisticRegression on uploaded CSV (used 'label' or 'label_bin' for training)."
            else:
                return pd.DataFrame(), "Cannot fallback: artifacts missing and uploaded CSV does not contain 'label' or 'label_bin' to train a temporary model."
    except Exception as e:
        tb = traceback.format_exc()
        return pd.DataFrame(), f"Prediction error: {e}\n\n{tb}"

def predict_single(sample_text: str) -> str:
    """
    sample_text: CSV row or JSON dict representing one row with same columns as columns.json
    returns a readable string with probability and class
    """
    try:
        if not sample_text:
            return "No input provided."
        # try JSON first
        try:
            d = json.loads(sample_text)
            if isinstance(d, dict):
                df = pd.DataFrame([d])
            else:
                return "JSON must represent an object/dict for single sample."
        except Exception:
            # try CSV row
            try:
                df = pd.read_csv(pd.compat.StringIO(sample_text), header=None)
                # if no header, user probably pasted values: cannot map to columns
                if COLUMNS is not None and df.shape[1] == len(COLUMNS):
                    df.columns = COLUMNS
                else:
                    return "CSV input detected but header/column count mismatch. Prefer JSON object keyed by column names."
            except Exception:
                return "Could not parse input. Paste a JSON object like {\"duration\":0, \"protocol_type\":\"tcp\", ...} or upload a CSV row with header."

        # Now we have df; run batch predict logic but for a single row
        if MODEL is not None and SCALER is not None and COLUMNS is not None and MODEL_TYPE == "tensorflow":
            Xs = prepare_X_from_df(df, COLUMNS, SCALER)
            prob = float(MODEL.predict(Xs)[0,0])
            pred = int(prob >= 0.5)
            return f"Pred prob: {prob:.4f} — predicted class: {pred} (0=normal, 1=attack)"
        else:
            return "Model artifacts not present in Space. Upload `nsl_kdd_tf_model.h5`, `scaler.pkl`, and `columns.json` to use the TensorFlow model. Alternatively upload a labelled CSV to train a quick demo model."
    except Exception as e:
        tb = traceback.format_exc()
        return f"Error: {e}\n\n{tb}"

# Gradio UI components
with gr.Blocks(title="NSL-KDD Intrusion Detection — Demo MVP") as demo:
    gr.Markdown("# NSL-KDD Intrusion Detection — Demo (MVP)\n"
                "Upload your artifacts (`nsl_kdd_tf_model.h5`, `scaler.pkl`, `columns.json`) to the Space to use the TensorFlow model.\n"
                "Or upload a labelled CSV (contains `label` or `label_bin`) and the app will train a quick logistic regression for demo.\n\n"
                "Columns expected: the original notebook used 41 numeric features with one-hot for `protocol_type`, `service`, `flag`.\n"
                )
    status = gr.Textbox(label="Status / Artifact check", value=model_available_message(), interactive=False)
    with gr.Row():
        with gr.Column(scale=2):
            file_input = gr.File(label="Upload CSV for batch prediction or for training fallback", file_types=['.csv'])
            sample_input = gr.Textbox(label="Single-sample input (JSON object)", placeholder='{"duration":0, "protocol_type":"tcp", ...}', lines=6)
            predict_button = gr.Button("Predict single sample")
            batch_button = gr.Button("Run batch (on uploaded CSV)")

        with gr.Column(scale=1):
            out_table = gr.Dataframe(label="Batch predictions (if any)")

            single_out = gr.Textbox(label="Single sample result", interactive=False)

    # Example / help
    example_text = json.dumps({
        "duration": 0,
        "protocol_type": "tcp",
        "service": "http",
        "flag": "SF",
        "src_bytes": 181,
        "dst_bytes": 5450
    }, indent=2)
    gr.Markdown("**Example single-sample JSON (fill in more NSL-KDD fields if you have them):**")
    gr.Code(example_text, language="json")

    # Callbacks
    def on_predict_single(sample_text):
        return predict_single(sample_text)

    def on_batch_predict(file_obj):
        if file_obj is None:
            return pd.DataFrame(), "No file uploaded."
        try:
            # read uploaded CSV into DataFrame
            df = pd.read_csv(file_obj.name)
        except Exception:
            try:
                # fallback: try bytes
                df = pd.read_csv(file_obj)
            except Exception as e:
                return pd.DataFrame(), f"Could not read CSV: {e}"

        out_df, msg = predict_batch_from_df(df)
        if out_df.empty:
            return pd.DataFrame(), msg
        # Limit columns shown for readability
        display_df = out_df.copy()
        # move prediction columns to front if present
        for c in ["_pred_prob", "_pred_class"]:
            if c in display_df.columns:
                cols = [c] + [x for x in display_df.columns if x != c]
                display_df = display_df[cols]
        return display_df, msg

    predict_button.click(on_predict_single, inputs=[sample_input], outputs=[single_out])
    batch_button.click(on_batch_predict, inputs=[file_input], outputs=[out_table, status])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))