Reynier commited on
Commit
d852973
·
verified ·
1 Parent(s): 75ef31c

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +103 -0
model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DGA-BiLSTM: Bidirectional LSTM + Self-Attention for DGA detection.
3
+ Based on Namgung et al. (Security and Communication Networks, 2021).
4
+ Trained on 54 DGA families.
5
+ """
6
+ import string
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ CHARS = string.ascii_lowercase + string.digits + "-._"
11
+ CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
12
+ VOCAB_SIZE = len(CHARS) + 1 # 40
13
+ MAXLEN = 75
14
+ EMBED_DIM = 32
15
+ BILSTM_DIM = 128 # per direction -> 256 total
16
+ DROPOUT = 0.5
17
+ FC_HIDDEN = 64
18
+
19
+
20
+ def encode_domain(domain: str) -> list:
21
+ domain = str(domain).lower().strip()
22
+ encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]]
23
+ return encoded + [0] * (MAXLEN - len(encoded))
24
+
25
+
26
+ class SelfAttention(nn.Module):
27
+ def __init__(self, hidden_size: int):
28
+ super().__init__()
29
+ self.attn = nn.Linear(hidden_size, 1, bias=False)
30
+
31
+ def forward(self, lstm_out):
32
+ scores = self.attn(lstm_out)
33
+ weights = torch.softmax(scores, dim=1)
34
+ return (weights * lstm_out).sum(dim=1)
35
+
36
+
37
+ class BiLSTMAttention(nn.Module):
38
+ """
39
+ Namgung et al. 2021:
40
+ Embedding -> BiLSTM(128) -> Self-Attention -> FC(64) -> ReLU -> Dropout(0.5) -> sigmoid
41
+ """
42
+ def __init__(self):
43
+ super().__init__()
44
+ self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM, padding_idx=0)
45
+ self.bilstm = nn.LSTM(
46
+ input_size=EMBED_DIM,
47
+ hidden_size=BILSTM_DIM,
48
+ batch_first=True,
49
+ bidirectional=True,
50
+ )
51
+ bilstm_out = BILSTM_DIM * 2 # 256
52
+ self.attention = SelfAttention(bilstm_out)
53
+ self.fc = nn.Sequential(
54
+ nn.Linear(bilstm_out, FC_HIDDEN),
55
+ nn.ReLU(),
56
+ nn.Dropout(DROPOUT),
57
+ nn.Linear(FC_HIDDEN, 1),
58
+ )
59
+
60
+ def forward(self, x):
61
+ emb = self.embedding(x)
62
+ out, _ = self.bilstm(emb)
63
+ context = self.attention(out)
64
+ return self.fc(context).squeeze(1)
65
+
66
+
67
+ def load_model(weights_path: str, device: str = None):
68
+ """Load trained model from a local weights path."""
69
+ if device is None:
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ model = BiLSTMAttention()
72
+ model.load_state_dict(torch.load(weights_path, map_location=device))
73
+ model.to(device)
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def predict(model, domains, device: str = None, batch_size: int = 256):
79
+ """
80
+ Predict DGA vs legit for a list of domain strings.
81
+ Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
82
+ """
83
+ if device is None:
84
+ device = next(model.parameters()).device
85
+ if isinstance(domains, str):
86
+ domains = [domains]
87
+
88
+ results = []
89
+ for i in range(0, len(domains), batch_size):
90
+ batch = domains[i : i + batch_size]
91
+ encoded = [encode_domain(d) for d in batch]
92
+ x = torch.tensor(encoded, dtype=torch.long).to(device)
93
+ with torch.no_grad():
94
+ logits = model(x)
95
+ scores = torch.sigmoid(logits).cpu().tolist()
96
+ preds = [1 if s >= 0.5 else 0 for s in scores]
97
+ for domain, pred, score in zip(batch, preds, scores):
98
+ results.append({
99
+ "domain": domain,
100
+ "label": "dga" if pred == 1 else "legit",
101
+ "score": round(score, 4),
102
+ })
103
+ return results