LiteTAE: Lightweight Transformer Autoencoder for Network Zero-Day Attack Detection

LiteTAE is an unsupervised, lightweight Transformer-based Autoencoder designed for Network Intrusion Detection Systems (NIDS). It is optimized to detect complex, zero-day attacks with high precision and an extremely low False Positive Rate (FPR).

The model evaluates reconstruction anomalies using a high-resolution Top-K Feature Error approach combined with a Temporal Moving Average (Smoothing) to eliminate background network noise.

Final Performance Results

Evaluated on the complex CSE-CIC-IDS2018 dataset, the optimized configuration achieved the following benchmark results without any supervised training:

Metric Value Status
AUC-ROC 97.62% Near-Ideal Separation
Accuracy 92.92% High Overall Reliability
Precision 93.63% High Trust in Alerts
TPR (Recall) 90.58% Catches 9 out of 10 Attacks
FPR (False Positive Rate) 5.14% Production-Ready (Low Noise)

Confusion Matrix

  • True Positives (TP): 22,646 (Attacks Correctly Identified)
  • True Negatives (TN): 28,459 (Normal Traffic Correctly Identified)
  • False Positives (FP): 1,541 (Normal Traffic Flagged as Attack)
  • False Negatives (FN): 2,354 (Attacks Missed)

How to Use & Inference Script

To run inference using the pre-trained weights (lite_tae_full.pt), you must apply the Top-K Feature Selection (K=5) and Smoothing (Window=5) filters that optimized this model.

import torch
import numpy as np
import pandas as pd

# 1. Load the entire package
checkpoint = torch.load('lite_tae_full.pt', map_location=torch.device('cpu'))

# 2. Extract hyperparameters and metadata
threshold = checkpoint['threshold']
scaler_center = np.array(checkpoint['scaler_c'])
scaler_scale = np.array(checkpoint['scaler_s'])
feat_cols = checkpoint['feat_cols']

print(f"Successfully loaded LiteTAE model.")
print(f"Configured Operational Threshold: {threshold:.6f}")

# 3. High-Resolution Reconstruction Error Function
def compute_lite_tae_errors(original_tensor, reconstructed_tensor, top_k=5):
    # Calculate squared error for each feature independently
    per_feature_error = (reconstructed_tensor - original_tensor) ** 2
    
    # Extract only the top-K highest feature anomalies
    top_errs, _ = torch.topk(per_feature_error, k=top_k, dim=1)
    return top_errs.mean(dim=1).detach().cpu().numpy()

# 4. Temporal Smoothing Filter
def apply_temporal_smoothing(errors, window=5):
    return pd.Series(errors).rolling(window=window, center=True, min_periods=1).mean().values


pip install torch numpy pandas scikit-learn
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support