Reynier commited on
Commit
c028283
Β·
verified Β·
1 Parent(s): 18ee337

Upload dga_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dga_loader.py +164 -0
dga_loader.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DGA Benchmark Loader β€” use this in Colab to load any model from HuggingFace.
3
+
4
+ Usage:
5
+ from dga_loader import load_dga_model, predict_domains
6
+
7
+ model, mod = load_dga_model("cnn")
8
+ results = predict_domains(mod, model, ["google.com", "xkr3f9mq.ru"])
9
+
10
+ Available models:
11
+ "cnn" -> Reynier/dga-cnn
12
+ "bilbo" -> Reynier/dga-bilbo
13
+ "bilstm" -> Reynier/dga-bilstm
14
+ "labin" -> Reynier/dga-labin
15
+ "logit" -> Reynier/dga-logit
16
+ "fanci" -> Reynier/dga-fanci
17
+ "modernbert" -> Reynier/modernbert-dga-detector (HF pipeline)
18
+ "domurlsbert" -> Reynier/dga-domurlsbert (PEFT/LoRA)
19
+ """
20
+ import importlib.util
21
+ import sys
22
+
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ REGISTRY = {
26
+ "cnn": ("Reynier/dga-cnn", "dga_cnn_model_1M.pth", "model.py"),
27
+ "bilbo": ("Reynier/dga-bilbo", "bilbo_best.pth", "model.py"),
28
+ "bilstm": ("Reynier/dga-bilstm", "bilstm_best.pth", "model.py"),
29
+ "labin": ("Reynier/dga-labin", "LABin_best_model.keras", "model.py"),
30
+ "logit": ("Reynier/dga-logit", "artifacts.joblib", "model.py"),
31
+ "fanci": ("Reynier/dga-fanci", "fanci_dga_detector.joblib", "model.py"),
32
+ }
33
+
34
+
35
+ def _import_module(path: str, name: str):
36
+ """Dynamically import a Python file as a module."""
37
+ spec = importlib.util.spec_from_file_location(name, path)
38
+ mod = importlib.util.module_from_spec(spec)
39
+ spec.loader.exec_module(mod)
40
+ sys.modules[name] = mod
41
+ return mod
42
+
43
+
44
+ def load_dga_model(model_name: str, device: str = None):
45
+ """
46
+ Download and load a DGA model from HuggingFace.
47
+
48
+ Parameters
49
+ ----------
50
+ model_name : str
51
+ One of: cnn, bilbo, bilstm, labin, logit, fanci, modernbert, domurlsbert
52
+ device : str, optional
53
+ 'cpu' or 'cuda'. Auto-detected if None.
54
+
55
+ Returns
56
+ -------
57
+ model : loaded model object
58
+ mod : the model module (call mod.predict(model, domains) to get predictions)
59
+ For modernbert/domurlsbert, mod=None (use the pipeline/model directly).
60
+ """
61
+ model_name = model_name.lower()
62
+
63
+ # ── Transformer models (special handling) ─────────────────────────────
64
+ if model_name == "modernbert":
65
+ from transformers import pipeline
66
+ print("Loading Reynier/modernbert-dga-detector ...")
67
+ pipe = pipeline(
68
+ "text-classification",
69
+ model="Reynier/modernbert-dga-detector",
70
+ device=0 if _cuda_available() else -1,
71
+ )
72
+ return pipe, None
73
+
74
+ if model_name == "domurlsbert":
75
+ import torch
76
+ from transformers import BertTokenizer, BertForSequenceClassification
77
+ from peft import PeftModel
78
+ print("Loading Reynier/dga-domurlsbert ...")
79
+ tok = BertTokenizer.from_pretrained("Reynier/dga-domurlsbert")
80
+ base = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
81
+ model = PeftModel.from_pretrained(base, "Reynier/dga-domurlsbert").eval()
82
+ dev = device or ("cuda" if _cuda_available() else "cpu")
83
+ model.to(dev)
84
+ model._tokenizer = tok
85
+ model._device = dev
86
+ return model, None
87
+
88
+ # ── Standard models ───────────────────────────────────────────────────
89
+ if model_name not in REGISTRY:
90
+ raise ValueError(
91
+ f"Unknown model '{model_name}'. "
92
+ f"Choose from: {list(REGISTRY.keys()) + ['modernbert', 'domurlsbert']}"
93
+ )
94
+
95
+ repo_id, weights_file, module_file = REGISTRY[model_name]
96
+ print(f"Downloading {model_name} from {repo_id} ...")
97
+
98
+ weights_path = hf_hub_download(repo_id, weights_file)
99
+ module_path = hf_hub_download(repo_id, module_file)
100
+
101
+ mod = _import_module(module_path, f"dga_{model_name}")
102
+ model = mod.load_model(weights_path) if device is None else mod.load_model(weights_path, device)
103
+
104
+ print(f" {model_name} ready.")
105
+ return model, mod
106
+
107
+
108
+ def predict_domains(mod, model, domains):
109
+ """
110
+ Unified prediction interface.
111
+
112
+ Works with both standard models (mod + model) and transformer pipelines.
113
+
114
+ Parameters
115
+ ----------
116
+ mod : module returned by load_dga_model, or None for transformers
117
+ model : loaded model
118
+ domains : str or list of str
119
+
120
+ Returns
121
+ -------
122
+ list of dicts: [{"domain": ..., "label": "dga"/"legit", "score": float}]
123
+ """
124
+ if isinstance(domains, str):
125
+ domains = [domains]
126
+
127
+ # HF pipeline (modernbert)
128
+ if mod is None and hasattr(model, '__call__') and not hasattr(model, '_tokenizer'):
129
+ raw = model(domains)
130
+ return [
131
+ {
132
+ "domain": d,
133
+ "label": r["label"].lower().replace("label_1", "dga").replace("label_0", "legit"),
134
+ "score": round(r["score"], 4),
135
+ }
136
+ for d, r in zip(domains, raw)
137
+ ]
138
+
139
+ # PEFT/LoRA model (domurlsbert)
140
+ if mod is None and hasattr(model, '_tokenizer'):
141
+ import torch
142
+ tok = model._tokenizer
143
+ dev = model._device
144
+ id2label = {0: "legit", 1: "dga"}
145
+ results = []
146
+ for domain in domains:
147
+ inputs = tok(domain, return_tensors="pt", truncation=True).to(dev)
148
+ with torch.no_grad():
149
+ logits = model(**inputs).logits
150
+ pred = torch.argmax(logits, dim=1).item()
151
+ score = torch.softmax(logits, dim=1)[0, 1].item()
152
+ results.append({"domain": domain, "label": id2label[pred], "score": round(score, 4)})
153
+ return results
154
+
155
+ # Standard models
156
+ return mod.predict(model, domains)
157
+
158
+
159
+ def _cuda_available():
160
+ try:
161
+ import torch
162
+ return torch.cuda.is_available()
163
+ except ImportError:
164
+ return False