File size: 3,287 Bytes
d852973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
DGA-BiLSTM: Bidirectional LSTM + Self-Attention for DGA detection.
Based on Namgung et al. (Security and Communication Networks, 2021).
Trained on 54 DGA families.
"""
import string
import torch
import torch.nn as nn

CHARS = string.ascii_lowercase + string.digits + "-._"
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
VOCAB_SIZE = len(CHARS) + 1  # 40
MAXLEN = 75
EMBED_DIM = 32
BILSTM_DIM = 128  # per direction -> 256 total
DROPOUT = 0.5
FC_HIDDEN = 64


def encode_domain(domain: str) -> list:
    domain = str(domain).lower().strip()
    encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]]
    return encoded + [0] * (MAXLEN - len(encoded))


class SelfAttention(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.attn = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, lstm_out):
        scores = self.attn(lstm_out)
        weights = torch.softmax(scores, dim=1)
        return (weights * lstm_out).sum(dim=1)


class BiLSTMAttention(nn.Module):
    """
    Namgung et al. 2021:
    Embedding -> BiLSTM(128) -> Self-Attention -> FC(64) -> ReLU -> Dropout(0.5) -> sigmoid
    """
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM, padding_idx=0)
        self.bilstm = nn.LSTM(
            input_size=EMBED_DIM,
            hidden_size=BILSTM_DIM,
            batch_first=True,
            bidirectional=True,
        )
        bilstm_out = BILSTM_DIM * 2  # 256
        self.attention = SelfAttention(bilstm_out)
        self.fc = nn.Sequential(
            nn.Linear(bilstm_out, FC_HIDDEN),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(FC_HIDDEN, 1),
        )

    def forward(self, x):
        emb = self.embedding(x)
        out, _ = self.bilstm(emb)
        context = self.attention(out)
        return self.fc(context).squeeze(1)


def load_model(weights_path: str, device: str = None):
    """Load trained model from a local weights path."""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    model = BiLSTMAttention()
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device)
    model.eval()
    return model


def predict(model, domains, device: str = None, batch_size: int = 256):
    """
    Predict DGA vs legit for a list of domain strings.
    Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
    """
    if device is None:
        device = next(model.parameters()).device
    if isinstance(domains, str):
        domains = [domains]

    results = []
    for i in range(0, len(domains), batch_size):
        batch = domains[i : i + batch_size]
        encoded = [encode_domain(d) for d in batch]
        x = torch.tensor(encoded, dtype=torch.long).to(device)
        with torch.no_grad():
            logits = model(x)
            scores = torch.sigmoid(logits).cpu().tolist()
            preds = [1 if s >= 0.5 else 0 for s in scores]
        for domain, pred, score in zip(batch, preds, scores):
            results.append({
                "domain": domain,
                "label": "dga" if pred == 1 else "legit",
                "score": round(score, 4),
            })
    return results