File size: 10,930 Bytes
26c425c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | """DGA Detection Model using Transformer Encoder.
This model treats domain names as sequences of characters and uses a Transformer
encoder to learn patterns that distinguish DGA (algorithmically generated) domains
from legitimate ones.
Key design decisions:
1. Character-level tokenization: Captures subword patterns that LSTMs miss
- DGAs often have unusual character n-grams (e.g., "xkwj", "qmzo")
- Character level avoids OOV issues with new DGA families
2. Pre-LN Transformer: Modern architecture that's easier to train
- More stable gradients than Post-LN (original Transformer)
- No need for learning rate warmup
- Can go deeper without tricks
3. [CLS] token pooling: Standard approach for sequence classification
- Transformer learns to aggregate sequence info into [CLS]
- Better than mean/max pooling empirically
4. Learned positional embeddings: Domain structure is important
- TLD patterns (last few chars)
- Subdomain patterns (first few chars)
- Learned embeddings capture this better than fixed sinusoids for short seqs
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from .charset import PAD, VOCAB_SIZE
from .config import PROFILES
NUM_CLASSES = 2
# ------------------------------
# Core encoder (Pre-LayerNorm)
# ------------------------------
class DGAEncoder(nn.Module):
"""
Transformer encoder for DGA (Domain Generation Algorithm) detection.
Architecture overview:
1. Token + Position embeddings
2. Transformer encoder (Pre-LN variant)
3. Classification head on [CLS] token
Design choices:
- Pre-LN (Layer Norm before attention): More stable training, doesn't need warmup
- Positional embeddings (learned): Capture character position importance
- [CLS] token pooling: Standard for sequence classification, better than mean pooling
"""
def __init__(
self,
*,
vocab_size: int,
max_len: int = 64,
d_model: int = 256,
nhead: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
ffn_mult: int = 4,
) -> None:
super().__init__()
# Token embeddings: Convert character IDs to dense vectors
# padding_idx=PAD tells the embedding to zero out padding tokens
# This prevents the model from learning anything from pad tokens
self.tok = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
# Positional embeddings: Learned position encodings (not sinusoidal)
# Each position gets its own learned embedding vector
# For domain names, position matters (e.g., TLD vs subdomain patterns)
self.pos = nn.Embedding(max_len, d_model)
# Register position IDs as a buffer (not a parameter, but moves with model to GPU)
# This is just [0, 1, 2, ..., max_len-1] repeated for batching
self.register_buffer(
"position_ids",
torch.arange(max_len).unsqueeze(0),
persistent=False, # Don't save in checkpoint, we can recreate it
)
# Transformer Encoder Layer with Pre-LN architecture
# Pre-LN (norm_first=True) is more stable than Post-LN:
# - Gradients flow better (less vanishing gradient issues)
# - No need for learning rate warmup
# - Can train deeper models without special initialization tricks
#
# ffn_mult=4 means FFN hidden dim = 4 * d_model (standard Transformer ratio)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=ffn_mult * d_model,
dropout=dropout,
batch_first=True, # Expect input as (batch, seq, features)
norm_first=True, # Pre-LN: LayerNorm before attention (more stable!)
)
# Stack multiple encoder layers
# Each layer does: Self-Attention -> FFN
# With Pre-LN, each sublayer is: LN -> Sublayer -> Residual
self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# Final LayerNorm on [CLS] token output
# This normalizes the representation before classification
# Helps with training stability and generalization
self.norm = nn.LayerNorm(d_model)
# Classification head: Simple linear layer
# Maps [CLS] representation (d_model) to class logits (NUM_CLASSES)
# No activation here - we'll use CrossEntropyLoss which applies softmax
self.clf = nn.Linear(d_model, NUM_CLASSES)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the encoder.
x: (B, L) token ids with CLS at index 0
Steps:
1. Look up token embeddings and add positional embeddings
2. Pass through transformer encoder layers
3. Extract [CLS] token (position 0) and normalize
4. Project to class logits
"""
b, L = x.shape # b = batch size, L = sequence length
# Expand position IDs to match batch size
# pos will be [[0,1,2,...,L-1], [0,1,2,...,L-1], ...] for batch
pos = self.position_ids[:, :L].expand(b, L)
# Token + position embeddings
# This is element-wise addition (broadcasting works because both are (B, L, d_model))
# Each position gets its own learned offset added to the token embedding
h = self.tok(x) + self.pos(pos) # h = hidden states (embeddings)
# Pass through transformer encoder
# Self-attention allows each character to attend to all other characters
# This captures long-range dependencies (e.g., suffix patterns, character distributions)
h = self.enc(h) # h = transformed hidden states
# Extract and normalize the [CLS] token representation
# [CLS] is always at position 0 in our encoding scheme
# The transformer has learned to aggregate sequence information into [CLS]
cls = self.norm(
h[:, 0]
) # cls = normalized [CLS] token (sequence representation)
# Project to class logits (benign vs DGA)
return self.clf(cls)
class DGAEncoderConfig(PretrainedConfig):
"""Configuration for DGAEncoder compatible with HuggingFace Transformers.
can be saved/loaded using HF's standard save_pretrained()
and from_pretrained() methods.
"""
model_type = "dga_encoder"
def __init__(
self,
vocab_size: int = VOCAB_SIZE,
max_len: int = 64,
d_model: int = 256,
nhead: int = 8,
num_layers: int = 4,
dropout: float = 0.1,
ffn_mult: int = 4,
num_labels: int = 2, # Binary classification: DGA vs Normal
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_len = max_len
self.d_model = d_model
self.nhead = nhead
self.num_layers = num_layers
self.dropout = dropout
self.ffn_mult = ffn_mult
self.num_labels = num_labels
class DGAEncoderForSequenceClassification(PreTrainedModel):
"""HuggingFace-compatible wrapper around DGAEncoder.
This enables:
- Automatic checkpoint management via Trainer
- save_pretrained() / from_pretrained() methods
- Integration with HF ecosystem (datasets, evaluate, etc.)
- W&B logging via Trainer's report_to="wandb"
"""
config_class = DGAEncoderConfig
def __init__(self, config: DGAEncoderConfig):
super().__init__(config)
self.config = config
self.encoder = DGAEncoder(
vocab_size=config.vocab_size,
max_len=config.max_len,
d_model=config.d_model,
nhead=config.nhead,
num_layers=config.num_layers,
dropout=config.dropout,
ffn_mult=config.ffn_mult,
)
# Initialize weights (HF convention)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
"""Forward pass compatible with HF Trainer.
Args:
input_ids: Token IDs (B, L) with CLS at index 0
attention_mask: Not used (padding handled by PAD token automatically)
labels: Ground truth labels for classification (B,)
return_dict: Whether to return SequenceClassifierOutput
Returns:
SequenceClassifierOutput or tuple with loss and logits
Note on loss computation:
- CrossEntropyLoss combines LogSoftmax + NLLLoss
- It expects raw logits (no softmax applied) and class indices
- Automatically handles the softmax internally for numerical stability
"""
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# Forward through the existing encoder
# This calls DGAEncoder.forward() which returns (B, NUM_CLASSES) logits
logits = self.encoder(input_ids)
# Compute loss if labels provided (training mode)
# CrossEntropyLoss expects:
# - Input: (N, C) where C is number of classes
# - Target: (N,) with class indices in [0, C-1]
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.num_labels), labels.view(-1)
)
# Return format depends on return_dict flag
# HF Trainer expects return_dict=True by default
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None, # Could add intermediate layer outputs here
attentions=None, # Could add attention weights here for visualization
)
def build_model(size: str = "tiny") -> DGAEncoderForSequenceClassification:
"""
model = build_model("tiny")
model.save_pretrained("./my_model")
loaded = DGAEncoderForSequenceClassification.from_pretrained("./my_model")
"""
prof = PROFILES[size]
config = DGAEncoderConfig(
vocab_size=VOCAB_SIZE,
max_len=prof.max_len,
d_model=prof.d_model,
nhead=prof.nhead,
num_layers=prof.num_layers,
dropout=prof.dropout,
ffn_mult=prof.ffn_mult,
num_labels=2, # Binary classification
)
return DGAEncoderForSequenceClassification(config)
__all__ = [
"DGAEncoderConfig",
"DGAEncoderForSequenceClassification",
"build_model",
]
|