{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CYB002 Baseline Classifier — Inference Example\n", "\n", "End-to-end demo: load the trained XGBoost and PyTorch MLP models from the Hugging Face repo and predict the **MITRE ATT&CK kill-chain phase** of a new attack-event record.\n", "\n", "**Models predict one of 10 phases:** `dwell_idle`, `reconnaissance`, `initial_access`, `execution`, `persistence`, `privilege_escalation`, `lateral_movement`, `collection`, `exfiltration`, `impact`.\n", "\n", "**This is a baseline reference model**, not a production threat detector. See the model card for full metrics and limitations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install --quiet xgboost torch safetensors pandas numpy huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Download model artifacts from Hugging Face\n", "\n", "Five files are needed:\n", "- `model_xgb.json` — XGBoost weights\n", "- `model_mlp.safetensors` — PyTorch MLP weights\n", "- `feature_engineering.py` — feature pipeline (must match the one used at training)\n", "- `feature_meta.json` — feature column order + categorical levels\n", "- `feature_scaler.json` — MLP input standardization (mean / std)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "REPO_ID = \"xpertsystems/cyb002-baseline-classifier\"\n", "\n", "files = {}\n", "for name in [\"model_xgb.json\", \"model_mlp.safetensors\",\n", " \"feature_engineering.py\", \"feature_meta.json\",\n", " \"feature_scaler.json\"]:\n", " files[name] = hf_hub_download(repo_id=REPO_ID, filename=name)\n", " print(f\" downloaded: {name}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Make feature_engineering.py importable\n", "import sys, os\n", "fe_dir = os.path.dirname(files[\"feature_engineering.py\"])\n", "if fe_dir not in sys.path:\n", " sys.path.insert(0, fe_dir)\n", "\n", "from feature_engineering import (\n", " transform_single, load_meta, INT_TO_LABEL, build_segment_lookup\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load models and metadata" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import xgboost as xgb\n", "from safetensors.torch import load_file\n", "\n", "meta = load_meta(files[\"feature_meta.json\"])\n", "with open(files[\"feature_scaler.json\"]) as f:\n", " scaler = json.load(f)\n", "\n", "N_FEATURES = len(meta[\"feature_names\"])\n", "N_CLASSES = len(meta[\"int_to_label\"])\n", "print(f\"feature count: {N_FEATURES}\")\n", "print(f\"class count: {N_CLASSES}\")\n", "print(f\"label classes: {list(meta['int_to_label'].values())}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# XGBoost\n", "xgb_model = xgb.XGBClassifier()\n", "xgb_model.load_model(files[\"model_xgb.json\"])\n", "\n", "# MLP architecture (must match training)\n", "class PhaseMLP(nn.Module):\n", " def __init__(self, n_features, n_classes=10, hidden1=128, hidden2=64, dropout=0.3):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_features, hidden1),\n", " nn.BatchNorm1d(hidden1),\n", " nn.ReLU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden1, hidden2),\n", " nn.BatchNorm1d(hidden2),\n", " nn.ReLU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden2, n_classes),\n", " )\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "mlp_model = PhaseMLP(N_FEATURES, n_classes=N_CLASSES)\n", "mlp_model.load_state_dict(load_file(files[\"model_mlp.safetensors\"]))\n", "mlp_model.eval()\n", "print(\"models loaded\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Build segment-aggregate lookup from the dataset\n", "\n", "Per-segment topology aggregates (mean exposure, fraction with EDR, etc.) are computed at training time and must be available at inference time too. The helper `build_segment_lookup` pulls them from `network_topology.csv`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "\n", "ds_path = snapshot_download(repo_id=\"xpertsystems/cyb002-sample\", repo_type=\"dataset\")\n", "\n", "import os\n", "segment_aggregates_lookup = build_segment_lookup(\n", " os.path.join(ds_path, \"network_topology.csv\")\n", ")\n", "print(f\"loaded {len(segment_aggregates_lookup)} segment aggregates\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Prediction helper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "MU = np.array(scaler[\"mean\"], dtype=np.float32)\n", "SD = np.array(scaler[\"std\"], dtype=np.float32)\n", "\n", "def predict_phase(record: dict) -> dict:\n", " \"\"\"Predict the kill-chain phase for one event record.\n", "\n", " `record` is a dict with event-level fields. Segment-level aggregates\n", " are pulled automatically from `segment_aggregates_lookup` using the\n", " `target_segment_id` field.\n", "\n", " Returns a dict with both models' predictions and per-class probabilities.\n", " \"\"\"\n", " seg_id = record.get(\"target_segment_id\")\n", " seg_agg = segment_aggregates_lookup.get(seg_id, {})\n", " X = transform_single(record, meta, segment_aggregates=seg_agg)\n", "\n", " xgb_proba = xgb_model.predict_proba(X)[0]\n", " xgb_label = INT_TO_LABEL[int(np.argmax(xgb_proba))]\n", "\n", " Xs = ((X - MU) / SD).astype(np.float32)\n", " with torch.no_grad():\n", " logits = mlp_model(torch.tensor(Xs))\n", " mlp_proba = torch.softmax(logits, dim=1).numpy()[0]\n", " mlp_label = INT_TO_LABEL[int(np.argmax(mlp_proba))]\n", "\n", " return {\n", " \"xgboost\": {\n", " \"label\": xgb_label,\n", " \"probabilities\": {INT_TO_LABEL[i]: float(p) for i, p in enumerate(xgb_proba)},\n", " },\n", " \"mlp\": {\n", " \"label\": mlp_label,\n", " \"probabilities\": {INT_TO_LABEL[i]: float(p) for i, p in enumerate(mlp_proba)},\n", " },\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Run on an example record\n", "\n", "This is a real `reconnaissance` event lifted from the sample dataset: opportunistic attacker scanning an email server early in a campaign (timestep 0). Both models should predict `reconnaissance`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Real attack event from the sample dataset (true label: reconnaissance)\n", "example_record = {\n", " \"campaign_id\": \"CAMP-000030\",\n", " \"attacker_id\": \"ATK-0003\",\n", " \"timestep\": 0,\n", " \"target_segment_id\": \"SEG-0008\",\n", " \"target_asset_type\": \"email_server\",\n", " \"source_ip_class\": \"vpn_tunnel\",\n", " \"dest_port\": 22,\n", " \"protocol\": \"icmp\",\n", " \"bytes_transferred\": 15648.48,\n", " \"connection_duration_s\": 3.913,\n", " \"auth_failure_count\": 0,\n", " \"process_injection_flag\": 0,\n", " \"lateral_hop_count\": 0,\n", " \"c2_beacon_interval_s\": 0.0,\n", " \"detection_outcome\": \"edr_blocked\",\n", " \"alert_severity\": \"critical\",\n", " \"siem_rule_triggered\": 0,\n", " \"edr_blocked_flag\": 1,\n", " \"attacker_capability_tier\": \"opportunistic\",\n", " \"defender_maturity_level\": \"baseline\",\n", "}\n", "\n", "result = predict_phase(example_record)\n", "\n", "print(f\"XGBoost -> {result['xgboost']['label']}\")\n", "for lbl, p in sorted(result['xgboost']['probabilities'].items(), key=lambda x: -x[1])[:5]:\n", " print(f\" P({lbl:25s}) = {p:.4f}\")\n", "\n", "print(f\"\\nMLP -> {result['mlp']['label']}\")\n", "for lbl, p in sorted(result['mlp']['probabilities'].items(), key=lambda x: -x[1])[:5]:\n", " print(f\" P({lbl:25s}) = {p:.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Note: when the two models disagree\n", "\n", "XGBoost and the MLP can disagree on out-of-distribution records — particularly hand-crafted inputs whose feature combinations don't sit on the training-data manifold. The MLP, with BatchNorm and a small training set, has narrower competence than the tree ensemble. Disagreement is a useful triage signal: in a SOC workflow, conflicting predictions are flows worth a human eyeball." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Batch prediction on the sample dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "events = pd.read_csv(os.path.join(ds_path, \"attack_events.csv\"))\n", "\n", "# Drop leakage columns the model was never trained on\n", "events = events.drop(columns=[\"technique_id\", \"technique_name\", \"tactic_category\"],\n", " errors=\"ignore\")\n", "\n", "# Score the first 200 events\n", "sample = events.head(200).copy()\n", "preds = [predict_phase(row.to_dict())[\"xgboost\"][\"label\"] for _, row in sample.iterrows()]\n", "sample[\"xgb_pred\"] = preds\n", "\n", "ct = pd.crosstab(sample[\"kill_chain_phase\"], sample[\"xgb_pred\"],\n", " rownames=[\"true\"], colnames=[\"pred\"])\n", "print(\"Confusion on first 200 sample rows (XGBoost):\")\n", "print(ct)\n", "acc = (sample[\"kill_chain_phase\"] == sample[\"xgb_pred\"]).mean()\n", "print(f\"\\nbatch accuracy on first 200 (in-distribution): {acc:.4f}\")\n", "print(\"\\nNote: this includes training-set events. See validation_results.json\\n\"\n", " \"for proper held-out test-set metrics from disjoint campaigns.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Next steps\n", "\n", "- See `validation_results.json` for held-out test-set metrics (15 disjoint campaigns, 726 events).\n", "- See `ablation_results.json` for per-feature-group contribution. `timestep` is by far the most predictive feature, which is honest: kill-chain phases progress in time, so where you are in the campaign timeline carries most of the phase signal.\n", "- The model card's **Limitations** section explains the gap between this baseline and production threat-detection systems.\n", "- For the full 380k-row CYB002 dataset and commercial licensing, contact **pradeep@xpertsystems.ai**." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" } }, "nbformat": 4, "nbformat_minor": 5 }