LiteTAE: Lightweight Transformer Autoencoder for Network Zero-Day Attack Detection
LiteTAE is an unsupervised, lightweight Transformer-based Autoencoder designed for Network Intrusion Detection Systems (NIDS). It is optimized to detect complex, zero-day attacks with high precision and an extremely low False Positive Rate (FPR).
The model evaluates reconstruction anomalies using a high-resolution Top-K Feature Error approach combined with a Temporal Moving Average (Smoothing) to eliminate background network noise.
Final Performance Results
Evaluated on the complex CSE-CIC-IDS2018 dataset, the optimized configuration achieved the following benchmark results without any supervised training:
| Metric | Value | Status |
|---|---|---|
| AUC-ROC | 97.62% | Near-Ideal Separation |
| Accuracy | 92.92% | High Overall Reliability |
| Precision | 93.63% | High Trust in Alerts |
| TPR (Recall) | 90.58% | Catches 9 out of 10 Attacks |
| FPR (False Positive Rate) | 5.14% | Production-Ready (Low Noise) |
Confusion Matrix
- True Positives (TP): 22,646 (Attacks Correctly Identified)
- True Negatives (TN): 28,459 (Normal Traffic Correctly Identified)
- False Positives (FP): 1,541 (Normal Traffic Flagged as Attack)
- False Negatives (FN): 2,354 (Attacks Missed)
How to Use & Inference Script
To run inference using the pre-trained weights (lite_tae_full.pt), you must apply the Top-K Feature Selection (K=5) and Smoothing (Window=5) filters that optimized this model.
import torch
import numpy as np
import pandas as pd
# 1. Load the entire package
checkpoint = torch.load('lite_tae_full.pt', map_location=torch.device('cpu'))
# 2. Extract hyperparameters and metadata
threshold = checkpoint['threshold']
scaler_center = np.array(checkpoint['scaler_c'])
scaler_scale = np.array(checkpoint['scaler_s'])
feat_cols = checkpoint['feat_cols']
print(f"Successfully loaded LiteTAE model.")
print(f"Configured Operational Threshold: {threshold:.6f}")
# 3. High-Resolution Reconstruction Error Function
def compute_lite_tae_errors(original_tensor, reconstructed_tensor, top_k=5):
# Calculate squared error for each feature independently
per_feature_error = (reconstructed_tensor - original_tensor) ** 2
# Extract only the top-K highest feature anomalies
top_errs, _ = torch.topk(per_feature_error, k=top_k, dim=1)
return top_errs.mean(dim=1).detach().cpu().numpy()
# 4. Temporal Smoothing Filter
def apply_temporal_smoothing(errors, window=5):
return pd.Series(errors).rolling(window=window, center=True, min_periods=1).mean().values
pip install torch numpy pandas scikit-learn