{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a2b393f1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\" # then import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "612e7253",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" ID | \n",
" CVE-ID | \n",
" CVSS-V3 | \n",
" CVSS-V2 | \n",
" SEVERITY | \n",
" DESCRIPTION | \n",
" CWE-ID | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" CVE-1999-0001 | \n",
" NaN | \n",
" 5.0 | \n",
" MEDIUM | \n",
" ip_input.c in BSD-derived TCP/IP implementatio... | \n",
" CWE-20 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2 | \n",
" CVE-1999-0002 | \n",
" NaN | \n",
" 10.0 | \n",
" HIGH | \n",
" Buffer overflow in NFS mountd gives root acces... | \n",
" CWE-119 | \n",
"
\n",
" \n",
" | 2 | \n",
" 3 | \n",
" CVE-1999-0003 | \n",
" NaN | \n",
" 10.0 | \n",
" HIGH | \n",
" Execute commands as root via buffer overflow i... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
" | 3 | \n",
" 4 | \n",
" CVE-1999-0004 | \n",
" NaN | \n",
" 5.0 | \n",
" MEDIUM | \n",
" MIME buffer overflow in email clients, e.g. So... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
" | 4 | \n",
" 5 | \n",
" CVE-1999-0005 | \n",
" NaN | \n",
" 10.0 | \n",
" HIGH | \n",
" Arbitrary command execution via IMAP buffer ov... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" ID CVE-ID CVSS-V3 CVSS-V2 SEVERITY \\\n",
"0 1 CVE-1999-0001 NaN 5.0 MEDIUM \n",
"1 2 CVE-1999-0002 NaN 10.0 HIGH \n",
"2 3 CVE-1999-0003 NaN 10.0 HIGH \n",
"3 4 CVE-1999-0004 NaN 5.0 MEDIUM \n",
"4 5 CVE-1999-0005 NaN 10.0 HIGH \n",
"\n",
" DESCRIPTION CWE-ID \n",
"0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20 \n",
"1 Buffer overflow in NFS mountd gives root acces... CWE-119 \n",
"2 Execute commands as root via buffer overflow i... NVD-CWE-Other \n",
"3 MIME buffer overflow in email clients, e.g. So... NVD-CWE-Other \n",
"4 Arbitrary command execution via IMAP buffer ov... NVD-CWE-Other "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the data\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Import the dataset from data/Global_Dataset.csv\n",
"df = pd.read_csv('data/Global_Dataset.csv')\n",
"df.head()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "162b7a0b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" DESCRIPTION | \n",
" CWE-ID | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" ip_input.c in BSD-derived TCP/IP implementatio... | \n",
" CWE-20 | \n",
"
\n",
" \n",
" | 1 | \n",
" Buffer overflow in NFS mountd gives root acces... | \n",
" CWE-119 | \n",
"
\n",
" \n",
" | 2 | \n",
" Execute commands as root via buffer overflow i... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
" | 3 | \n",
" MIME buffer overflow in email clients, e.g. So... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
" | 4 | \n",
" Arbitrary command execution via IMAP buffer ov... | \n",
" NVD-CWE-Other | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" DESCRIPTION CWE-ID\n",
"0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20\n",
"1 Buffer overflow in NFS mountd gives root acces... CWE-119\n",
"2 Execute commands as root via buffer overflow i... NVD-CWE-Other\n",
"3 MIME buffer overflow in email clients, e.g. So... NVD-CWE-Other\n",
"4 Arbitrary command execution via IMAP buffer ov... NVD-CWE-Other"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get just the description and CWE-ID columns\n",
"df_subset = df[['DESCRIPTION', 'CWE-ID']]\n",
"df_subset.head()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a5473ff7",
"metadata": {},
"outputs": [],
"source": [
"# Get rid of the original dataframe\n",
"del df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "72abf8d8",
"metadata": {},
"outputs": [],
"source": [
"# Train a Hugging Face classifier to map DESCRIPTION -> CWE-ID\n",
"import os\n",
"import json\n",
"from datasets import Dataset\n",
"from sklearn.model_selection import train_test_split\n",
"from transformers import (\n",
" AutoTokenizer,\n",
" AutoModelForSequenceClassification,\n",
" TrainingArguments,\n",
" Trainer,\n",
" DataCollatorWithPadding,\n",
")\n",
"import numpy as np\n",
"import evaluate\n",
"import torch\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e233b212",
"metadata": {},
"outputs": [],
"source": [
"\n",
"MODEL_NAME = \"distilbert-base-uncased\"\n",
"TEXT_COL = \"DESCRIPTION\"\n",
"LABEL_COL = \"CWE-ID\"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "29ac0839",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Preparing dataframe...\n",
"Dropping overly generic buckets...\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" DESCRIPTION | \n",
" CWE-ID | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" ip_input.c in BSD-derived TCP/IP implementatio... | \n",
" CWE-20 | \n",
"
\n",
" \n",
" | 1 | \n",
" Buffer overflow in NFS mountd gives root acces... | \n",
" CWE-119 | \n",
"
\n",
" \n",
" | 6 | \n",
" Information from SSL-encrypted sessions via PK... | \n",
" CWE-327 | \n",
"
\n",
" \n",
" | 25 | \n",
" root privileges via buffer overflow in eject c... | \n",
" CWE-119 | \n",
"
\n",
" \n",
" | 176 | \n",
" Windows NT crashes or locks up when a Samba cl... | \n",
" CWE-17 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" DESCRIPTION CWE-ID\n",
"0 ip_input.c in BSD-derived TCP/IP implementatio... CWE-20\n",
"1 Buffer overflow in NFS mountd gives root acces... CWE-119\n",
"6 Information from SSL-encrypted sessions via PK... CWE-327\n",
"25 root privileges via buffer overflow in eject c... CWE-119\n",
"176 Windows NT crashes or locks up when a Samba cl... CWE-17"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"# Prepare dataframe\n",
"print(\"Preparing dataframe...\")\n",
"df_model = df_subset.dropna(subset=[TEXT_COL, LABEL_COL]).copy()\n",
"df_model[LABEL_COL] = df_model[LABEL_COL].astype(str).str.strip()\n",
"\n",
"print(\"Dropping overly generic buckets...\")\n",
"df_model = df_model[~df_model[LABEL_COL].isin([\"NVD-CWE-Other\", \"NVD-CWE-noinfo\"])].copy()\n",
"\n",
"df_model.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb283ee4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label counts before filtering:\n",
"CWE-ID\n",
"CWE-79 17438\n",
"CWE-119 11494\n",
"CWE-20 8575\n",
"CWE-89 6992\n",
"CWE-200 6749\n",
"CWE-264 5321\n",
"CWE-787 5211\n",
"CWE-22 4117\n",
"CWE-125 3866\n",
"CWE-352 3341\n",
"Name: count, dtype: int64\n",
"Labels with only 1 example: 72\n",
"\n",
"After filtering:\n",
"Total examples: 124045\n",
"Unique labels: 232\n"
]
}
],
"source": [
"# Filter out labels that only have 1 example\n",
"# Count occurrences of each label\n",
"label_counts = df_model[LABEL_COL].value_counts()\n",
"print(f\"Label counts before filtering:\")\n",
"print(label_counts.head(10))\n",
"print(f\"Labels with only 1 example: {(label_counts == 1).sum()}\")\n",
"\n",
"# Keep only labels that have 2 or more examples\n",
"labels_to_keep = label_counts[label_counts >= 2].index\n",
"df_model = df_model[df_model[LABEL_COL].isin(labels_to_keep)].copy()\n",
"\n",
"print(f\"\\nAfter filtering:\")\n",
"print(f\"Total examples: {len(df_model)}\")\n",
"print(f\"Unique labels: {df_model[LABEL_COL].nunique()}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f55b8585",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updated num labels: 232\n"
]
}
],
"source": [
"\n",
"# Update label maps after filtering\n",
"unique_labels = sorted(df_model[LABEL_COL].unique())\n",
"label2id = {label: i for i, label in enumerate(unique_labels)}\n",
"id2label = {i: label for label, i in label2id.items()}\n",
"num_labels = len(unique_labels)\n",
"print(f\"Updated num labels: {num_labels}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0a67ebc1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['DESCRIPTION', 'CWE-ID', 'labels'],\n",
" num_rows: 111640\n",
"})\n",
"Dataset({\n",
" features: ['DESCRIPTION', 'CWE-ID', 'labels'],\n",
" num_rows: 12405\n",
"})\n"
]
}
],
"source": [
"\n",
"# Train/val split\n",
"train_df, val_df = train_test_split(\n",
" df_model[[TEXT_COL, LABEL_COL]],\n",
" test_size=0.1,\n",
" random_state=1,\n",
" stratify=df_model[LABEL_COL],\n",
")\n",
"\n",
"# Add numeric labels in pandas before converting to HF Datasets\n",
"train_df = train_df.assign(labels=train_df[LABEL_COL].map(label2id))\n",
"val_df = val_df.assign(labels=val_df[LABEL_COL].map(label2id))\n",
"\n",
"train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))\n",
"val_ds = Dataset.from_pandas(val_df.reset_index(drop=True))\n",
"\n",
"print(train_ds)\n",
"print(val_ds)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "567e7b1e",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f05985a8399f4591bfba5ad04c6b4b1d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/111640 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c641109db71c435ea5cb891d3a1e9767",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/12405 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['labels', 'input_ids', 'attention_mask'],\n",
" num_rows: 111640\n",
"})\n",
"Dataset({\n",
" features: ['labels', 'input_ids', 'attention_mask'],\n",
" num_rows: 12405\n",
"})\n"
]
}
],
"source": [
"\n",
"# Tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
"\n",
"def tokenize_fn(batch):\n",
" return tokenizer(batch[TEXT_COL], truncation=True, max_length=512)\n",
"\n",
"remove_cols_train = [c for c in train_ds.column_names if c in [TEXT_COL, LABEL_COL, \"__index_level_0__\"]]\n",
"remove_cols_val = [c for c in val_ds.column_names if c in [TEXT_COL, LABEL_COL, \"__index_level_0__\"]]\n",
"\n",
"enc_train = train_ds.map(tokenize_fn, batched=True, remove_columns=remove_cols_train)\n",
"enc_val = val_ds.map(tokenize_fn, batched=True, remove_columns=remove_cols_val)\n",
"\n",
"print(enc_train)\n",
"print(enc_val)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "1dff2692",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"# Model\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" MODEL_NAME,\n",
" num_labels=num_labels,\n",
" id2label=id2label,\n",
" label2id=label2id,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "e59de9ea",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Training setup\n",
"args = TrainingArguments(\n",
" output_dir=\"./results\",\n",
" save_strategy=\"steps\",\n",
" eval_strategy=\"steps\",\n",
" eval_steps=1000,\n",
" learning_rate=2e-5,\n",
" per_device_train_batch_size=2,\n",
" per_device_eval_batch_size=2,\n",
" gradient_accumulation_steps=8,\n",
" num_train_epochs=1,\n",
" weight_decay=0.01,\n",
" load_best_model_at_end=False,\n",
" dataloader_num_workers=0,\n",
" dataloader_pin_memory=False,\n",
" dataloader_persistent_workers=False,\n",
" report_to=[],\n",
")\n",
"\n",
"model.gradient_checkpointing_enable()\n",
"\n",
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
"accuracy = evaluate.load(\"accuracy\")\n",
"f1 = evaluate.load(\"f1\")\n",
"\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" preds = np.argmax(logits, axis=-1)\n",
" return {\n",
" \"accuracy\": accuracy.compute(predictions=preds, references=labels)[\"accuracy\"],\n",
" \"f1\": f1.compute(predictions=preds, references=labels, average=\"macro\")[\"f1\"],\n",
" }\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=enc_train,\n",
" eval_dataset=enc_val,\n",
" processing_class=tokenizer,\n",
" data_collator=data_collator,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "4b087fae",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [6978/6978 1:18:31, Epoch 1/1]\n",
"
\n",
" \n",
" \n",
" \n",
" | Step | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Accuracy | \n",
" F1 | \n",
"
\n",
" \n",
" \n",
" \n",
" | 1000 | \n",
" 1.044600 | \n",
" 1.252940 | \n",
" 0.704716 | \n",
" 0.220344 | \n",
"
\n",
" \n",
" | 2000 | \n",
" 1.158700 | \n",
" 1.188677 | \n",
" 0.711326 | \n",
" 0.229855 | \n",
"
\n",
" \n",
" | 3000 | \n",
" 1.119900 | \n",
" 1.159229 | \n",
" 0.719226 | \n",
" 0.235295 | \n",
"
\n",
" \n",
" | 4000 | \n",
" 1.112600 | \n",
" 1.119924 | \n",
" 0.720193 | \n",
" 0.242404 | \n",
"
\n",
" \n",
" | 5000 | \n",
" 1.110300 | \n",
" 1.111053 | \n",
" 0.722934 | \n",
" 0.244389 | \n",
"
\n",
" \n",
" | 6000 | \n",
" 1.134700 | \n",
" 1.082806 | \n",
" 0.727207 | \n",
" 0.251264 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=6978, training_loss=1.1011348515535433, metrics={'train_runtime': 4712.2885, 'train_samples_per_second': 23.691, 'train_steps_per_second': 1.481, 'total_flos': 2912105519756448.0, 'train_loss': 1.1011348515535433, 'epoch': 1.0})"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "48faf17c",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Save artifacts\n",
"os.makedirs(\"artifacts\", exist_ok=True)\n",
"model.save_pretrained(\"artifacts/model\")\n",
"tokenizer.save_pretrained(\"artifacts/model\")\n",
"with open(\"artifacts/label_map.json\", \"w\") as f:\n",
" json.dump({\"label2id\": label2id, \"id2label\": id2label}, f, indent=2)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "fcb11390",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a33b852d1c594a69974bb9c3d30c014a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0.00B [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"No files have been modified since last commit. Skipping to prevent empty commit.\n",
"No files have been modified since last commit. Skipping to prevent empty commit.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Model uploaded to: https://huggingface.co/mulliken/cwe-predictor\n"
]
}
],
"source": [
"model = AutoModelForSequenceClassification.from_pretrained(\"artifacts/model\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"artifacts/model\")\n",
"\n",
"repo_name = \"mulliken/cwe-predictor\" # Change this!\n",
"\n",
"model.push_to_hub(repo_name, private=False)\n",
"tokenizer.push_to_hub(repo_name, private=False)\n",
"\n",
"print(f\"✅ Model uploaded to: https://huggingface.co/{repo_name}\")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "33847880",
"metadata": {},
"outputs": [],
"source": [
"# Load the model\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"artifacts/model\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"artifacts/model\")\n",
"id2label = {int(k): v for k, v in json.load(open(\"artifacts/label_map.json\"))[\"id2label\"].items()}\n",
"label2id = json.load(open(\"artifacts/label_map.json\"))[\"label2id\"]\n",
"\n",
"# Quick inference helper\n",
"def predict_cwe(text: str) -> str:\n",
" encoded = tokenizer(text, return_tensors=\"pt\", truncation=True)\n",
" with torch.no_grad():\n",
" logits = model(**encoded).logits\n",
" pred_id = int(torch.argmax(logits, dim=-1).item())\n",
" return id2label[pred_id]\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "cdeaadbb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CWE-119\n",
"CWE-89\n",
"CWE-79\n",
"CWE-287\n",
"CWE-22\n",
"CWE-190\n",
"CWE-401\n",
"CWE-77\n",
"CWE-326\n"
]
}
],
"source": [
"\n",
"print(predict_cwe(\"Buffer overflow in POP servers allows remote attackers to gain root access using a long PASS command.\"))\n",
"print(predict_cwe(\"SQL injection vulnerability in web application allows attackers to execute arbitrary SQL commands through user input fields.\"))\n",
"print(predict_cwe(\"Cross-site scripting (XSS) vulnerability allows attackers to inject malicious scripts into web pages viewed by other users.\"))\n",
"print(predict_cwe(\"Authentication bypass vulnerability allows unauthorized access to restricted areas of the application.\"))\n",
"print(predict_cwe(\"Path traversal vulnerability enables attackers to access files outside the intended directory structure.\"))\n",
"print(predict_cwe(\"Integer overflow condition causes unexpected behavior when processing large numeric values.\"))\n",
"print(predict_cwe(\"Memory leak in network daemon causes gradual memory consumption leading to denial of service.\"))\n",
"print(predict_cwe(\"Command injection vulnerability allows execution of arbitrary system commands through unsanitized user input.\"))\n",
"print(predict_cwe(\"Weak cryptographic algorithm implementation makes encrypted data susceptible to brute force attacks.\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}