Spaces:
Sleeping
Sleeping
File size: 16,304 Bytes
06e9934 c5f2c5c 06e9934 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | """
DefenAI β Module 0 | Lab 0
Malware Text Classification with DistilBERT
Gradio Interactive App
Run locally : python app.py
Deploy HF : Push folder to HuggingFace Space (SDK=Gradio)
"""
import re, time, random
import numpy as np
import pandas as pd
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, pipeline
)
from datasets import Dataset
import evaluate, torch
# βββ Colour palette ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
COLORS = {
"ransomware": "#DC2626", "trojan": "#EA580C",
"worm": "#D97706", "spyware": "#7C3AED",
"benign": "#16A34A",
}
RISK_COLOR = {"HIGH": "#DC2626", "MEDIUM": "#D97706", "LOW": "#16A34A"}
LABELS = ["ransomware", "trojan", "worm", "spyware", "benign"]
label2id = {l: i for i, l in enumerate(LABELS)}
id2label = {i: l for l, i in label2id.items()}
# βββ Training data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
SAMPLES = [
("This ransomware encrypts all files and demands Bitcoin payment", "ransomware"),
("Trojan disguised as PDF reader silently exfiltrates credentials", "trojan"),
("Worm propagates across network shares by exploiting SMB vulnerability", "worm"),
("Spyware monitors keystrokes and captures screen every 30 seconds", "spyware"),
("Legitimate system update downloaded from official vendor website", "benign"),
("Malware encrypts documents and deletes shadow copies", "ransomware"),
("Remote access trojan opens backdoor on port 4444", "trojan"),
("Self-replicating code scans subnet and infects unpatched Windows hosts", "worm"),
("Adware injects ads into browser and tracks browsing history", "spyware"),
("Scheduled task runs Windows Defender update from Microsoft", "benign"),
("Crypto locker demands 0.5 BTC and leaves ransom note", "ransomware"),
("Dropper downloads secondary payload into svchost.exe", "trojan"),
("Network worm exploits EternalBlue to spread across enterprise", "worm"),
("Keylogger records all input and emails log to attacker", "spyware"),
("Antivirus signature update received from vendor portal", "benign"),
("Ransomware deletes volume shadow copies using vssadmin", "ransomware"),
("Banking trojan intercepts web traffic to steal credentials", "trojan"),
("Email worm sends itself to all Outlook contacts", "worm"),
("Spyware activates webcam remotely without user consent", "spyware"),
("Security patch KB5034441 downloaded from Windows Update", "benign"),
("File-encrypting payload targets network drives and backups", "ransomware"),
("Stealer harvests passwords from Chrome and Firefox", "trojan"),
("Virus copies itself to USB drives for propagation", "worm"),
("Hidden process uploads screenshots to remote FTP server", "spyware"),
("Driver update from trusted manufacturer with valid certificate", "benign"),
] * 8
# βββ Global model state ββββββββββββββββββββββββββββββββββββββββββββββββββββ
_model_cache = {"pipe": None, "trained": False}
EXAMPLE_ALERTS = [
"Malware encrypts home directory and drops ransom note demanding ETH",
"Legitimate browser extension downloaded from Chrome Web Store",
"Hidden backdoor connects to attacker IP every hour via C2 channel",
"Worm spreading via Windows print spooler vulnerability CVE-2021-34527",
"Adware captures clipboard content and sends to remote server",
"Security software update verified and signed by Microsoft",
"Trojan disguised as game installer exfiltrates saved passwords",
"Software update modifies registry keys and disables antivirus",
]
# βββ Helper: build & train model ββββββββββββββββββββββββββββββββββββββββββ
def build_and_train():
texts, labs = zip(*SAMPLES)
df = pd.DataFrame({"text": list(texts), "label": [label2id[l] for l in labs]})
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
split = int(0.8 * len(df))
train_ds = Dataset.from_pandas(df[:split])
test_ds = Dataset.from_pandas(df[split:])
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize(batch):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64)
train_tok = train_ds.map(tokenize, batched=True)
test_tok = test_ds.map(tokenize, batched=True)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=len(LABELS),
id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True
)
f1_metric = evaluate.load("f1")
def compute_metrics(ep):
preds = np.argmax(ep.predictions, axis=-1)
return f1_metric.compute(predictions=preds, references=ep.label_ids, average="weighted")
args = TrainingArguments(
output_dir="malware_clf", num_train_epochs=2,
per_device_train_batch_size=16, per_device_eval_batch_size=16,
evaluation_strategy="epoch", save_strategy="no",
logging_steps=50, report_to="none",
)
trainer = Trainer(
model=model, args=args,
train_dataset=train_tok, eval_dataset=test_tok,
compute_metrics=compute_metrics,
)
trainer.train()
results = trainer.evaluate()
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=-1)
return pipe, results.get("eval_f1", 0)
# βββ Classify alert βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def classify_alert(text):
if not text.strip():
return "β οΈ Please enter an alert text.", None, ""
if _model_cache["pipe"] is None:
return "β Model not loaded. Click **Load & Train Model** first.", None, ""
result = _model_cache["pipe"](text)[0]
label = result["label"]
score = result["score"]
risk = "HIGH" if label != "benign" else "LOW"
color = COLORS.get(label, "#666")
# Confidence bar chart
all_results = _model_cache["pipe"](text, top_k=None)
labels_sorted = [r["label"] for r in all_results]
scores_sorted = [r["score"] for r in all_results]
bar_colors = [COLORS.get(l, "#666") for l in labels_sorted]
fig, ax = plt.subplots(figsize=(7, 3))
bars = ax.barh(labels_sorted, scores_sorted, color=bar_colors, edgecolor="white", height=0.6)
ax.set_xlim(0, 1)
ax.set_xlabel("Confidence Score", fontsize=11)
ax.set_title("Classification Confidence by Malware Type", fontsize=12, fontweight="bold")
for bar, sc in zip(bars, scores_sorted):
ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
f"{sc:.1%}", va="center", fontsize=10)
ax.spines[["top", "right"]].set_visible(False)
fig.patch.set_facecolor("#F8FAFC")
ax.set_facecolor("#F8FAFC")
plt.tight_layout()
risk_html = f"""
<div style="padding:16px;border-radius:10px;background:{RISK_COLOR[risk]}22;
border:2px solid {RISK_COLOR[risk]};text-align:center;margin:8px 0">
<div style="font-size:2em;font-weight:bold;color:{RISK_COLOR[risk]}">{risk} RISK</div>
<div style="font-size:1.3em;color:{color};font-weight:600;margin-top:4px">
{label.upper()} | {score:.1%} confidence
</div>
</div>
"""
details = f"π **Analysis:** Classified as **{label.upper()}** with {score:.1%} confidence.\n\n"
if label == "ransomware":
details += "π **Threat:** File-encrypting malware. Isolate system immediately. Check backups."
elif label == "trojan":
details += "π΄ **Threat:** Trojan/Backdoor. Check network connections. Scan for lateral movement."
elif label == "worm":
details += "πͺ± **Threat:** Self-propagating worm. Isolate from network. Patch vulnerabilities."
elif label == "spyware":
details += "ποΈ **Threat:** Spyware/Keylogger. Reset credentials. Check outbound traffic."
else:
details += "β
**Status:** No malicious indicators detected. Continue monitoring."
return risk_html, fig, details
# βββ Batch analyse ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def batch_analyse():
if _model_cache["pipe"] is None:
return "β Model not loaded.", None
rows = []
for alert in EXAMPLE_ALERTS:
r = _model_cache["pipe"](alert)[0]
rows.append({
"Alert": alert[:60] + "β¦",
"Label": r["label"].upper(),
"Confidence": f"{r['score']:.1%}",
"Risk": "π΄ HIGH" if r["label"] != "benign" else "π’ LOW",
})
df = pd.DataFrame(rows)
# Pie chart
counts = df["Label"].value_counts()
pie_colors = [COLORS.get(l.lower(), "#999") for l in counts.index]
fig, ax = plt.subplots(figsize=(5, 5))
wedges, texts, autotexts = ax.pie(
counts.values, labels=counts.index, autopct="%1.0f%%",
colors=pie_colors, startangle=90,
wedgeprops={"edgecolor": "white", "linewidth": 2},
)
for at in autotexts:
at.set_fontsize(11); at.set_fontweight("bold"); at.set_color("white")
ax.set_title("SOC Alert Distribution", fontsize=13, fontweight="bold")
fig.patch.set_facecolor("#F8FAFC")
plt.tight_layout()
return df, fig
# βββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_model():
yield "β³ Loading and training DistilBERT malware classifier...\n(This takes ~60-90 seconds on CPU)"
try:
pipe, f1 = build_and_train()
_model_cache["pipe"] = pipe
_model_cache["trained"] = True
yield f"β
**Model Ready!** Weighted F1 = **{f1:.4f}**\n\nYou can now classify threat alerts below."
except Exception as e:
yield f"β Error: {str(e)}"
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CSS = """
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1100px; margin: auto; }
.header-box { background: linear-gradient(135deg,#1B2A4A,#2563EB);
padding:24px; border-radius:12px; color:white; margin-bottom:16px; }
.risk-high { border-left: 6px solid #DC2626; }
.risk-low { border-left: 6px solid #16A34A; }
footer { display:none !important; }
"""
with gr.Blocks(css=CSS, title="DefenAI β Module 0: Malware Classifier") as demo:
# Header
gr.HTML("""
<div class="header-box">
<h1 style="margin:0;font-size:1.8em">π‘οΈ DefenAI β Module 0</h1>
<h2 style="margin:4px 0 0;font-weight:400;opacity:.9">
SOC Malware Alert Triage Β· DistilBERT Text Classifier
</h2>
<p style="margin:8px 0 0;opacity:.75;font-size:.9em">
Classify threat intelligence alerts as Ransomware Β· Trojan Β· Worm Β· Spyware Β· Benign
</p>
</div>
""")
with gr.Tabs():
# ββ Tab 1: Setup ββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Tab("βοΈ Step 1: Load Model"):
gr.Markdown("""
### How This App Works
1. Click **Load & Train Model** to fine-tune DistilBERT on synthetic malware alerts
2. Go to **Classify Alert** to test single threat intel texts
3. Go to **Batch SOC Triage** to analyse 8 pre-loaded alerts at once
> **Module 0 Concept:** AI can automate SOC Level-1 alert triage using NLP models
""")
load_btn = gr.Button("π Load & Train Model", variant="primary", size="lg")
status_box = gr.Markdown("_Model not loaded. Click above to start._")
load_btn.click(load_model, outputs=status_box)
# ββ Tab 2: Single classify βββββββββββββββββββββββββββββββββββββββββ
with gr.Tab("π Step 2: Classify Alert"):
gr.Markdown("### Enter a threat intelligence alert text to classify")
with gr.Row():
with gr.Column(scale=3):
alert_input = gr.Textbox(
label="Threat Intel Alert",
placeholder="e.g. Ransomware encrypts all files and demands Bitcoin paymentβ¦",
lines=3,
)
gr.Examples(
examples=[[e] for e in EXAMPLE_ALERTS[:5]],
inputs=alert_input,
label="π Example Alerts (click to load)",
)
classify_btn = gr.Button("π Classify Alert", variant="primary")
with gr.Column(scale=2):
risk_display = gr.HTML(label="Risk Assessment")
detail_text = gr.Markdown()
conf_chart = gr.Plot(label="Confidence Scores by Category")
classify_btn.click(
classify_alert,
inputs=alert_input,
outputs=[risk_display, conf_chart, detail_text],
)
# ββ Tab 3: Batch triage ββββββββββββββββββββββββββββββββββββββββββββ
with gr.Tab("π Step 3: Batch SOC Triage"):
gr.Markdown("### Analyse 8 pre-loaded SOC alerts at once")
batch_btn = gr.Button("βΆοΈ Run Batch Triage", variant="primary", size="lg")
batch_table = gr.Dataframe(label="Triage Results", interactive=False)
dist_chart = gr.Plot(label="Alert Distribution")
batch_btn.click(batch_analyse, outputs=[batch_table, dist_chart])
# ββ Tab 4: Learn βββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Tab("π Learn: How It Works"):
gr.Markdown("""
## π§ How DistilBERT Classifies Malware Alerts
### Model Architecture
- **DistilBERT** is a smaller, faster version of BERT (66M parameters)
- Fine-tuned on labelled malware descriptions for **5-class classification**
- Runs entirely on **CPU** β no GPU needed
### Attack Categories
| Label | Description | Example |
|-------|-------------|---------|
| π΄ Ransomware | Encrypts files, demands ransom | WannaCry, LockBit |
| π Trojan | Disguised malware, backdoor | Agent Tesla, Emotet |
| π‘ Worm | Self-propagating across network | Conficker, Stuxnet |
| π£ Spyware | Surveillance, data theft | Pegasus, DarkComet |
| π’ Benign | Legitimate software/update | Windows Update |
### Why F1 over Accuracy?
Class imbalance (few malware vs many benign) makes accuracy misleading.
**Weighted F1** balances precision and recall across all classes.
### Assessment Questions
1. Why is weighted F1 better than accuracy for imbalanced datasets?
2. A new alert says: *"Software modifies registry and disables antivirus"* β what label?
3. How would you update this model when new malware families emerge?
""")
if __name__ == "__main__":
demo.launch()
|