{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a2b393f1", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\" # then import torch" ] }, { "cell_type": "code", "execution_count": 2, "id": "612e7253", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDCVE-IDCVSS-V3CVSS-V2SEVERITYDESCRIPTIONCWE-ID
01CVE-1999-0001NaN5.0MEDIUMip_input.c in BSD-derived TCP/IP implementatio...CWE-20
12CVE-1999-0002NaN10.0HIGHBuffer overflow in NFS mountd gives root acces...CWE-119
23CVE-1999-0003NaN10.0HIGHExecute commands as root via buffer overflow i...NVD-CWE-Other
34CVE-1999-0004NaN5.0MEDIUMMIME buffer overflow in email clients, e.g. So...NVD-CWE-Other
45CVE-1999-0005NaN10.0HIGHArbitrary command execution via IMAP buffer ov...NVD-CWE-Other
\n", "
" ], "text/plain": [ " ID CVE-ID CVSS-V3 CVSS-V2 SEVERITY \\\n", "0 1 CVE-1999-0001 NaN 5.0 MEDIUM \n", "1 2 CVE-1999-0002 NaN 10.0 HIGH \n", "2 3 CVE-1999-0003 NaN 10.0 HIGH \n", "3 4 CVE-1999-0004 NaN 5.0 MEDIUM \n", "4 5 CVE-1999-0005 NaN 10.0 HIGH \n", "\n", " DESCRIPTION CWE-ID \n", "0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20 \n", "1 Buffer overflow in NFS mountd gives root acces... CWE-119 \n", "2 Execute commands as root via buffer overflow i... NVD-CWE-Other \n", "3 MIME buffer overflow in email clients, e.g. So... NVD-CWE-Other \n", "4 Arbitrary command execution via IMAP buffer ov... NVD-CWE-Other " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the data\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# Import the dataset from data/Global_Dataset.csv\n", "df = pd.read_csv('data/Global_Dataset.csv')\n", "df.head()\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "162b7a0b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DESCRIPTIONCWE-ID
0ip_input.c in BSD-derived TCP/IP implementatio...CWE-20
1Buffer overflow in NFS mountd gives root acces...CWE-119
2Execute commands as root via buffer overflow i...NVD-CWE-Other
3MIME buffer overflow in email clients, e.g. So...NVD-CWE-Other
4Arbitrary command execution via IMAP buffer ov...NVD-CWE-Other
\n", "
" ], "text/plain": [ " DESCRIPTION CWE-ID\n", "0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20\n", "1 Buffer overflow in NFS mountd gives root acces... CWE-119\n", "2 Execute commands as root via buffer overflow i... NVD-CWE-Other\n", "3 MIME buffer overflow in email clients, e.g. So... NVD-CWE-Other\n", "4 Arbitrary command execution via IMAP buffer ov... NVD-CWE-Other" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get just the description and CWE-ID columns\n", "df_subset = df[['DESCRIPTION', 'CWE-ID']]\n", "df_subset.head()\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "a5473ff7", "metadata": {}, "outputs": [], "source": [ "# Get rid of the original dataframe\n", "del df" ] }, { "cell_type": "code", "execution_count": 5, "id": "72abf8d8", "metadata": {}, "outputs": [], "source": [ "# Train a Hugging Face classifier to map DESCRIPTION -> CWE-ID\n", "import os\n", "import json\n", "from datasets import Dataset\n", "from sklearn.model_selection import train_test_split\n", "from transformers import (\n", " AutoTokenizer,\n", " AutoModelForSequenceClassification,\n", " TrainingArguments,\n", " Trainer,\n", " DataCollatorWithPadding,\n", ")\n", "import numpy as np\n", "import evaluate\n", "import torch\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "e233b212", "metadata": {}, "outputs": [], "source": [ "\n", "MODEL_NAME = \"distilbert-base-uncased\"\n", "TEXT_COL = \"DESCRIPTION\"\n", "LABEL_COL = \"CWE-ID\"\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "29ac0839", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Preparing dataframe...\n", "Dropping overly generic buckets...\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DESCRIPTIONCWE-ID
0ip_input.c in BSD-derived TCP/IP implementatio...CWE-20
1Buffer overflow in NFS mountd gives root acces...CWE-119
6Information from SSL-encrypted sessions via PK...CWE-327
25root privileges via buffer overflow in eject c...CWE-119
176Windows NT crashes or locks up when a Samba cl...CWE-17
\n", "
" ], "text/plain": [ " DESCRIPTION CWE-ID\n", "0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20\n", "1 Buffer overflow in NFS mountd gives root acces... CWE-119\n", "6 Information from SSL-encrypted sessions via PK... CWE-327\n", "25 root privileges via buffer overflow in eject c... CWE-119\n", "176 Windows NT crashes or locks up when a Samba cl... CWE-17" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "# Prepare dataframe\n", "print(\"Preparing dataframe...\")\n", "df_model = df_subset.dropna(subset=[TEXT_COL, LABEL_COL]).copy()\n", "df_model[LABEL_COL] = df_model[LABEL_COL].astype(str).str.strip()\n", "\n", "print(\"Dropping overly generic buckets...\")\n", "df_model = df_model[~df_model[LABEL_COL].isin([\"NVD-CWE-Other\", \"NVD-CWE-noinfo\"])].copy()\n", "\n", "df_model.head()" ] }, { "cell_type": "code", "execution_count": 8, "id": "cb283ee4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label counts before filtering:\n", "CWE-ID\n", "CWE-79 17438\n", "CWE-119 11494\n", "CWE-20 8575\n", "CWE-89 6992\n", "CWE-200 6749\n", "CWE-264 5321\n", "CWE-787 5211\n", "CWE-22 4117\n", "CWE-125 3866\n", "CWE-352 3341\n", "Name: count, dtype: int64\n", "Labels with only 1 example: 72\n", "\n", "After filtering:\n", "Total examples: 124045\n", "Unique labels: 232\n" ] } ], "source": [ "# Filter out labels that only have 1 example\n", "# Count occurrences of each label\n", "label_counts = df_model[LABEL_COL].value_counts()\n", "print(f\"Label counts before filtering:\")\n", "print(label_counts.head(10))\n", "print(f\"Labels with only 1 example: {(label_counts == 1).sum()}\")\n", "\n", "# Keep only labels that have 2 or more examples\n", "labels_to_keep = label_counts[label_counts >= 2].index\n", "df_model = df_model[df_model[LABEL_COL].isin(labels_to_keep)].copy()\n", "\n", "print(f\"\\nAfter filtering:\")\n", "print(f\"Total examples: {len(df_model)}\")\n", "print(f\"Unique labels: {df_model[LABEL_COL].nunique()}\")\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "f55b8585", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updated num labels: 232\n" ] } ], "source": [ "\n", "# Update label maps after filtering\n", "unique_labels = sorted(df_model[LABEL_COL].unique())\n", "label2id = {label: i for i, label in enumerate(unique_labels)}\n", "id2label = {i: label for label, i in label2id.items()}\n", "num_labels = len(unique_labels)\n", "print(f\"Updated num labels: {num_labels}\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "0a67ebc1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['DESCRIPTION', 'CWE-ID', 'labels'],\n", " num_rows: 111640\n", "})\n", "Dataset({\n", " features: ['DESCRIPTION', 'CWE-ID', 'labels'],\n", " num_rows: 12405\n", "})\n" ] } ], "source": [ "\n", "# Train/val split\n", "train_df, val_df = train_test_split(\n", " df_model[[TEXT_COL, LABEL_COL]],\n", " test_size=0.1,\n", " random_state=1,\n", " stratify=df_model[LABEL_COL],\n", ")\n", "\n", "# Add numeric labels in pandas before converting to HF Datasets\n", "train_df = train_df.assign(labels=train_df[LABEL_COL].map(label2id))\n", "val_df = val_df.assign(labels=val_df[LABEL_COL].map(label2id))\n", "\n", "train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))\n", "val_ds = Dataset.from_pandas(val_df.reset_index(drop=True))\n", "\n", "print(train_ds)\n", "print(val_ds)" ] }, { "cell_type": "code", "execution_count": 11, "id": "567e7b1e", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f05985a8399f4591bfba5ad04c6b4b1d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/111640 [00:00\n", " \n", " \n", " [6978/6978 1:18:31, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossAccuracyF1
10001.0446001.2529400.7047160.220344
20001.1587001.1886770.7113260.229855
30001.1199001.1592290.7192260.235295
40001.1126001.1199240.7201930.242404
50001.1103001.1110530.7229340.244389
60001.1347001.0828060.7272070.251264

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=6978, training_loss=1.1011348515535433, metrics={'train_runtime': 4712.2885, 'train_samples_per_second': 23.691, 'train_steps_per_second': 1.481, 'total_flos': 2912105519756448.0, 'train_loss': 1.1011348515535433, 'epoch': 1.0})" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 33, "id": "48faf17c", "metadata": {}, "outputs": [], "source": [ "\n", "# Save artifacts\n", "os.makedirs(\"artifacts\", exist_ok=True)\n", "model.save_pretrained(\"artifacts/model\")\n", "tokenizer.save_pretrained(\"artifacts/model\")\n", "with open(\"artifacts/label_map.json\", \"w\") as f:\n", " json.dump({\"label2id\": label2id, \"id2label\": id2label}, f, indent=2)" ] }, { "cell_type": "code", "execution_count": 34, "id": "fcb11390", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a33b852d1c594a69974bb9c3d30c014a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n", "No files have been modified since last commit. Skipping to prevent empty commit.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✅ Model uploaded to: https://huggingface.co/mulliken/cwe-predictor\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\"artifacts/model\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"artifacts/model\")\n", "\n", "repo_name = \"mulliken/cwe-predictor\" # Change this!\n", "\n", "model.push_to_hub(repo_name, private=False)\n", "tokenizer.push_to_hub(repo_name, private=False)\n", "\n", "print(f\"✅ Model uploaded to: https://huggingface.co/{repo_name}\")" ] }, { "cell_type": "code", "execution_count": 35, "id": "33847880", "metadata": {}, "outputs": [], "source": [ "# Load the model\n", "model = AutoModelForSequenceClassification.from_pretrained(\"artifacts/model\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"artifacts/model\")\n", "id2label = {int(k): v for k, v in json.load(open(\"artifacts/label_map.json\"))[\"id2label\"].items()}\n", "label2id = json.load(open(\"artifacts/label_map.json\"))[\"label2id\"]\n", "\n", "# Quick inference helper\n", "def predict_cwe(text: str) -> str:\n", " encoded = tokenizer(text, return_tensors=\"pt\", truncation=True)\n", " with torch.no_grad():\n", " logits = model(**encoded).logits\n", " pred_id = int(torch.argmax(logits, dim=-1).item())\n", " return id2label[pred_id]\n" ] }, { "cell_type": "code", "execution_count": 36, "id": "cdeaadbb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CWE-119\n", "CWE-89\n", "CWE-79\n", "CWE-287\n", "CWE-22\n", "CWE-190\n", "CWE-401\n", "CWE-77\n", "CWE-326\n" ] } ], "source": [ "\n", "print(predict_cwe(\"Buffer overflow in POP servers allows remote attackers to gain root access using a long PASS command.\"))\n", "print(predict_cwe(\"SQL injection vulnerability in web application allows attackers to execute arbitrary SQL commands through user input fields.\"))\n", "print(predict_cwe(\"Cross-site scripting (XSS) vulnerability allows attackers to inject malicious scripts into web pages viewed by other users.\"))\n", "print(predict_cwe(\"Authentication bypass vulnerability allows unauthorized access to restricted areas of the application.\"))\n", "print(predict_cwe(\"Path traversal vulnerability enables attackers to access files outside the intended directory structure.\"))\n", "print(predict_cwe(\"Integer overflow condition causes unexpected behavior when processing large numeric values.\"))\n", "print(predict_cwe(\"Memory leak in network daemon causes gradual memory consumption leading to denial of service.\"))\n", "print(predict_cwe(\"Command injection vulnerability allows execution of arbitrary system commands through unsanitized user input.\"))\n", "print(predict_cwe(\"Weak cryptographic algorithm implementation makes encrypted data susceptible to brute force attacks.\"))" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }