""" DefenAI – Module 0 | Lab 0 Malware Text Classification with DistilBERT Gradio Interactive App Run locally : python app.py Deploy HF : Push folder to HuggingFace Space (SDK=Gradio) """ import re, time, random import numpy as np import pandas as pd import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline ) from datasets import Dataset import evaluate, torch # ─── Colour palette ──────────────────────────────────────────────────────── COLORS = { "ransomware": "#DC2626", "trojan": "#EA580C", "worm": "#D97706", "spyware": "#7C3AED", "benign": "#16A34A", } RISK_COLOR = {"HIGH": "#DC2626", "MEDIUM": "#D97706", "LOW": "#16A34A"} LABELS = ["ransomware", "trojan", "worm", "spyware", "benign"] label2id = {l: i for i, l in enumerate(LABELS)} id2label = {i: l for l, i in label2id.items()} # ─── Training data ───────────────────────────────────────────────────────── SAMPLES = [ ("This ransomware encrypts all files and demands Bitcoin payment", "ransomware"), ("Trojan disguised as PDF reader silently exfiltrates credentials", "trojan"), ("Worm propagates across network shares by exploiting SMB vulnerability", "worm"), ("Spyware monitors keystrokes and captures screen every 30 seconds", "spyware"), ("Legitimate system update downloaded from official vendor website", "benign"), ("Malware encrypts documents and deletes shadow copies", "ransomware"), ("Remote access trojan opens backdoor on port 4444", "trojan"), ("Self-replicating code scans subnet and infects unpatched Windows hosts", "worm"), ("Adware injects ads into browser and tracks browsing history", "spyware"), ("Scheduled task runs Windows Defender update from Microsoft", "benign"), ("Crypto locker demands 0.5 BTC and leaves ransom note", "ransomware"), ("Dropper downloads secondary payload into svchost.exe", "trojan"), ("Network worm exploits EternalBlue to spread across enterprise", "worm"), ("Keylogger records all input and emails log to attacker", "spyware"), ("Antivirus signature update received from vendor portal", "benign"), ("Ransomware deletes volume shadow copies using vssadmin", "ransomware"), ("Banking trojan intercepts web traffic to steal credentials", "trojan"), ("Email worm sends itself to all Outlook contacts", "worm"), ("Spyware activates webcam remotely without user consent", "spyware"), ("Security patch KB5034441 downloaded from Windows Update", "benign"), ("File-encrypting payload targets network drives and backups", "ransomware"), ("Stealer harvests passwords from Chrome and Firefox", "trojan"), ("Virus copies itself to USB drives for propagation", "worm"), ("Hidden process uploads screenshots to remote FTP server", "spyware"), ("Driver update from trusted manufacturer with valid certificate", "benign"), ] * 8 # ─── Global model state ──────────────────────────────────────────────────── _model_cache = {"pipe": None, "trained": False} EXAMPLE_ALERTS = [ "Malware encrypts home directory and drops ransom note demanding ETH", "Legitimate browser extension downloaded from Chrome Web Store", "Hidden backdoor connects to attacker IP every hour via C2 channel", "Worm spreading via Windows print spooler vulnerability CVE-2021-34527", "Adware captures clipboard content and sends to remote server", "Security software update verified and signed by Microsoft", "Trojan disguised as game installer exfiltrates saved passwords", "Software update modifies registry keys and disables antivirus", ] # ─── Helper: build & train model ────────────────────────────────────────── def build_and_train(): texts, labs = zip(*SAMPLES) df = pd.DataFrame({"text": list(texts), "label": [label2id[l] for l in labs]}) df = df.sample(frac=1, random_state=42).reset_index(drop=True) split = int(0.8 * len(df)) train_ds = Dataset.from_pandas(df[:split]) test_ds = Dataset.from_pandas(df[split:]) model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize(batch): return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64) train_tok = train_ds.map(tokenize, batched=True) test_tok = test_ds.map(tokenize, batched=True) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=len(LABELS), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True ) f1_metric = evaluate.load("f1") def compute_metrics(ep): preds = np.argmax(ep.predictions, axis=-1) return f1_metric.compute(predictions=preds, references=ep.label_ids, average="weighted") args = TrainingArguments( output_dir="malware_clf", num_train_epochs=2, per_device_train_batch_size=16, per_device_eval_batch_size=16, evaluation_strategy="epoch", save_strategy="no", logging_steps=50, report_to="none", ) trainer = Trainer( model=model, args=args, train_dataset=train_tok, eval_dataset=test_tok, compute_metrics=compute_metrics, ) trainer.train() results = trainer.evaluate() pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=-1) return pipe, results.get("eval_f1", 0) # ─── Classify alert ─────────────────────────────────────────────────────── def classify_alert(text): if not text.strip(): return "⚠️ Please enter an alert text.", None, "" if _model_cache["pipe"] is None: return "❌ Model not loaded. Click **Load & Train Model** first.", None, "" result = _model_cache["pipe"](text)[0] label = result["label"] score = result["score"] risk = "HIGH" if label != "benign" else "LOW" color = COLORS.get(label, "#666") # Confidence bar chart all_results = _model_cache["pipe"](text, top_k=None) labels_sorted = [r["label"] for r in all_results] scores_sorted = [r["score"] for r in all_results] bar_colors = [COLORS.get(l, "#666") for l in labels_sorted] fig, ax = plt.subplots(figsize=(7, 3)) bars = ax.barh(labels_sorted, scores_sorted, color=bar_colors, edgecolor="white", height=0.6) ax.set_xlim(0, 1) ax.set_xlabel("Confidence Score", fontsize=11) ax.set_title("Classification Confidence by Malware Type", fontsize=12, fontweight="bold") for bar, sc in zip(bars, scores_sorted): ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2, f"{sc:.1%}", va="center", fontsize=10) ax.spines[["top", "right"]].set_visible(False) fig.patch.set_facecolor("#F8FAFC") ax.set_facecolor("#F8FAFC") plt.tight_layout() risk_html = f"""
Classify threat intelligence alerts as Ransomware · Trojan · Worm · Spyware · Benign