File size: 12,772 Bytes
59848dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
model.py
--------
WAFClassifier β€” a tiny, CPU-optimised multi-label classifier for HTTP request
threat detection.

Inputs
------
  input_ids        : LongTensor  [B, seq_len]   BPE token ids (max 128)
  attention_mask   : LongTensor  [B, seq_len]   1=real token, 0=padding
  numeric_features : FloatTensor [B, 6]         hand-crafted numeric signals

Outputs
-------
  label_probs : FloatTensor [B, 7]   per-label sigmoid probabilities
                order: clean, xss, sqli, path_traversal, command_injection,
                       scanner, spam_bot  (matches config.json label_names)
  risk_score  : FloatTensor [B, 1]   continuous [0, 1] risk estimate

Design rationale
----------------
- Conv1D encoder: 10-50x faster than self-attention on CPU for short sequences.
  Two depthwise-separable-style conv layers capture local n-gram patterns
  (SQL keywords, XSS angle-bracket patterns, path traversal dots, etc.)
  without the quadratic cost of attention.
- Global max pooling collapses variable sequence length to a fixed vector,
  making the ONNX graph fully static-shape-friendly on the channel axis.
- Separate numeric projector for hand-crafted signals (body length, special
  char ratios, etc.) that are cheap to compute at request time.
- Fusion MLP kept intentionally small (160β†’128β†’64) for sub-3ms CPU inference.
- Two output heads share all representations β€” no extra compute cost.
- Parameter count target: < 2M.  Actual: ~1.3M (see print_param_count()).
- All ops are ONNX opset-17 compatible.  No control flow, no Python-level
  branching in the forward pass.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn

# ---------------------------------------------------------------------------
# Label ordering (canonical β€” must match data_pipeline.py)
# ---------------------------------------------------------------------------
LABEL_NAMES = [
    "clean",
    "xss",
    "sqli",
    "path_traversal",
    "command_injection",
    "scanner",
    "spam_bot",
]
NUM_LABELS = len(LABEL_NAMES)  # 7

# ---------------------------------------------------------------------------
# Default config β€” overridden by config.json at training time
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
    "vocab_size": 8192,
    "embedding_dim": 128,
    "num_numeric_features": 6,
    "num_labels": NUM_LABELS,
    "dropout": 0.1,
    "max_seq_len": 128,
    # Conv encoder
    "conv_channels": 128,
    "conv_kernel_size": 3,
    # Fusion MLP
    "mlp_hidden": 128,
    "mlp_out": 64,
}


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class WAFClassifier(nn.Module):
    """
    Low-latency WAF request classifier.

    Parameters
    ----------
    config : dict
        Must contain the keys defined in DEFAULT_CONFIG.
        Load from config_v3.json at training time.
    """

    def __init__(self, config: dict) -> None:
        super().__init__()

        vocab_size          = config["vocab_size"]
        embedding_dim       = config["embedding_dim"]
        num_numeric         = config["num_numeric_features"]
        num_labels          = config["num_labels"]
        dropout             = config["dropout"]
        conv_ch             = config["conv_channels"]
        conv_k              = config["conv_kernel_size"]
        mlp_hidden          = config["mlp_hidden"]
        mlp_out             = config["mlp_out"]

        # ------------------------------------------------------------------
        # 1. Token embedding  [B, S] β†’ [B, S, embedding_dim]
        #    padding_idx=0 keeps PAD vectors zeroed and out of gradient flow.
        # ------------------------------------------------------------------
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim, padding_idx=0
        )

        # ------------------------------------------------------------------
        # 2. Lightweight CNN text encoder
        #    Two Conv1d layers with same-padding preserve sequence length so
        #    the subsequent global-max-pool can always reduce to [B, ch, 1].
        #
        #    Using BatchNorm1d instead of LayerNorm keeps the inference path
        #    fast (BN fuses into a single multiply-add per channel in ONNX).
        # ------------------------------------------------------------------
        pad = conv_k // 2  # "same" padding for odd kernel sizes

        self.conv_encoder = nn.Sequential(
            # Layer 1: project embedding_dim β†’ conv_ch
            nn.Conv1d(embedding_dim, conv_ch, kernel_size=conv_k, padding=pad),
            nn.BatchNorm1d(conv_ch),
            nn.ReLU(inplace=True),
            # Layer 2: refine features, same channel width
            nn.Conv1d(conv_ch, conv_ch, kernel_size=conv_k, padding=pad),
            nn.BatchNorm1d(conv_ch),
            nn.ReLU(inplace=True),
            # Global max pool β†’ [B, conv_ch, 1]
            nn.AdaptiveMaxPool1d(1),
        )

        # ------------------------------------------------------------------
        # 3. Numeric feature projector  [B, num_numeric] β†’ [B, 32]
        #    Small MLP; 32-dim gives enough capacity without dominating.
        # ------------------------------------------------------------------
        self.numeric_proj = nn.Sequential(
            nn.Linear(num_numeric, 32),
            nn.ReLU(inplace=True),
        )

        # ------------------------------------------------------------------
        # 4. Fusion MLP  [B, conv_ch+32] β†’ [B, mlp_out]
        #    Dropout applied before the second layer β€” only active in training.
        # ------------------------------------------------------------------
        fusion_in = conv_ch + 32  # 128 + 32 = 160

        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_in, mlp_hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(mlp_hidden, mlp_out),
            nn.ReLU(inplace=True),
        )

        # ------------------------------------------------------------------
        # 5. Output heads  (no activation β€” raw logits for training stability)
        #    Sigmoid is applied in forward() for inference / ONNX export.
        # ------------------------------------------------------------------
        self.label_head = nn.Linear(mlp_out, num_labels)  # β†’ [B, 7] logits
        self.risk_head  = nn.Linear(mlp_out, 1)           # β†’ [B, 1] logit

        # ------------------------------------------------------------------
        # Weight initialisation
        # ------------------------------------------------------------------
        self._init_weights()

    # ------------------------------------------------------------------
    # Initialisation
    # ------------------------------------------------------------------
    def _init_weights(self) -> None:
        """Kaiming-uniform for linear/conv; uniform for embeddings (default)."""
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv1d)):
                nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    # ------------------------------------------------------------------
    # Forward pass
    # ------------------------------------------------------------------
    def forward(
        self,
        input_ids: torch.Tensor,        # [B, S]  Long or Int32
        attention_mask: torch.Tensor,   # [B, S]  Long or Int32 (1/0)
        numeric_features: torch.Tensor, # [B, 6]  Float
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns
        -------
        label_probs : [B, 7]  float32, sigmoid-activated per-label probs
        risk_score  : [B, 1]  float32, sigmoid-activated risk in [0, 1]

        Notes
        -----
        input_ids and attention_mask can be int32 (as produced by the
        data_pipeline tokenizer) or int64 β€” both are accepted because
        nn.Embedding accepts any integer dtype in PyTorch 2+, and the
        explicit .long() cast ensures ONNX opset-17 compatibility.
        """
        # -- Token embeddings + mask application -------------------------
        x = self.embedding(input_ids.long())        # [B, S, E]
        # Zero out padding positions so they cannot contribute to max-pool.
        mask = attention_mask.long().unsqueeze(-1).float()  # [B, S, 1]
        x = x * mask                                 # [B, S, E]

        # -- Conv encoder ------------------------------------------------
        # Conv1d expects channel-first: [B, E, S]
        x = x.permute(0, 2, 1).contiguous()         # [B, E, S]
        x = self.conv_encoder(x)                     # [B, conv_ch, 1]
        x = x.squeeze(-1)                            # [B, conv_ch]

        # -- Numeric projector -------------------------------------------
        n = self.numeric_proj(numeric_features)      # [B, 32]

        # -- Fusion MLP --------------------------------------------------
        combined = torch.cat([x, n], dim=1)          # [B, 160]
        features = self.fusion_mlp(combined)         # [B, 64]

        # -- Output heads ------------------------------------------------
        label_logits = self.label_head(features)     # [B, 7]
        label_probs  = torch.sigmoid(label_logits)   # [B, 7]

        risk_logit   = self.risk_head(features)      # [B, 1]
        risk_score   = torch.sigmoid(risk_logit)     # [B, 1]

        return label_probs, risk_score


# ---------------------------------------------------------------------------
# Helper utilities
# ---------------------------------------------------------------------------

def build_model(config: dict | None = None) -> WAFClassifier:
    """Instantiate WAFClassifier from a config dict (or DEFAULT_CONFIG)."""
    cfg = DEFAULT_CONFIG.copy()
    if config:
        cfg.update(config)
    return WAFClassifier(cfg)


def load_config(config_path: str | Path) -> dict:
    """Load config.json and merge with DEFAULT_CONFIG."""
    cfg = DEFAULT_CONFIG.copy()
    path = Path(config_path)
    if path.exists():
        with open(path, "r") as fh:
            overrides = json.load(fh)
        cfg.update(overrides)
    else:
        print(f"[WARN] config.json not found at {path}; using defaults.")
    return cfg


def print_param_count(model: nn.Module) -> int:
    """Print and return total trainable parameter count."""
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"WAFClassifier trainable parameters: {total:,}")
    # Breakdown by component
    breakdown = {
        "embedding":     sum(p.numel() for p in model.embedding.parameters()),
        "conv_encoder":  sum(p.numel() for p in model.conv_encoder.parameters()),
        "numeric_proj":  sum(p.numel() for p in model.numeric_proj.parameters()),
        "fusion_mlp":    sum(p.numel() for p in model.fusion_mlp.parameters()),
        "label_head":    sum(p.numel() for p in model.label_head.parameters()),
        "risk_head":     sum(p.numel() for p in model.risk_head.parameters()),
    }
    for name, count in breakdown.items():
        print(f"  {name:<16}: {count:>10,}")
    return total


# ---------------------------------------------------------------------------
# Quick sanity check (run directly: python model.py)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    torch.manual_seed(42)
    cfg = DEFAULT_CONFIG.copy()
    model = WAFClassifier(cfg)
    model.eval()

    total = print_param_count(model)
    assert total < 2_000_000, f"Model too large: {total:,} params"

    B, S = 4, 128
    ids   = torch.randint(0, cfg["vocab_size"], (B, S))
    mask  = torch.ones(B, S, dtype=torch.long)
    mask[:, 100:] = 0  # simulate padding
    num   = torch.randn(B, cfg["num_numeric_features"])

    with torch.no_grad():
        probs, risk = model(ids, mask, num)

    assert probs.shape == (B, NUM_LABELS), f"Bad probs shape: {probs.shape}"
    assert risk.shape  == (B, 1),          f"Bad risk shape: {risk.shape}"
    assert probs.min() >= 0.0 and probs.max() <= 1.0
    assert risk.min()  >= 0.0 and risk.max()  <= 1.0

    print(f"\nForward pass OK  |  label_probs: {probs.shape}  risk_score: {risk.shape}")
    print(f"Label probs (first example): {probs[0].tolist()}")
    print(f"Risk score  (first example): {risk[0].item():.4f}")