""" DefenAI – Module 0 | Lab 0 Malware Text Classification with DistilBERT Gradio Interactive App Run locally : python app.py Deploy HF : Push folder to HuggingFace Space (SDK=Gradio) """ import re, time, random import numpy as np import pandas as pd import gradio as gr import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline ) from datasets import Dataset import evaluate, torch # ─── Colour palette ──────────────────────────────────────────────────────── COLORS = { "ransomware": "#DC2626", "trojan": "#EA580C", "worm": "#D97706", "spyware": "#7C3AED", "benign": "#16A34A", } RISK_COLOR = {"HIGH": "#DC2626", "MEDIUM": "#D97706", "LOW": "#16A34A"} LABELS = ["ransomware", "trojan", "worm", "spyware", "benign"] label2id = {l: i for i, l in enumerate(LABELS)} id2label = {i: l for l, i in label2id.items()} # ─── Training data ───────────────────────────────────────────────────────── SAMPLES = [ ("This ransomware encrypts all files and demands Bitcoin payment", "ransomware"), ("Trojan disguised as PDF reader silently exfiltrates credentials", "trojan"), ("Worm propagates across network shares by exploiting SMB vulnerability", "worm"), ("Spyware monitors keystrokes and captures screen every 30 seconds", "spyware"), ("Legitimate system update downloaded from official vendor website", "benign"), ("Malware encrypts documents and deletes shadow copies", "ransomware"), ("Remote access trojan opens backdoor on port 4444", "trojan"), ("Self-replicating code scans subnet and infects unpatched Windows hosts", "worm"), ("Adware injects ads into browser and tracks browsing history", "spyware"), ("Scheduled task runs Windows Defender update from Microsoft", "benign"), ("Crypto locker demands 0.5 BTC and leaves ransom note", "ransomware"), ("Dropper downloads secondary payload into svchost.exe", "trojan"), ("Network worm exploits EternalBlue to spread across enterprise", "worm"), ("Keylogger records all input and emails log to attacker", "spyware"), ("Antivirus signature update received from vendor portal", "benign"), ("Ransomware deletes volume shadow copies using vssadmin", "ransomware"), ("Banking trojan intercepts web traffic to steal credentials", "trojan"), ("Email worm sends itself to all Outlook contacts", "worm"), ("Spyware activates webcam remotely without user consent", "spyware"), ("Security patch KB5034441 downloaded from Windows Update", "benign"), ("File-encrypting payload targets network drives and backups", "ransomware"), ("Stealer harvests passwords from Chrome and Firefox", "trojan"), ("Virus copies itself to USB drives for propagation", "worm"), ("Hidden process uploads screenshots to remote FTP server", "spyware"), ("Driver update from trusted manufacturer with valid certificate", "benign"), ] * 8 # ─── Global model state ──────────────────────────────────────────────────── _model_cache = {"pipe": None, "trained": False} EXAMPLE_ALERTS = [ "Malware encrypts home directory and drops ransom note demanding ETH", "Legitimate browser extension downloaded from Chrome Web Store", "Hidden backdoor connects to attacker IP every hour via C2 channel", "Worm spreading via Windows print spooler vulnerability CVE-2021-34527", "Adware captures clipboard content and sends to remote server", "Security software update verified and signed by Microsoft", "Trojan disguised as game installer exfiltrates saved passwords", "Software update modifies registry keys and disables antivirus", ] # ─── Helper: build & train model ────────────────────────────────────────── def build_and_train(): texts, labs = zip(*SAMPLES) df = pd.DataFrame({"text": list(texts), "label": [label2id[l] for l in labs]}) df = df.sample(frac=1, random_state=42).reset_index(drop=True) split = int(0.8 * len(df)) train_ds = Dataset.from_pandas(df[:split]) test_ds = Dataset.from_pandas(df[split:]) model_name = "distilbert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize(batch): return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64) train_tok = train_ds.map(tokenize, batched=True) test_tok = test_ds.map(tokenize, batched=True) model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=len(LABELS), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True ) f1_metric = evaluate.load("f1") def compute_metrics(ep): preds = np.argmax(ep.predictions, axis=-1) return f1_metric.compute(predictions=preds, references=ep.label_ids, average="weighted") args = TrainingArguments( output_dir="malware_clf", num_train_epochs=2, per_device_train_batch_size=16, per_device_eval_batch_size=16, evaluation_strategy="epoch", save_strategy="no", logging_steps=50, report_to="none", ) trainer = Trainer( model=model, args=args, train_dataset=train_tok, eval_dataset=test_tok, compute_metrics=compute_metrics, ) trainer.train() results = trainer.evaluate() pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=-1) return pipe, results.get("eval_f1", 0) # ─── Classify alert ─────────────────────────────────────────────────────── def classify_alert(text): if not text.strip(): return "⚠️ Please enter an alert text.", None, "" if _model_cache["pipe"] is None: return "❌ Model not loaded. Click **Load & Train Model** first.", None, "" result = _model_cache["pipe"](text)[0] label = result["label"] score = result["score"] risk = "HIGH" if label != "benign" else "LOW" color = COLORS.get(label, "#666") # Confidence bar chart all_results = _model_cache["pipe"](text, top_k=None) labels_sorted = [r["label"] for r in all_results] scores_sorted = [r["score"] for r in all_results] bar_colors = [COLORS.get(l, "#666") for l in labels_sorted] fig, ax = plt.subplots(figsize=(7, 3)) bars = ax.barh(labels_sorted, scores_sorted, color=bar_colors, edgecolor="white", height=0.6) ax.set_xlim(0, 1) ax.set_xlabel("Confidence Score", fontsize=11) ax.set_title("Classification Confidence by Malware Type", fontsize=12, fontweight="bold") for bar, sc in zip(bars, scores_sorted): ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2, f"{sc:.1%}", va="center", fontsize=10) ax.spines[["top", "right"]].set_visible(False) fig.patch.set_facecolor("#F8FAFC") ax.set_facecolor("#F8FAFC") plt.tight_layout() risk_html = f"""
{risk} RISK
{label.upper()}  |  {score:.1%} confidence
""" details = f"🔍 **Analysis:** Classified as **{label.upper()}** with {score:.1%} confidence.\n\n" if label == "ransomware": details += "💀 **Threat:** File-encrypting malware. Isolate system immediately. Check backups." elif label == "trojan": details += "🐴 **Threat:** Trojan/Backdoor. Check network connections. Scan for lateral movement." elif label == "worm": details += "🪱 **Threat:** Self-propagating worm. Isolate from network. Patch vulnerabilities." elif label == "spyware": details += "👁️ **Threat:** Spyware/Keylogger. Reset credentials. Check outbound traffic." else: details += "✅ **Status:** No malicious indicators detected. Continue monitoring." return risk_html, fig, details # ─── Batch analyse ──────────────────────────────────────────────────────── def batch_analyse(): if _model_cache["pipe"] is None: return "❌ Model not loaded.", None rows = [] for alert in EXAMPLE_ALERTS: r = _model_cache["pipe"](alert)[0] rows.append({ "Alert": alert[:60] + "…", "Label": r["label"].upper(), "Confidence": f"{r['score']:.1%}", "Risk": "🔴 HIGH" if r["label"] != "benign" else "🟢 LOW", }) df = pd.DataFrame(rows) # Pie chart counts = df["Label"].value_counts() pie_colors = [COLORS.get(l.lower(), "#999") for l in counts.index] fig, ax = plt.subplots(figsize=(5, 5)) wedges, texts, autotexts = ax.pie( counts.values, labels=counts.index, autopct="%1.0f%%", colors=pie_colors, startangle=90, wedgeprops={"edgecolor": "white", "linewidth": 2}, ) for at in autotexts: at.set_fontsize(11); at.set_fontweight("bold"); at.set_color("white") ax.set_title("SOC Alert Distribution", fontsize=13, fontweight="bold") fig.patch.set_facecolor("#F8FAFC") plt.tight_layout() return df, fig # ─── Load model ─────────────────────────────────────────────────────────── def load_model(): yield "⏳ Loading and training DistilBERT malware classifier...\n(This takes ~60-90 seconds on CPU)" try: pipe, f1 = build_and_train() _model_cache["pipe"] = pipe _model_cache["trained"] = True yield f"✅ **Model Ready!** Weighted F1 = **{f1:.4f}**\n\nYou can now classify threat alerts below." except Exception as e: yield f"❌ Error: {str(e)}" # ─── Gradio UI ──────────────────────────────────────────────────────────── CSS = """ .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1100px; margin: auto; } .header-box { background: linear-gradient(135deg,#1B2A4A,#2563EB); padding:24px; border-radius:12px; color:white; margin-bottom:16px; } .risk-high { border-left: 6px solid #DC2626; } .risk-low { border-left: 6px solid #16A34A; } footer { display:none !important; } """ with gr.Blocks(css=CSS, title="DefenAI – Module 0: Malware Classifier") as demo: # Header gr.HTML("""

🛡️ DefenAI – Module 0

SOC Malware Alert Triage · DistilBERT Text Classifier

Classify threat intelligence alerts as Ransomware · Trojan · Worm · Spyware · Benign

""") with gr.Tabs(): # ── Tab 1: Setup ────────────────────────────────────────────────── with gr.Tab("⚙️ Step 1: Load Model"): gr.Markdown(""" ### How This App Works 1. Click **Load & Train Model** to fine-tune DistilBERT on synthetic malware alerts 2. Go to **Classify Alert** to test single threat intel texts 3. Go to **Batch SOC Triage** to analyse 8 pre-loaded alerts at once > **Module 0 Concept:** AI can automate SOC Level-1 alert triage using NLP models """) load_btn = gr.Button("🚀 Load & Train Model", variant="primary", size="lg") status_box = gr.Markdown("_Model not loaded. Click above to start._") load_btn.click(load_model, outputs=status_box) # ── Tab 2: Single classify ───────────────────────────────────────── with gr.Tab("🔍 Step 2: Classify Alert"): gr.Markdown("### Enter a threat intelligence alert text to classify") with gr.Row(): with gr.Column(scale=3): alert_input = gr.Textbox( label="Threat Intel Alert", placeholder="e.g. Ransomware encrypts all files and demands Bitcoin payment…", lines=3, ) gr.Examples( examples=[[e] for e in EXAMPLE_ALERTS[:5]], inputs=alert_input, label="📋 Example Alerts (click to load)", ) classify_btn = gr.Button("🔎 Classify Alert", variant="primary") with gr.Column(scale=2): risk_display = gr.HTML(label="Risk Assessment") detail_text = gr.Markdown() conf_chart = gr.Plot(label="Confidence Scores by Category") classify_btn.click( classify_alert, inputs=alert_input, outputs=[risk_display, conf_chart, detail_text], ) # ── Tab 3: Batch triage ──────────────────────────────────────────── with gr.Tab("📊 Step 3: Batch SOC Triage"): gr.Markdown("### Analyse 8 pre-loaded SOC alerts at once") batch_btn = gr.Button("▶️ Run Batch Triage", variant="primary", size="lg") batch_table = gr.Dataframe(label="Triage Results", interactive=False) dist_chart = gr.Plot(label="Alert Distribution") batch_btn.click(batch_analyse, outputs=[batch_table, dist_chart]) # ── Tab 4: Learn ─────────────────────────────────────────────────── with gr.Tab("📚 Learn: How It Works"): gr.Markdown(""" ## 🧠 How DistilBERT Classifies Malware Alerts ### Model Architecture - **DistilBERT** is a smaller, faster version of BERT (66M parameters) - Fine-tuned on labelled malware descriptions for **5-class classification** - Runs entirely on **CPU** — no GPU needed ### Attack Categories | Label | Description | Example | |-------|-------------|---------| | 🔴 Ransomware | Encrypts files, demands ransom | WannaCry, LockBit | | 🟠 Trojan | Disguised malware, backdoor | Agent Tesla, Emotet | | 🟡 Worm | Self-propagating across network | Conficker, Stuxnet | | 🟣 Spyware | Surveillance, data theft | Pegasus, DarkComet | | 🟢 Benign | Legitimate software/update | Windows Update | ### Why F1 over Accuracy? Class imbalance (few malware vs many benign) makes accuracy misleading. **Weighted F1** balances precision and recall across all classes. ### Assessment Questions 1. Why is weighted F1 better than accuracy for imbalanced datasets? 2. A new alert says: *"Software modifies registry and disables antivirus"* — what label? 3. How would you update this model when new malware families emerge? """) if __name__ == "__main__": demo.launch()