""" model.py -------- WAFClassifier — a tiny, CPU-optimised multi-label classifier for HTTP request threat detection. Inputs ------ input_ids : LongTensor [B, seq_len] BPE token ids (max 128) attention_mask : LongTensor [B, seq_len] 1=real token, 0=padding numeric_features : FloatTensor [B, 6] hand-crafted numeric signals Outputs ------- label_probs : FloatTensor [B, 7] per-label sigmoid probabilities order: clean, xss, sqli, path_traversal, command_injection, scanner, spam_bot (matches config.json label_names) risk_score : FloatTensor [B, 1] continuous [0, 1] risk estimate Design rationale ---------------- - Conv1D encoder: 10-50x faster than self-attention on CPU for short sequences. Two depthwise-separable-style conv layers capture local n-gram patterns (SQL keywords, XSS angle-bracket patterns, path traversal dots, etc.) without the quadratic cost of attention. - Global max pooling collapses variable sequence length to a fixed vector, making the ONNX graph fully static-shape-friendly on the channel axis. - Separate numeric projector for hand-crafted signals (body length, special char ratios, etc.) that are cheap to compute at request time. - Fusion MLP kept intentionally small (160→128→64) for sub-3ms CPU inference. - Two output heads share all representations — no extra compute cost. - Parameter count target: < 2M. Actual: ~1.3M (see print_param_count()). - All ops are ONNX opset-17 compatible. No control flow, no Python-level branching in the forward pass. """ from __future__ import annotations import json from pathlib import Path from typing import Tuple import torch import torch.nn as nn # --------------------------------------------------------------------------- # Label ordering (canonical — must match data_pipeline.py) # --------------------------------------------------------------------------- LABEL_NAMES = [ "clean", "xss", "sqli", "path_traversal", "command_injection", "scanner", "spam_bot", ] NUM_LABELS = len(LABEL_NAMES) # 7 # --------------------------------------------------------------------------- # Default config — overridden by config.json at training time # --------------------------------------------------------------------------- DEFAULT_CONFIG = { "vocab_size": 8192, "embedding_dim": 128, "num_numeric_features": 6, "num_labels": NUM_LABELS, "dropout": 0.1, "max_seq_len": 128, # Conv encoder "conv_channels": 128, "conv_kernel_size": 3, # Fusion MLP "mlp_hidden": 128, "mlp_out": 64, } # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class WAFClassifier(nn.Module): """ Low-latency WAF request classifier. Parameters ---------- config : dict Must contain the keys defined in DEFAULT_CONFIG. Load from config_v3.json at training time. """ def __init__(self, config: dict) -> None: super().__init__() vocab_size = config["vocab_size"] embedding_dim = config["embedding_dim"] num_numeric = config["num_numeric_features"] num_labels = config["num_labels"] dropout = config["dropout"] conv_ch = config["conv_channels"] conv_k = config["conv_kernel_size"] mlp_hidden = config["mlp_hidden"] mlp_out = config["mlp_out"] # ------------------------------------------------------------------ # 1. Token embedding [B, S] → [B, S, embedding_dim] # padding_idx=0 keeps PAD vectors zeroed and out of gradient flow. # ------------------------------------------------------------------ self.embedding = nn.Embedding( vocab_size, embedding_dim, padding_idx=0 ) # ------------------------------------------------------------------ # 2. Lightweight CNN text encoder # Two Conv1d layers with same-padding preserve sequence length so # the subsequent global-max-pool can always reduce to [B, ch, 1]. # # Using BatchNorm1d instead of LayerNorm keeps the inference path # fast (BN fuses into a single multiply-add per channel in ONNX). # ------------------------------------------------------------------ pad = conv_k // 2 # "same" padding for odd kernel sizes self.conv_encoder = nn.Sequential( # Layer 1: project embedding_dim → conv_ch nn.Conv1d(embedding_dim, conv_ch, kernel_size=conv_k, padding=pad), nn.BatchNorm1d(conv_ch), nn.ReLU(inplace=True), # Layer 2: refine features, same channel width nn.Conv1d(conv_ch, conv_ch, kernel_size=conv_k, padding=pad), nn.BatchNorm1d(conv_ch), nn.ReLU(inplace=True), # Global max pool → [B, conv_ch, 1] nn.AdaptiveMaxPool1d(1), ) # ------------------------------------------------------------------ # 3. Numeric feature projector [B, num_numeric] → [B, 32] # Small MLP; 32-dim gives enough capacity without dominating. # ------------------------------------------------------------------ self.numeric_proj = nn.Sequential( nn.Linear(num_numeric, 32), nn.ReLU(inplace=True), ) # ------------------------------------------------------------------ # 4. Fusion MLP [B, conv_ch+32] → [B, mlp_out] # Dropout applied before the second layer — only active in training. # ------------------------------------------------------------------ fusion_in = conv_ch + 32 # 128 + 32 = 160 self.fusion_mlp = nn.Sequential( nn.Linear(fusion_in, mlp_hidden), nn.ReLU(inplace=True), nn.Dropout(p=dropout), nn.Linear(mlp_hidden, mlp_out), nn.ReLU(inplace=True), ) # ------------------------------------------------------------------ # 5. Output heads (no activation — raw logits for training stability) # Sigmoid is applied in forward() for inference / ONNX export. # ------------------------------------------------------------------ self.label_head = nn.Linear(mlp_out, num_labels) # → [B, 7] logits self.risk_head = nn.Linear(mlp_out, 1) # → [B, 1] logit # ------------------------------------------------------------------ # Weight initialisation # ------------------------------------------------------------------ self._init_weights() # ------------------------------------------------------------------ # Initialisation # ------------------------------------------------------------------ def _init_weights(self) -> None: """Kaiming-uniform for linear/conv; uniform for embeddings (default).""" for module in self.modules(): if isinstance(module, (nn.Linear, nn.Conv1d)): nn.init.kaiming_uniform_(module.weight, nonlinearity="relu") if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.BatchNorm1d): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) # ------------------------------------------------------------------ # Forward pass # ------------------------------------------------------------------ def forward( self, input_ids: torch.Tensor, # [B, S] Long or Int32 attention_mask: torch.Tensor, # [B, S] Long or Int32 (1/0) numeric_features: torch.Tensor, # [B, 6] Float ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns ------- label_probs : [B, 7] float32, sigmoid-activated per-label probs risk_score : [B, 1] float32, sigmoid-activated risk in [0, 1] Notes ----- input_ids and attention_mask can be int32 (as produced by the data_pipeline tokenizer) or int64 — both are accepted because nn.Embedding accepts any integer dtype in PyTorch 2+, and the explicit .long() cast ensures ONNX opset-17 compatibility. """ # -- Token embeddings + mask application ------------------------- x = self.embedding(input_ids.long()) # [B, S, E] # Zero out padding positions so they cannot contribute to max-pool. mask = attention_mask.long().unsqueeze(-1).float() # [B, S, 1] x = x * mask # [B, S, E] # -- Conv encoder ------------------------------------------------ # Conv1d expects channel-first: [B, E, S] x = x.permute(0, 2, 1).contiguous() # [B, E, S] x = self.conv_encoder(x) # [B, conv_ch, 1] x = x.squeeze(-1) # [B, conv_ch] # -- Numeric projector ------------------------------------------- n = self.numeric_proj(numeric_features) # [B, 32] # -- Fusion MLP -------------------------------------------------- combined = torch.cat([x, n], dim=1) # [B, 160] features = self.fusion_mlp(combined) # [B, 64] # -- Output heads ------------------------------------------------ label_logits = self.label_head(features) # [B, 7] label_probs = torch.sigmoid(label_logits) # [B, 7] risk_logit = self.risk_head(features) # [B, 1] risk_score = torch.sigmoid(risk_logit) # [B, 1] return label_probs, risk_score # --------------------------------------------------------------------------- # Helper utilities # --------------------------------------------------------------------------- def build_model(config: dict | None = None) -> WAFClassifier: """Instantiate WAFClassifier from a config dict (or DEFAULT_CONFIG).""" cfg = DEFAULT_CONFIG.copy() if config: cfg.update(config) return WAFClassifier(cfg) def load_config(config_path: str | Path) -> dict: """Load config.json and merge with DEFAULT_CONFIG.""" cfg = DEFAULT_CONFIG.copy() path = Path(config_path) if path.exists(): with open(path, "r") as fh: overrides = json.load(fh) cfg.update(overrides) else: print(f"[WARN] config.json not found at {path}; using defaults.") return cfg def print_param_count(model: nn.Module) -> int: """Print and return total trainable parameter count.""" total = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"WAFClassifier trainable parameters: {total:,}") # Breakdown by component breakdown = { "embedding": sum(p.numel() for p in model.embedding.parameters()), "conv_encoder": sum(p.numel() for p in model.conv_encoder.parameters()), "numeric_proj": sum(p.numel() for p in model.numeric_proj.parameters()), "fusion_mlp": sum(p.numel() for p in model.fusion_mlp.parameters()), "label_head": sum(p.numel() for p in model.label_head.parameters()), "risk_head": sum(p.numel() for p in model.risk_head.parameters()), } for name, count in breakdown.items(): print(f" {name:<16}: {count:>10,}") return total # --------------------------------------------------------------------------- # Quick sanity check (run directly: python model.py) # --------------------------------------------------------------------------- if __name__ == "__main__": torch.manual_seed(42) cfg = DEFAULT_CONFIG.copy() model = WAFClassifier(cfg) model.eval() total = print_param_count(model) assert total < 2_000_000, f"Model too large: {total:,} params" B, S = 4, 128 ids = torch.randint(0, cfg["vocab_size"], (B, S)) mask = torch.ones(B, S, dtype=torch.long) mask[:, 100:] = 0 # simulate padding num = torch.randn(B, cfg["num_numeric_features"]) with torch.no_grad(): probs, risk = model(ids, mask, num) assert probs.shape == (B, NUM_LABELS), f"Bad probs shape: {probs.shape}" assert risk.shape == (B, 1), f"Bad risk shape: {risk.shape}" assert probs.min() >= 0.0 and probs.max() <= 1.0 assert risk.min() >= 0.0 and risk.max() <= 1.0 print(f"\nForward pass OK | label_probs: {probs.shape} risk_score: {risk.shape}") print(f"Label probs (first example): {probs[0].tolist()}") print(f"Risk score (first example): {risk[0].item():.4f}")