""" DGA-CNN: Character-level CNN for DGA detection. Architecture from Patton et al. (adapted), trained on 54 DGA families. """ import string import torch import torch.nn as nn CHARS = string.ascii_lowercase + string.digits + "-._" CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # 0 = padding VOCAB_SIZE = len(CHARS) + 1 # 40 MAXLEN = 75 def encode_domain(domain: str) -> list: domain = str(domain).lower().strip() encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]] return encoded + [0] * (MAXLEN - len(encoded)) class DGACNN(nn.Module): def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=32, num_classes=2): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.conv1 = nn.Conv1d(embedding_dim, 64, kernel_size=3, padding=1) self.relu = nn.ReLU() self.pool = nn.MaxPool1d(2) self.dropout = nn.Dropout(0.3) self.fc = nn.Linear(64 * (MAXLEN // 2), num_classes) def forward(self, x): x = self.embedding(x).transpose(1, 2) x = self.pool(self.relu(self.conv1(x))) x = x.view(x.size(0), -1) x = self.dropout(x) return self.fc(x) def load_model(weights_path: str, device: str = None): """Load trained model from a local weights path.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" model = DGACNN() model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() return model def predict(model, domains, device: str = None, batch_size: int = 256): """ Predict DGA vs legit for a list of domain strings. Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}] """ if device is None: device = next(model.parameters()).device if isinstance(domains, str): domains = [domains] results = [] for i in range(0, len(domains), batch_size): batch = domains[i : i + batch_size] encoded = [encode_domain(d) for d in batch] x = torch.tensor(encoded, dtype=torch.long).to(device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1) preds = logits.argmax(dim=1).cpu().tolist() scores = probs[:, 1].cpu().tolist() for domain, pred, score in zip(batch, preds, scores): results.append({ "domain": domain, "label": "dga" if pred == 1 else "legit", "score": round(score, 4), }) return results