| """ |
| model.py |
| -------- |
| WAFClassifier — a tiny, CPU-optimised multi-label classifier for HTTP request |
| threat detection. |
| |
| Inputs |
| ------ |
| input_ids : LongTensor [B, seq_len] BPE token ids (max 128) |
| attention_mask : LongTensor [B, seq_len] 1=real token, 0=padding |
| numeric_features : FloatTensor [B, 6] hand-crafted numeric signals |
| |
| Outputs |
| ------- |
| label_probs : FloatTensor [B, 7] per-label sigmoid probabilities |
| order: clean, xss, sqli, path_traversal, command_injection, |
| scanner, spam_bot (matches config.json label_names) |
| risk_score : FloatTensor [B, 1] continuous [0, 1] risk estimate |
| |
| Design rationale |
| ---------------- |
| - Conv1D encoder: 10-50x faster than self-attention on CPU for short sequences. |
| Two depthwise-separable-style conv layers capture local n-gram patterns |
| (SQL keywords, XSS angle-bracket patterns, path traversal dots, etc.) |
| without the quadratic cost of attention. |
| - Global max pooling collapses variable sequence length to a fixed vector, |
| making the ONNX graph fully static-shape-friendly on the channel axis. |
| - Separate numeric projector for hand-crafted signals (body length, special |
| char ratios, etc.) that are cheap to compute at request time. |
| - Fusion MLP kept intentionally small (160→128→64) for sub-3ms CPU inference. |
| - Two output heads share all representations — no extra compute cost. |
| - Parameter count target: < 2M. Actual: ~1.3M (see print_param_count()). |
| - All ops are ONNX opset-17 compatible. No control flow, no Python-level |
| branching in the forward pass. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| |
| |
| |
| LABEL_NAMES = [ |
| "clean", |
| "xss", |
| "sqli", |
| "path_traversal", |
| "command_injection", |
| "scanner", |
| "spam_bot", |
| ] |
| NUM_LABELS = len(LABEL_NAMES) |
|
|
| |
| |
| |
| DEFAULT_CONFIG = { |
| "vocab_size": 8192, |
| "embedding_dim": 128, |
| "num_numeric_features": 6, |
| "num_labels": NUM_LABELS, |
| "dropout": 0.1, |
| "max_seq_len": 128, |
| |
| "conv_channels": 128, |
| "conv_kernel_size": 3, |
| |
| "mlp_hidden": 128, |
| "mlp_out": 64, |
| } |
|
|
|
|
| |
| |
| |
| class WAFClassifier(nn.Module): |
| """ |
| Low-latency WAF request classifier. |
| |
| Parameters |
| ---------- |
| config : dict |
| Must contain the keys defined in DEFAULT_CONFIG. |
| Load from config_v3.json at training time. |
| """ |
|
|
| def __init__(self, config: dict) -> None: |
| super().__init__() |
|
|
| vocab_size = config["vocab_size"] |
| embedding_dim = config["embedding_dim"] |
| num_numeric = config["num_numeric_features"] |
| num_labels = config["num_labels"] |
| dropout = config["dropout"] |
| conv_ch = config["conv_channels"] |
| conv_k = config["conv_kernel_size"] |
| mlp_hidden = config["mlp_hidden"] |
| mlp_out = config["mlp_out"] |
|
|
| |
| |
| |
| |
| self.embedding = nn.Embedding( |
| vocab_size, embedding_dim, padding_idx=0 |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| pad = conv_k // 2 |
|
|
| self.conv_encoder = nn.Sequential( |
| |
| nn.Conv1d(embedding_dim, conv_ch, kernel_size=conv_k, padding=pad), |
| nn.BatchNorm1d(conv_ch), |
| nn.ReLU(inplace=True), |
| |
| nn.Conv1d(conv_ch, conv_ch, kernel_size=conv_k, padding=pad), |
| nn.BatchNorm1d(conv_ch), |
| nn.ReLU(inplace=True), |
| |
| nn.AdaptiveMaxPool1d(1), |
| ) |
|
|
| |
| |
| |
| |
| self.numeric_proj = nn.Sequential( |
| nn.Linear(num_numeric, 32), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| |
| |
| |
| fusion_in = conv_ch + 32 |
|
|
| self.fusion_mlp = nn.Sequential( |
| nn.Linear(fusion_in, mlp_hidden), |
| nn.ReLU(inplace=True), |
| nn.Dropout(p=dropout), |
| nn.Linear(mlp_hidden, mlp_out), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| |
| |
| |
| self.label_head = nn.Linear(mlp_out, num_labels) |
| self.risk_head = nn.Linear(mlp_out, 1) |
|
|
| |
| |
| |
| self._init_weights() |
|
|
| |
| |
| |
| def _init_weights(self) -> None: |
| """Kaiming-uniform for linear/conv; uniform for embeddings (default).""" |
| for module in self.modules(): |
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.BatchNorm1d): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| |
| |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| numeric_features: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns |
| ------- |
| label_probs : [B, 7] float32, sigmoid-activated per-label probs |
| risk_score : [B, 1] float32, sigmoid-activated risk in [0, 1] |
| |
| Notes |
| ----- |
| input_ids and attention_mask can be int32 (as produced by the |
| data_pipeline tokenizer) or int64 — both are accepted because |
| nn.Embedding accepts any integer dtype in PyTorch 2+, and the |
| explicit .long() cast ensures ONNX opset-17 compatibility. |
| """ |
| |
| x = self.embedding(input_ids.long()) |
| |
| mask = attention_mask.long().unsqueeze(-1).float() |
| x = x * mask |
|
|
| |
| |
| x = x.permute(0, 2, 1).contiguous() |
| x = self.conv_encoder(x) |
| x = x.squeeze(-1) |
|
|
| |
| n = self.numeric_proj(numeric_features) |
|
|
| |
| combined = torch.cat([x, n], dim=1) |
| features = self.fusion_mlp(combined) |
|
|
| |
| label_logits = self.label_head(features) |
| label_probs = torch.sigmoid(label_logits) |
|
|
| risk_logit = self.risk_head(features) |
| risk_score = torch.sigmoid(risk_logit) |
|
|
| return label_probs, risk_score |
|
|
|
|
| |
| |
| |
|
|
| def build_model(config: dict | None = None) -> WAFClassifier: |
| """Instantiate WAFClassifier from a config dict (or DEFAULT_CONFIG).""" |
| cfg = DEFAULT_CONFIG.copy() |
| if config: |
| cfg.update(config) |
| return WAFClassifier(cfg) |
|
|
|
|
| def load_config(config_path: str | Path) -> dict: |
| """Load config.json and merge with DEFAULT_CONFIG.""" |
| cfg = DEFAULT_CONFIG.copy() |
| path = Path(config_path) |
| if path.exists(): |
| with open(path, "r") as fh: |
| overrides = json.load(fh) |
| cfg.update(overrides) |
| else: |
| print(f"[WARN] config.json not found at {path}; using defaults.") |
| return cfg |
|
|
|
|
| def print_param_count(model: nn.Module) -> int: |
| """Print and return total trainable parameter count.""" |
| total = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"WAFClassifier trainable parameters: {total:,}") |
| |
| breakdown = { |
| "embedding": sum(p.numel() for p in model.embedding.parameters()), |
| "conv_encoder": sum(p.numel() for p in model.conv_encoder.parameters()), |
| "numeric_proj": sum(p.numel() for p in model.numeric_proj.parameters()), |
| "fusion_mlp": sum(p.numel() for p in model.fusion_mlp.parameters()), |
| "label_head": sum(p.numel() for p in model.label_head.parameters()), |
| "risk_head": sum(p.numel() for p in model.risk_head.parameters()), |
| } |
| for name, count in breakdown.items(): |
| print(f" {name:<16}: {count:>10,}") |
| return total |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| torch.manual_seed(42) |
| cfg = DEFAULT_CONFIG.copy() |
| model = WAFClassifier(cfg) |
| model.eval() |
|
|
| total = print_param_count(model) |
| assert total < 2_000_000, f"Model too large: {total:,} params" |
|
|
| B, S = 4, 128 |
| ids = torch.randint(0, cfg["vocab_size"], (B, S)) |
| mask = torch.ones(B, S, dtype=torch.long) |
| mask[:, 100:] = 0 |
| num = torch.randn(B, cfg["num_numeric_features"]) |
|
|
| with torch.no_grad(): |
| probs, risk = model(ids, mask, num) |
|
|
| assert probs.shape == (B, NUM_LABELS), f"Bad probs shape: {probs.shape}" |
| assert risk.shape == (B, 1), f"Bad risk shape: {risk.shape}" |
| assert probs.min() >= 0.0 and probs.max() <= 1.0 |
| assert risk.min() >= 0.0 and risk.max() <= 1.0 |
|
|
| print(f"\nForward pass OK | label_probs: {probs.shape} risk_score: {risk.shape}") |
| print(f"Label probs (first example): {probs[0].tolist()}") |
| print(f"Risk score (first example): {risk[0].item():.4f}") |
|
|