ccss17 commited on
Commit
26c425c
·
verified ·
1 Parent(s): bf3475e

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +305 -0
model.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DGA Detection Model using Transformer Encoder.
2
+
3
+ This model treats domain names as sequences of characters and uses a Transformer
4
+ encoder to learn patterns that distinguish DGA (algorithmically generated) domains
5
+ from legitimate ones.
6
+
7
+ Key design decisions:
8
+ 1. Character-level tokenization: Captures subword patterns that LSTMs miss
9
+ - DGAs often have unusual character n-grams (e.g., "xkwj", "qmzo")
10
+ - Character level avoids OOV issues with new DGA families
11
+
12
+ 2. Pre-LN Transformer: Modern architecture that's easier to train
13
+ - More stable gradients than Post-LN (original Transformer)
14
+ - No need for learning rate warmup
15
+ - Can go deeper without tricks
16
+
17
+ 3. [CLS] token pooling: Standard approach for sequence classification
18
+ - Transformer learns to aggregate sequence info into [CLS]
19
+ - Better than mean/max pooling empirically
20
+
21
+ 4. Learned positional embeddings: Domain structure is important
22
+ - TLD patterns (last few chars)
23
+ - Subdomain patterns (first few chars)
24
+ - Learned embeddings capture this better than fixed sinusoids for short seqs
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from typing import Optional
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ from transformers import PreTrainedModel, PretrainedConfig
34
+ from transformers.modeling_outputs import SequenceClassifierOutput
35
+
36
+ from .charset import PAD, VOCAB_SIZE
37
+ from .config import PROFILES
38
+
39
+ NUM_CLASSES = 2
40
+
41
+
42
+ # ------------------------------
43
+ # Core encoder (Pre-LayerNorm)
44
+ # ------------------------------
45
+ class DGAEncoder(nn.Module):
46
+ """
47
+ Transformer encoder for DGA (Domain Generation Algorithm) detection.
48
+
49
+ Architecture overview:
50
+ 1. Token + Position embeddings
51
+ 2. Transformer encoder (Pre-LN variant)
52
+ 3. Classification head on [CLS] token
53
+
54
+ Design choices:
55
+ - Pre-LN (Layer Norm before attention): More stable training, doesn't need warmup
56
+ - Positional embeddings (learned): Capture character position importance
57
+ - [CLS] token pooling: Standard for sequence classification, better than mean pooling
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ *,
63
+ vocab_size: int,
64
+ max_len: int = 64,
65
+ d_model: int = 256,
66
+ nhead: int = 8,
67
+ num_layers: int = 4,
68
+ dropout: float = 0.1,
69
+ ffn_mult: int = 4,
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ # Token embeddings: Convert character IDs to dense vectors
74
+ # padding_idx=PAD tells the embedding to zero out padding tokens
75
+ # This prevents the model from learning anything from pad tokens
76
+ self.tok = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
77
+
78
+ # Positional embeddings: Learned position encodings (not sinusoidal)
79
+ # Each position gets its own learned embedding vector
80
+ # For domain names, position matters (e.g., TLD vs subdomain patterns)
81
+ self.pos = nn.Embedding(max_len, d_model)
82
+
83
+ # Register position IDs as a buffer (not a parameter, but moves with model to GPU)
84
+ # This is just [0, 1, 2, ..., max_len-1] repeated for batching
85
+ self.register_buffer(
86
+ "position_ids",
87
+ torch.arange(max_len).unsqueeze(0),
88
+ persistent=False, # Don't save in checkpoint, we can recreate it
89
+ )
90
+
91
+ # Transformer Encoder Layer with Pre-LN architecture
92
+ # Pre-LN (norm_first=True) is more stable than Post-LN:
93
+ # - Gradients flow better (less vanishing gradient issues)
94
+ # - No need for learning rate warmup
95
+ # - Can train deeper models without special initialization tricks
96
+ #
97
+ # ffn_mult=4 means FFN hidden dim = 4 * d_model (standard Transformer ratio)
98
+ enc_layer = nn.TransformerEncoderLayer(
99
+ d_model=d_model,
100
+ nhead=nhead,
101
+ dim_feedforward=ffn_mult * d_model,
102
+ dropout=dropout,
103
+ batch_first=True, # Expect input as (batch, seq, features)
104
+ norm_first=True, # Pre-LN: LayerNorm before attention (more stable!)
105
+ )
106
+
107
+ # Stack multiple encoder layers
108
+ # Each layer does: Self-Attention -> FFN
109
+ # With Pre-LN, each sublayer is: LN -> Sublayer -> Residual
110
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
111
+
112
+ # Final LayerNorm on [CLS] token output
113
+ # This normalizes the representation before classification
114
+ # Helps with training stability and generalization
115
+ self.norm = nn.LayerNorm(d_model)
116
+
117
+ # Classification head: Simple linear layer
118
+ # Maps [CLS] representation (d_model) to class logits (NUM_CLASSES)
119
+ # No activation here - we'll use CrossEntropyLoss which applies softmax
120
+ self.clf = nn.Linear(d_model, NUM_CLASSES)
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Forward pass through the encoder.
125
+
126
+ x: (B, L) token ids with CLS at index 0
127
+
128
+ Steps:
129
+ 1. Look up token embeddings and add positional embeddings
130
+ 2. Pass through transformer encoder layers
131
+ 3. Extract [CLS] token (position 0) and normalize
132
+ 4. Project to class logits
133
+ """
134
+ b, L = x.shape # b = batch size, L = sequence length
135
+
136
+ # Expand position IDs to match batch size
137
+ # pos will be [[0,1,2,...,L-1], [0,1,2,...,L-1], ...] for batch
138
+ pos = self.position_ids[:, :L].expand(b, L)
139
+
140
+ # Token + position embeddings
141
+ # This is element-wise addition (broadcasting works because both are (B, L, d_model))
142
+ # Each position gets its own learned offset added to the token embedding
143
+ h = self.tok(x) + self.pos(pos) # h = hidden states (embeddings)
144
+
145
+ # Pass through transformer encoder
146
+ # Self-attention allows each character to attend to all other characters
147
+ # This captures long-range dependencies (e.g., suffix patterns, character distributions)
148
+ h = self.enc(h) # h = transformed hidden states
149
+
150
+ # Extract and normalize the [CLS] token representation
151
+ # [CLS] is always at position 0 in our encoding scheme
152
+ # The transformer has learned to aggregate sequence information into [CLS]
153
+ cls = self.norm(
154
+ h[:, 0]
155
+ ) # cls = normalized [CLS] token (sequence representation)
156
+
157
+ # Project to class logits (benign vs DGA)
158
+ return self.clf(cls)
159
+
160
+
161
+ class DGAEncoderConfig(PretrainedConfig):
162
+ """Configuration for DGAEncoder compatible with HuggingFace Transformers.
163
+
164
+ can be saved/loaded using HF's standard save_pretrained()
165
+ and from_pretrained() methods.
166
+ """
167
+
168
+ model_type = "dga_encoder"
169
+
170
+ def __init__(
171
+ self,
172
+ vocab_size: int = VOCAB_SIZE,
173
+ max_len: int = 64,
174
+ d_model: int = 256,
175
+ nhead: int = 8,
176
+ num_layers: int = 4,
177
+ dropout: float = 0.1,
178
+ ffn_mult: int = 4,
179
+ num_labels: int = 2, # Binary classification: DGA vs Normal
180
+ **kwargs,
181
+ ):
182
+ super().__init__(**kwargs)
183
+ self.vocab_size = vocab_size
184
+ self.max_len = max_len
185
+ self.d_model = d_model
186
+ self.nhead = nhead
187
+ self.num_layers = num_layers
188
+ self.dropout = dropout
189
+ self.ffn_mult = ffn_mult
190
+ self.num_labels = num_labels
191
+
192
+
193
+ class DGAEncoderForSequenceClassification(PreTrainedModel):
194
+ """HuggingFace-compatible wrapper around DGAEncoder.
195
+
196
+ This enables:
197
+ - Automatic checkpoint management via Trainer
198
+ - save_pretrained() / from_pretrained() methods
199
+ - Integration with HF ecosystem (datasets, evaluate, etc.)
200
+ - W&B logging via Trainer's report_to="wandb"
201
+ """
202
+
203
+ config_class = DGAEncoderConfig
204
+
205
+ def __init__(self, config: DGAEncoderConfig):
206
+ super().__init__(config)
207
+ self.config = config
208
+
209
+ self.encoder = DGAEncoder(
210
+ vocab_size=config.vocab_size,
211
+ max_len=config.max_len,
212
+ d_model=config.d_model,
213
+ nhead=config.nhead,
214
+ num_layers=config.num_layers,
215
+ dropout=config.dropout,
216
+ ffn_mult=config.ffn_mult,
217
+ )
218
+
219
+ # Initialize weights (HF convention)
220
+ self.post_init()
221
+
222
+ def forward(
223
+ self,
224
+ input_ids: torch.Tensor,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ labels: Optional[torch.Tensor] = None,
227
+ return_dict: Optional[bool] = None,
228
+ **kwargs,
229
+ ):
230
+ """Forward pass compatible with HF Trainer.
231
+
232
+ Args:
233
+ input_ids: Token IDs (B, L) with CLS at index 0
234
+ attention_mask: Not used (padding handled by PAD token automatically)
235
+ labels: Ground truth labels for classification (B,)
236
+ return_dict: Whether to return SequenceClassifierOutput
237
+
238
+ Returns:
239
+ SequenceClassifierOutput or tuple with loss and logits
240
+
241
+ Note on loss computation:
242
+ - CrossEntropyLoss combines LogSoftmax + NLLLoss
243
+ - It expects raw logits (no softmax applied) and class indices
244
+ - Automatically handles the softmax internally for numerical stability
245
+ """
246
+ return_dict = (
247
+ return_dict
248
+ if return_dict is not None
249
+ else self.config.use_return_dict
250
+ )
251
+
252
+ # Forward through the existing encoder
253
+ # This calls DGAEncoder.forward() which returns (B, NUM_CLASSES) logits
254
+ logits = self.encoder(input_ids)
255
+
256
+ # Compute loss if labels provided (training mode)
257
+ # CrossEntropyLoss expects:
258
+ # - Input: (N, C) where C is number of classes
259
+ # - Target: (N,) with class indices in [0, C-1]
260
+ loss = None
261
+ if labels is not None:
262
+ loss_fct = nn.CrossEntropyLoss()
263
+ loss = loss_fct(
264
+ logits.view(-1, self.config.num_labels), labels.view(-1)
265
+ )
266
+
267
+ # Return format depends on return_dict flag
268
+ # HF Trainer expects return_dict=True by default
269
+ if not return_dict:
270
+ output = (logits,)
271
+ return ((loss,) + output) if loss is not None else output
272
+
273
+ return SequenceClassifierOutput(
274
+ loss=loss,
275
+ logits=logits,
276
+ hidden_states=None, # Could add intermediate layer outputs here
277
+ attentions=None, # Could add attention weights here for visualization
278
+ )
279
+
280
+
281
+ def build_model(size: str = "tiny") -> DGAEncoderForSequenceClassification:
282
+ """
283
+ model = build_model("tiny")
284
+ model.save_pretrained("./my_model")
285
+ loaded = DGAEncoderForSequenceClassification.from_pretrained("./my_model")
286
+ """
287
+ prof = PROFILES[size]
288
+ config = DGAEncoderConfig(
289
+ vocab_size=VOCAB_SIZE,
290
+ max_len=prof.max_len,
291
+ d_model=prof.d_model,
292
+ nhead=prof.nhead,
293
+ num_layers=prof.num_layers,
294
+ dropout=prof.dropout,
295
+ ffn_mult=prof.ffn_mult,
296
+ num_labels=2, # Binary classification
297
+ )
298
+ return DGAEncoderForSequenceClassification(config)
299
+
300
+
301
+ __all__ = [
302
+ "DGAEncoderConfig",
303
+ "DGAEncoderForSequenceClassification",
304
+ "build_model",
305
+ ]