File size: 2,613 Bytes
79052fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | """
DGA-CNN: Character-level CNN for DGA detection.
Architecture from Patton et al. (adapted), trained on 54 DGA families.
"""
import string
import torch
import torch.nn as nn
CHARS = string.ascii_lowercase + string.digits + "-._"
CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # 0 = padding
VOCAB_SIZE = len(CHARS) + 1 # 40
MAXLEN = 75
def encode_domain(domain: str) -> list:
domain = str(domain).lower().strip()
encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]]
return encoded + [0] * (MAXLEN - len(encoded))
class DGACNN(nn.Module):
def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=32, num_classes=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.conv1 = nn.Conv1d(embedding_dim, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool1d(2)
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(64 * (MAXLEN // 2), num_classes)
def forward(self, x):
x = self.embedding(x).transpose(1, 2)
x = self.pool(self.relu(self.conv1(x)))
x = x.view(x.size(0), -1)
x = self.dropout(x)
return self.fc(x)
def load_model(weights_path: str, device: str = None):
"""Load trained model from a local weights path."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DGACNN()
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
return model
def predict(model, domains, device: str = None, batch_size: int = 256):
"""
Predict DGA vs legit for a list of domain strings.
Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
"""
if device is None:
device = next(model.parameters()).device
if isinstance(domains, str):
domains = [domains]
results = []
for i in range(0, len(domains), batch_size):
batch = domains[i : i + batch_size]
encoded = [encode_domain(d) for d in batch]
x = torch.tensor(encoded, dtype=torch.long).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)
preds = logits.argmax(dim=1).cpu().tolist()
scores = probs[:, 1].cpu().tolist()
for domain, pred, score in zip(batch, preds, scores):
results.append({
"domain": domain,
"label": "dga" if pred == 1 else "legit",
"score": round(score, 4),
})
return results
|