argus_sentinel / model.py
Fizcko's picture
Upload v3
59848dd verified
"""
model.py
--------
WAFClassifier — a tiny, CPU-optimised multi-label classifier for HTTP request
threat detection.
Inputs
------
input_ids : LongTensor [B, seq_len] BPE token ids (max 128)
attention_mask : LongTensor [B, seq_len] 1=real token, 0=padding
numeric_features : FloatTensor [B, 6] hand-crafted numeric signals
Outputs
-------
label_probs : FloatTensor [B, 7] per-label sigmoid probabilities
order: clean, xss, sqli, path_traversal, command_injection,
scanner, spam_bot (matches config.json label_names)
risk_score : FloatTensor [B, 1] continuous [0, 1] risk estimate
Design rationale
----------------
- Conv1D encoder: 10-50x faster than self-attention on CPU for short sequences.
Two depthwise-separable-style conv layers capture local n-gram patterns
(SQL keywords, XSS angle-bracket patterns, path traversal dots, etc.)
without the quadratic cost of attention.
- Global max pooling collapses variable sequence length to a fixed vector,
making the ONNX graph fully static-shape-friendly on the channel axis.
- Separate numeric projector for hand-crafted signals (body length, special
char ratios, etc.) that are cheap to compute at request time.
- Fusion MLP kept intentionally small (160→128→64) for sub-3ms CPU inference.
- Two output heads share all representations — no extra compute cost.
- Parameter count target: < 2M. Actual: ~1.3M (see print_param_count()).
- All ops are ONNX opset-17 compatible. No control flow, no Python-level
branching in the forward pass.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Tuple
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Label ordering (canonical — must match data_pipeline.py)
# ---------------------------------------------------------------------------
LABEL_NAMES = [
"clean",
"xss",
"sqli",
"path_traversal",
"command_injection",
"scanner",
"spam_bot",
]
NUM_LABELS = len(LABEL_NAMES) # 7
# ---------------------------------------------------------------------------
# Default config — overridden by config.json at training time
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
"vocab_size": 8192,
"embedding_dim": 128,
"num_numeric_features": 6,
"num_labels": NUM_LABELS,
"dropout": 0.1,
"max_seq_len": 128,
# Conv encoder
"conv_channels": 128,
"conv_kernel_size": 3,
# Fusion MLP
"mlp_hidden": 128,
"mlp_out": 64,
}
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class WAFClassifier(nn.Module):
"""
Low-latency WAF request classifier.
Parameters
----------
config : dict
Must contain the keys defined in DEFAULT_CONFIG.
Load from config_v3.json at training time.
"""
def __init__(self, config: dict) -> None:
super().__init__()
vocab_size = config["vocab_size"]
embedding_dim = config["embedding_dim"]
num_numeric = config["num_numeric_features"]
num_labels = config["num_labels"]
dropout = config["dropout"]
conv_ch = config["conv_channels"]
conv_k = config["conv_kernel_size"]
mlp_hidden = config["mlp_hidden"]
mlp_out = config["mlp_out"]
# ------------------------------------------------------------------
# 1. Token embedding [B, S] → [B, S, embedding_dim]
# padding_idx=0 keeps PAD vectors zeroed and out of gradient flow.
# ------------------------------------------------------------------
self.embedding = nn.Embedding(
vocab_size, embedding_dim, padding_idx=0
)
# ------------------------------------------------------------------
# 2. Lightweight CNN text encoder
# Two Conv1d layers with same-padding preserve sequence length so
# the subsequent global-max-pool can always reduce to [B, ch, 1].
#
# Using BatchNorm1d instead of LayerNorm keeps the inference path
# fast (BN fuses into a single multiply-add per channel in ONNX).
# ------------------------------------------------------------------
pad = conv_k // 2 # "same" padding for odd kernel sizes
self.conv_encoder = nn.Sequential(
# Layer 1: project embedding_dim → conv_ch
nn.Conv1d(embedding_dim, conv_ch, kernel_size=conv_k, padding=pad),
nn.BatchNorm1d(conv_ch),
nn.ReLU(inplace=True),
# Layer 2: refine features, same channel width
nn.Conv1d(conv_ch, conv_ch, kernel_size=conv_k, padding=pad),
nn.BatchNorm1d(conv_ch),
nn.ReLU(inplace=True),
# Global max pool → [B, conv_ch, 1]
nn.AdaptiveMaxPool1d(1),
)
# ------------------------------------------------------------------
# 3. Numeric feature projector [B, num_numeric] → [B, 32]
# Small MLP; 32-dim gives enough capacity without dominating.
# ------------------------------------------------------------------
self.numeric_proj = nn.Sequential(
nn.Linear(num_numeric, 32),
nn.ReLU(inplace=True),
)
# ------------------------------------------------------------------
# 4. Fusion MLP [B, conv_ch+32] → [B, mlp_out]
# Dropout applied before the second layer — only active in training.
# ------------------------------------------------------------------
fusion_in = conv_ch + 32 # 128 + 32 = 160
self.fusion_mlp = nn.Sequential(
nn.Linear(fusion_in, mlp_hidden),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(mlp_hidden, mlp_out),
nn.ReLU(inplace=True),
)
# ------------------------------------------------------------------
# 5. Output heads (no activation — raw logits for training stability)
# Sigmoid is applied in forward() for inference / ONNX export.
# ------------------------------------------------------------------
self.label_head = nn.Linear(mlp_out, num_labels) # → [B, 7] logits
self.risk_head = nn.Linear(mlp_out, 1) # → [B, 1] logit
# ------------------------------------------------------------------
# Weight initialisation
# ------------------------------------------------------------------
self._init_weights()
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
def _init_weights(self) -> None:
"""Kaiming-uniform for linear/conv; uniform for embeddings (default)."""
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv1d)):
nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# ------------------------------------------------------------------
# Forward pass
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor, # [B, S] Long or Int32
attention_mask: torch.Tensor, # [B, S] Long or Int32 (1/0)
numeric_features: torch.Tensor, # [B, 6] Float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns
-------
label_probs : [B, 7] float32, sigmoid-activated per-label probs
risk_score : [B, 1] float32, sigmoid-activated risk in [0, 1]
Notes
-----
input_ids and attention_mask can be int32 (as produced by the
data_pipeline tokenizer) or int64 — both are accepted because
nn.Embedding accepts any integer dtype in PyTorch 2+, and the
explicit .long() cast ensures ONNX opset-17 compatibility.
"""
# -- Token embeddings + mask application -------------------------
x = self.embedding(input_ids.long()) # [B, S, E]
# Zero out padding positions so they cannot contribute to max-pool.
mask = attention_mask.long().unsqueeze(-1).float() # [B, S, 1]
x = x * mask # [B, S, E]
# -- Conv encoder ------------------------------------------------
# Conv1d expects channel-first: [B, E, S]
x = x.permute(0, 2, 1).contiguous() # [B, E, S]
x = self.conv_encoder(x) # [B, conv_ch, 1]
x = x.squeeze(-1) # [B, conv_ch]
# -- Numeric projector -------------------------------------------
n = self.numeric_proj(numeric_features) # [B, 32]
# -- Fusion MLP --------------------------------------------------
combined = torch.cat([x, n], dim=1) # [B, 160]
features = self.fusion_mlp(combined) # [B, 64]
# -- Output heads ------------------------------------------------
label_logits = self.label_head(features) # [B, 7]
label_probs = torch.sigmoid(label_logits) # [B, 7]
risk_logit = self.risk_head(features) # [B, 1]
risk_score = torch.sigmoid(risk_logit) # [B, 1]
return label_probs, risk_score
# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------
def build_model(config: dict | None = None) -> WAFClassifier:
"""Instantiate WAFClassifier from a config dict (or DEFAULT_CONFIG)."""
cfg = DEFAULT_CONFIG.copy()
if config:
cfg.update(config)
return WAFClassifier(cfg)
def load_config(config_path: str | Path) -> dict:
"""Load config.json and merge with DEFAULT_CONFIG."""
cfg = DEFAULT_CONFIG.copy()
path = Path(config_path)
if path.exists():
with open(path, "r") as fh:
overrides = json.load(fh)
cfg.update(overrides)
else:
print(f"[WARN] config.json not found at {path}; using defaults.")
return cfg
def print_param_count(model: nn.Module) -> int:
"""Print and return total trainable parameter count."""
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"WAFClassifier trainable parameters: {total:,}")
# Breakdown by component
breakdown = {
"embedding": sum(p.numel() for p in model.embedding.parameters()),
"conv_encoder": sum(p.numel() for p in model.conv_encoder.parameters()),
"numeric_proj": sum(p.numel() for p in model.numeric_proj.parameters()),
"fusion_mlp": sum(p.numel() for p in model.fusion_mlp.parameters()),
"label_head": sum(p.numel() for p in model.label_head.parameters()),
"risk_head": sum(p.numel() for p in model.risk_head.parameters()),
}
for name, count in breakdown.items():
print(f" {name:<16}: {count:>10,}")
return total
# ---------------------------------------------------------------------------
# Quick sanity check (run directly: python model.py)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(42)
cfg = DEFAULT_CONFIG.copy()
model = WAFClassifier(cfg)
model.eval()
total = print_param_count(model)
assert total < 2_000_000, f"Model too large: {total:,} params"
B, S = 4, 128
ids = torch.randint(0, cfg["vocab_size"], (B, S))
mask = torch.ones(B, S, dtype=torch.long)
mask[:, 100:] = 0 # simulate padding
num = torch.randn(B, cfg["num_numeric_features"])
with torch.no_grad():
probs, risk = model(ids, mask, num)
assert probs.shape == (B, NUM_LABELS), f"Bad probs shape: {probs.shape}"
assert risk.shape == (B, 1), f"Bad risk shape: {risk.shape}"
assert probs.min() >= 0.0 and probs.max() <= 1.0
assert risk.min() >= 0.0 and risk.max() <= 1.0
print(f"\nForward pass OK | label_probs: {probs.shape} risk_score: {risk.shape}")
print(f"Label probs (first example): {probs[0].tolist()}")
print(f"Risk score (first example): {risk[0].item():.4f}")