File size: 4,233 Bytes
5d1d43b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""DGA Detection Model using Transformer Encoder.

This model treats domain names as sequences of characters and uses a Transformer
encoder to learn patterns that distinguish DGA (algorithmically generated) domains
from legitimate ones.
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput

from charset import PAD, VOCAB_SIZE

NUM_CLASSES = 2


class DGAEncoder(nn.Module):
    """Transformer encoder for DGA (Domain Generation Algorithm) detection."""

    def __init__(
        self,
        *,
        vocab_size: int,
        max_len: int = 64,
        d_model: int = 256,
        nhead: int = 8,
        num_layers: int = 4,
        dropout: float = 0.1,
        ffn_mult: int = 4,
    ) -> None:
        super().__init__()

        self.tok = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
        self.pos = nn.Embedding(max_len, d_model)

        self.register_buffer(
            "position_ids",
            torch.arange(max_len).unsqueeze(0),
            persistent=False,
        )

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=ffn_mult * d_model,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )

        self.enc = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(d_model)
        self.clf = nn.Linear(d_model, NUM_CLASSES)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the encoder."""
        b, L = x.shape
        pos = self.position_ids[:, :L].expand(b, L)
        h = self.tok(x) + self.pos(pos)
        h = self.enc(h)
        cls = self.norm(h[:, 0])
        return self.clf(cls)


class DGAEncoderConfig(PretrainedConfig):
    """Configuration for DGAEncoder compatible with HuggingFace Transformers."""

    model_type = "dga_encoder"

    def __init__(
        self,
        vocab_size: int = VOCAB_SIZE,
        max_len: int = 64,
        d_model: int = 256,
        nhead: int = 8,
        num_layers: int = 4,
        dropout: float = 0.1,
        ffn_mult: int = 4,
        num_labels: int = 2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.d_model = d_model
        self.nhead = nhead
        self.num_layers = num_layers
        self.dropout = dropout
        self.ffn_mult = ffn_mult
        self.num_labels = num_labels


class DGAEncoderForSequenceClassification(PreTrainedModel):
    """HuggingFace-compatible wrapper around DGAEncoder."""

    config_class = DGAEncoderConfig

    def __init__(self, config: DGAEncoderConfig):
        super().__init__(config)
        self.config = config

        self.encoder = DGAEncoder(
            vocab_size=config.vocab_size,
            max_len=config.max_len,
            d_model=config.d_model,
            nhead=config.nhead,
            num_layers=config.num_layers,
            dropout=config.dropout,
            ffn_mult=config.ffn_mult,
        )

        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        """Forward pass compatible with HF Trainer."""
        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.use_return_dict
        )

        logits = self.encoder(input_ids)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.config.num_labels), labels.view(-1)
            )

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )