dga-bilstm / model.py
Reynier's picture
Upload model.py with huggingface_hub
d852973 verified
"""
DGA-BiLSTM: Bidirectional LSTM + Self-Attention for DGA detection.
Based on Namgung et al. (Security and Communication Networks, 2021).
Trained on 54 DGA families.
"""
import string
import torch
import torch.nn as nn
CHARS = string.ascii_lowercase + string.digits + "-._"
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)}
VOCAB_SIZE = len(CHARS) + 1 # 40
MAXLEN = 75
EMBED_DIM = 32
BILSTM_DIM = 128 # per direction -> 256 total
DROPOUT = 0.5
FC_HIDDEN = 64
def encode_domain(domain: str) -> list:
domain = str(domain).lower().strip()
encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]]
return encoded + [0] * (MAXLEN - len(encoded))
class SelfAttention(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.attn = nn.Linear(hidden_size, 1, bias=False)
def forward(self, lstm_out):
scores = self.attn(lstm_out)
weights = torch.softmax(scores, dim=1)
return (weights * lstm_out).sum(dim=1)
class BiLSTMAttention(nn.Module):
"""
Namgung et al. 2021:
Embedding -> BiLSTM(128) -> Self-Attention -> FC(64) -> ReLU -> Dropout(0.5) -> sigmoid
"""
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM, padding_idx=0)
self.bilstm = nn.LSTM(
input_size=EMBED_DIM,
hidden_size=BILSTM_DIM,
batch_first=True,
bidirectional=True,
)
bilstm_out = BILSTM_DIM * 2 # 256
self.attention = SelfAttention(bilstm_out)
self.fc = nn.Sequential(
nn.Linear(bilstm_out, FC_HIDDEN),
nn.ReLU(),
nn.Dropout(DROPOUT),
nn.Linear(FC_HIDDEN, 1),
)
def forward(self, x):
emb = self.embedding(x)
out, _ = self.bilstm(emb)
context = self.attention(out)
return self.fc(context).squeeze(1)
def load_model(weights_path: str, device: str = None):
"""Load trained model from a local weights path."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BiLSTMAttention()
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
return model
def predict(model, domains, device: str = None, batch_size: int = 256):
"""
Predict DGA vs legit for a list of domain strings.
Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
"""
if device is None:
device = next(model.parameters()).device
if isinstance(domains, str):
domains = [domains]
results = []
for i in range(0, len(domains), batch_size):
batch = domains[i : i + batch_size]
encoded = [encode_domain(d) for d in batch]
x = torch.tensor(encoded, dtype=torch.long).to(device)
with torch.no_grad():
logits = model(x)
scores = torch.sigmoid(logits).cpu().tolist()
preds = [1 if s >= 0.5 else 0 for s in scores]
for domain, pred, score in zip(batch, preds, scores):
results.append({
"domain": domain,
"label": "dga" if pred == 1 else "legit",
"score": round(score, 4),
})
return results