Spaces:
Sleeping
Sleeping
| """ | |
| DefenAI β Module 0 | Lab 0 | |
| Malware Text Classification with DistilBERT | |
| Gradio Interactive App | |
| Run locally : python app.py | |
| Deploy HF : Push folder to HuggingFace Space (SDK=Gradio) | |
| """ | |
| import re, time, random | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| TrainingArguments, Trainer, pipeline | |
| ) | |
| from datasets import Dataset | |
| import evaluate, torch | |
| # βββ Colour palette ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| COLORS = { | |
| "ransomware": "#DC2626", "trojan": "#EA580C", | |
| "worm": "#D97706", "spyware": "#7C3AED", | |
| "benign": "#16A34A", | |
| } | |
| RISK_COLOR = {"HIGH": "#DC2626", "MEDIUM": "#D97706", "LOW": "#16A34A"} | |
| LABELS = ["ransomware", "trojan", "worm", "spyware", "benign"] | |
| label2id = {l: i for i, l in enumerate(LABELS)} | |
| id2label = {i: l for l, i in label2id.items()} | |
| # βββ Training data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SAMPLES = [ | |
| ("This ransomware encrypts all files and demands Bitcoin payment", "ransomware"), | |
| ("Trojan disguised as PDF reader silently exfiltrates credentials", "trojan"), | |
| ("Worm propagates across network shares by exploiting SMB vulnerability", "worm"), | |
| ("Spyware monitors keystrokes and captures screen every 30 seconds", "spyware"), | |
| ("Legitimate system update downloaded from official vendor website", "benign"), | |
| ("Malware encrypts documents and deletes shadow copies", "ransomware"), | |
| ("Remote access trojan opens backdoor on port 4444", "trojan"), | |
| ("Self-replicating code scans subnet and infects unpatched Windows hosts", "worm"), | |
| ("Adware injects ads into browser and tracks browsing history", "spyware"), | |
| ("Scheduled task runs Windows Defender update from Microsoft", "benign"), | |
| ("Crypto locker demands 0.5 BTC and leaves ransom note", "ransomware"), | |
| ("Dropper downloads secondary payload into svchost.exe", "trojan"), | |
| ("Network worm exploits EternalBlue to spread across enterprise", "worm"), | |
| ("Keylogger records all input and emails log to attacker", "spyware"), | |
| ("Antivirus signature update received from vendor portal", "benign"), | |
| ("Ransomware deletes volume shadow copies using vssadmin", "ransomware"), | |
| ("Banking trojan intercepts web traffic to steal credentials", "trojan"), | |
| ("Email worm sends itself to all Outlook contacts", "worm"), | |
| ("Spyware activates webcam remotely without user consent", "spyware"), | |
| ("Security patch KB5034441 downloaded from Windows Update", "benign"), | |
| ("File-encrypting payload targets network drives and backups", "ransomware"), | |
| ("Stealer harvests passwords from Chrome and Firefox", "trojan"), | |
| ("Virus copies itself to USB drives for propagation", "worm"), | |
| ("Hidden process uploads screenshots to remote FTP server", "spyware"), | |
| ("Driver update from trusted manufacturer with valid certificate", "benign"), | |
| ] * 8 | |
| # βββ Global model state ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model_cache = {"pipe": None, "trained": False} | |
| EXAMPLE_ALERTS = [ | |
| "Malware encrypts home directory and drops ransom note demanding ETH", | |
| "Legitimate browser extension downloaded from Chrome Web Store", | |
| "Hidden backdoor connects to attacker IP every hour via C2 channel", | |
| "Worm spreading via Windows print spooler vulnerability CVE-2021-34527", | |
| "Adware captures clipboard content and sends to remote server", | |
| "Security software update verified and signed by Microsoft", | |
| "Trojan disguised as game installer exfiltrates saved passwords", | |
| "Software update modifies registry keys and disables antivirus", | |
| ] | |
| # βββ Helper: build & train model ββββββββββββββββββββββββββββββββββββββββββ | |
| def build_and_train(): | |
| texts, labs = zip(*SAMPLES) | |
| df = pd.DataFrame({"text": list(texts), "label": [label2id[l] for l in labs]}) | |
| df = df.sample(frac=1, random_state=42).reset_index(drop=True) | |
| split = int(0.8 * len(df)) | |
| train_ds = Dataset.from_pandas(df[:split]) | |
| test_ds = Dataset.from_pandas(df[split:]) | |
| model_name = "distilbert-base-uncased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def tokenize(batch): | |
| return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64) | |
| train_tok = train_ds.map(tokenize, batched=True) | |
| test_tok = test_ds.map(tokenize, batched=True) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, num_labels=len(LABELS), | |
| id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True | |
| ) | |
| f1_metric = evaluate.load("f1") | |
| def compute_metrics(ep): | |
| preds = np.argmax(ep.predictions, axis=-1) | |
| return f1_metric.compute(predictions=preds, references=ep.label_ids, average="weighted") | |
| args = TrainingArguments( | |
| output_dir="malware_clf", num_train_epochs=2, | |
| per_device_train_batch_size=16, per_device_eval_batch_size=16, | |
| evaluation_strategy="epoch", save_strategy="no", | |
| logging_steps=50, report_to="none", | |
| ) | |
| trainer = Trainer( | |
| model=model, args=args, | |
| train_dataset=train_tok, eval_dataset=test_tok, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| results = trainer.evaluate() | |
| pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=-1) | |
| return pipe, results.get("eval_f1", 0) | |
| # βββ Classify alert βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def classify_alert(text): | |
| if not text.strip(): | |
| return "β οΈ Please enter an alert text.", None, "" | |
| if _model_cache["pipe"] is None: | |
| return "β Model not loaded. Click **Load & Train Model** first.", None, "" | |
| result = _model_cache["pipe"](text)[0] | |
| label = result["label"] | |
| score = result["score"] | |
| risk = "HIGH" if label != "benign" else "LOW" | |
| color = COLORS.get(label, "#666") | |
| # Confidence bar chart | |
| all_results = _model_cache["pipe"](text, top_k=None) | |
| labels_sorted = [r["label"] for r in all_results] | |
| scores_sorted = [r["score"] for r in all_results] | |
| bar_colors = [COLORS.get(l, "#666") for l in labels_sorted] | |
| fig, ax = plt.subplots(figsize=(7, 3)) | |
| bars = ax.barh(labels_sorted, scores_sorted, color=bar_colors, edgecolor="white", height=0.6) | |
| ax.set_xlim(0, 1) | |
| ax.set_xlabel("Confidence Score", fontsize=11) | |
| ax.set_title("Classification Confidence by Malware Type", fontsize=12, fontweight="bold") | |
| for bar, sc in zip(bars, scores_sorted): | |
| ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2, | |
| f"{sc:.1%}", va="center", fontsize=10) | |
| ax.spines[["top", "right"]].set_visible(False) | |
| fig.patch.set_facecolor("#F8FAFC") | |
| ax.set_facecolor("#F8FAFC") | |
| plt.tight_layout() | |
| risk_html = f""" | |
| <div style="padding:16px;border-radius:10px;background:{RISK_COLOR[risk]}22; | |
| border:2px solid {RISK_COLOR[risk]};text-align:center;margin:8px 0"> | |
| <div style="font-size:2em;font-weight:bold;color:{RISK_COLOR[risk]}">{risk} RISK</div> | |
| <div style="font-size:1.3em;color:{color};font-weight:600;margin-top:4px"> | |
| {label.upper()} | {score:.1%} confidence | |
| </div> | |
| </div> | |
| """ | |
| details = f"π **Analysis:** Classified as **{label.upper()}** with {score:.1%} confidence.\n\n" | |
| if label == "ransomware": | |
| details += "π **Threat:** File-encrypting malware. Isolate system immediately. Check backups." | |
| elif label == "trojan": | |
| details += "π΄ **Threat:** Trojan/Backdoor. Check network connections. Scan for lateral movement." | |
| elif label == "worm": | |
| details += "πͺ± **Threat:** Self-propagating worm. Isolate from network. Patch vulnerabilities." | |
| elif label == "spyware": | |
| details += "ποΈ **Threat:** Spyware/Keylogger. Reset credentials. Check outbound traffic." | |
| else: | |
| details += "β **Status:** No malicious indicators detected. Continue monitoring." | |
| return risk_html, fig, details | |
| # βββ Batch analyse ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def batch_analyse(): | |
| if _model_cache["pipe"] is None: | |
| return "β Model not loaded.", None | |
| rows = [] | |
| for alert in EXAMPLE_ALERTS: | |
| r = _model_cache["pipe"](alert)[0] | |
| rows.append({ | |
| "Alert": alert[:60] + "β¦", | |
| "Label": r["label"].upper(), | |
| "Confidence": f"{r['score']:.1%}", | |
| "Risk": "π΄ HIGH" if r["label"] != "benign" else "π’ LOW", | |
| }) | |
| df = pd.DataFrame(rows) | |
| # Pie chart | |
| counts = df["Label"].value_counts() | |
| pie_colors = [COLORS.get(l.lower(), "#999") for l in counts.index] | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| wedges, texts, autotexts = ax.pie( | |
| counts.values, labels=counts.index, autopct="%1.0f%%", | |
| colors=pie_colors, startangle=90, | |
| wedgeprops={"edgecolor": "white", "linewidth": 2}, | |
| ) | |
| for at in autotexts: | |
| at.set_fontsize(11); at.set_fontweight("bold"); at.set_color("white") | |
| ax.set_title("SOC Alert Distribution", fontsize=13, fontweight="bold") | |
| fig.patch.set_facecolor("#F8FAFC") | |
| plt.tight_layout() | |
| return df, fig | |
| # βββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| yield "β³ Loading and training DistilBERT malware classifier...\n(This takes ~60-90 seconds on CPU)" | |
| try: | |
| pipe, f1 = build_and_train() | |
| _model_cache["pipe"] = pipe | |
| _model_cache["trained"] = True | |
| yield f"β **Model Ready!** Weighted F1 = **{f1:.4f}**\n\nYou can now classify threat alerts below." | |
| except Exception as e: | |
| yield f"β Error: {str(e)}" | |
| # βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1100px; margin: auto; } | |
| .header-box { background: linear-gradient(135deg,#1B2A4A,#2563EB); | |
| padding:24px; border-radius:12px; color:white; margin-bottom:16px; } | |
| .risk-high { border-left: 6px solid #DC2626; } | |
| .risk-low { border-left: 6px solid #16A34A; } | |
| footer { display:none !important; } | |
| """ | |
| with gr.Blocks(css=CSS, title="DefenAI β Module 0: Malware Classifier") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header-box"> | |
| <h1 style="margin:0;font-size:1.8em">π‘οΈ DefenAI β Module 0</h1> | |
| <h2 style="margin:4px 0 0;font-weight:400;opacity:.9"> | |
| SOC Malware Alert Triage Β· DistilBERT Text Classifier | |
| </h2> | |
| <p style="margin:8px 0 0;opacity:.75;font-size:.9em"> | |
| Classify threat intelligence alerts as Ransomware Β· Trojan Β· Worm Β· Spyware Β· Benign | |
| </p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Setup ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βοΈ Step 1: Load Model"): | |
| gr.Markdown(""" | |
| ### How This App Works | |
| 1. Click **Load & Train Model** to fine-tune DistilBERT on synthetic malware alerts | |
| 2. Go to **Classify Alert** to test single threat intel texts | |
| 3. Go to **Batch SOC Triage** to analyse 8 pre-loaded alerts at once | |
| > **Module 0 Concept:** AI can automate SOC Level-1 alert triage using NLP models | |
| """) | |
| load_btn = gr.Button("π Load & Train Model", variant="primary", size="lg") | |
| status_box = gr.Markdown("_Model not loaded. Click above to start._") | |
| load_btn.click(load_model, outputs=status_box) | |
| # ββ Tab 2: Single classify βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Step 2: Classify Alert"): | |
| gr.Markdown("### Enter a threat intelligence alert text to classify") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| alert_input = gr.Textbox( | |
| label="Threat Intel Alert", | |
| placeholder="e.g. Ransomware encrypts all files and demands Bitcoin paymentβ¦", | |
| lines=3, | |
| ) | |
| gr.Examples( | |
| examples=[[e] for e in EXAMPLE_ALERTS[:5]], | |
| inputs=alert_input, | |
| label="π Example Alerts (click to load)", | |
| ) | |
| classify_btn = gr.Button("π Classify Alert", variant="primary") | |
| with gr.Column(scale=2): | |
| risk_display = gr.HTML(label="Risk Assessment") | |
| detail_text = gr.Markdown() | |
| conf_chart = gr.Plot(label="Confidence Scores by Category") | |
| classify_btn.click( | |
| classify_alert, | |
| inputs=alert_input, | |
| outputs=[risk_display, conf_chart, detail_text], | |
| ) | |
| # ββ Tab 3: Batch triage ββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Step 3: Batch SOC Triage"): | |
| gr.Markdown("### Analyse 8 pre-loaded SOC alerts at once") | |
| batch_btn = gr.Button("βΆοΈ Run Batch Triage", variant="primary", size="lg") | |
| batch_table = gr.Dataframe(label="Triage Results", interactive=False) | |
| dist_chart = gr.Plot(label="Alert Distribution") | |
| batch_btn.click(batch_analyse, outputs=[batch_table, dist_chart]) | |
| # ββ Tab 4: Learn βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Learn: How It Works"): | |
| gr.Markdown(""" | |
| ## π§ How DistilBERT Classifies Malware Alerts | |
| ### Model Architecture | |
| - **DistilBERT** is a smaller, faster version of BERT (66M parameters) | |
| - Fine-tuned on labelled malware descriptions for **5-class classification** | |
| - Runs entirely on **CPU** β no GPU needed | |
| ### Attack Categories | |
| | Label | Description | Example | | |
| |-------|-------------|---------| | |
| | π΄ Ransomware | Encrypts files, demands ransom | WannaCry, LockBit | | |
| | π Trojan | Disguised malware, backdoor | Agent Tesla, Emotet | | |
| | π‘ Worm | Self-propagating across network | Conficker, Stuxnet | | |
| | π£ Spyware | Surveillance, data theft | Pegasus, DarkComet | | |
| | π’ Benign | Legitimate software/update | Windows Update | | |
| ### Why F1 over Accuracy? | |
| Class imbalance (few malware vs many benign) makes accuracy misleading. | |
| **Weighted F1** balances precision and recall across all classes. | |
| ### Assessment Questions | |
| 1. Why is weighted F1 better than accuracy for imbalanced datasets? | |
| 2. A new alert says: *"Software modifies registry and disables antivirus"* β what label? | |
| 3. How would you update this model when new malware families emerge? | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |