Reynier commited on
Commit
79052fb
·
verified ·
1 Parent(s): bdf66bf

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +76 -0
model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DGA-CNN: Character-level CNN for DGA detection.
3
+ Architecture from Patton et al. (adapted), trained on 54 DGA families.
4
+ """
5
+ import string
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ CHARS = string.ascii_lowercase + string.digits + "-._"
10
+ CHAR2IDX = {c: i + 1 for i, c in enumerate(CHARS)} # 0 = padding
11
+ VOCAB_SIZE = len(CHARS) + 1 # 40
12
+ MAXLEN = 75
13
+
14
+
15
+ def encode_domain(domain: str) -> list:
16
+ domain = str(domain).lower().strip()
17
+ encoded = [CHAR2IDX.get(c, 0) for c in domain[:MAXLEN]]
18
+ return encoded + [0] * (MAXLEN - len(encoded))
19
+
20
+
21
+ class DGACNN(nn.Module):
22
+ def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=32, num_classes=2):
23
+ super().__init__()
24
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
25
+ self.conv1 = nn.Conv1d(embedding_dim, 64, kernel_size=3, padding=1)
26
+ self.relu = nn.ReLU()
27
+ self.pool = nn.MaxPool1d(2)
28
+ self.dropout = nn.Dropout(0.3)
29
+ self.fc = nn.Linear(64 * (MAXLEN // 2), num_classes)
30
+
31
+ def forward(self, x):
32
+ x = self.embedding(x).transpose(1, 2)
33
+ x = self.pool(self.relu(self.conv1(x)))
34
+ x = x.view(x.size(0), -1)
35
+ x = self.dropout(x)
36
+ return self.fc(x)
37
+
38
+
39
+ def load_model(weights_path: str, device: str = None):
40
+ """Load trained model from a local weights path."""
41
+ if device is None:
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ model = DGACNN()
44
+ model.load_state_dict(torch.load(weights_path, map_location=device))
45
+ model.to(device)
46
+ model.eval()
47
+ return model
48
+
49
+
50
+ def predict(model, domains, device: str = None, batch_size: int = 256):
51
+ """
52
+ Predict DGA vs legit for a list of domain strings.
53
+ Returns list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
54
+ """
55
+ if device is None:
56
+ device = next(model.parameters()).device
57
+ if isinstance(domains, str):
58
+ domains = [domains]
59
+
60
+ results = []
61
+ for i in range(0, len(domains), batch_size):
62
+ batch = domains[i : i + batch_size]
63
+ encoded = [encode_domain(d) for d in batch]
64
+ x = torch.tensor(encoded, dtype=torch.long).to(device)
65
+ with torch.no_grad():
66
+ logits = model(x)
67
+ probs = torch.softmax(logits, dim=1)
68
+ preds = logits.argmax(dim=1).cpu().tolist()
69
+ scores = probs[:, 1].cpu().tolist()
70
+ for domain, pred, score in zip(batch, preds, scores):
71
+ results.append({
72
+ "domain": domain,
73
+ "label": "dga" if pred == 1 else "legit",
74
+ "score": round(score, 4),
75
+ })
76
+ return results