File size: 6,567 Bytes
46750c9 4c82e5c 46750c9 4c82e5c fd8d4e5 4c82e5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | ---
license: mit
datasets:
- bziemba/cve_cwe_cvss
language:
- en
base_model:
- cisco-ai/SecureBERT2.0-biencoder
tags:
- cybersecurity
- vulnerability-classification
- cvss
- cwe
- securebert
- multi-task-learning
---
# SecureBERT Vulnerability Classifier (CVSS & CWE Flat Classifier)
This model automatically analyzes raw vulnerability descriptions (e.g., CVE reports, bug bounty submissions) and predicts **CVSS v3.1 metrics** alongside a 4-level **CWE taxonomy** (Pillar, Class, Base, Variant).
It is a fine-tuned version of the domain-specific [`cisco-ai/SecureBERT2.0`](https://huggingface.co/cisco-ai/SecureBERT2.0) utilizing a Multi-Task Learning (MTL) architecture with flat classification heads.
## 🎯 Intended Use
The primary use case is automating the initial **Vulnerability Triage** process. By inputting unstructured threat narratives, security analysts can instantly receive:
* **8 CVSS v3.1 Metrics:** Attack Vector, Attack Complexity, Privileges Required, User Interaction, Scope, Confidentiality, Integrity, and Availability.
* **CWE Classification:** Probabilistic mapping to the MITRE CWE tree across 4 levels of abstraction (Top-K predictions).
## 🧠 Model Architecture
The model uses a shared `SecureBERT2.0` backbone with 12 distinct classification heads attached to the pooled outputs:
* **CVSS Heads (8):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax`. They use the `[CLS]` token embedding to predict nominal and ordinal CVSS categories.
* **CWE Heads (4):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear. These heads utilize the Mean-Pooled token embeddings.
## 📂 Repository Structure & Custom Config
Unlike standard Hugging Face models, this repository features a highly customized `config.json`. It dynamically dictates the architecture and handles label decoding.
* `cvss_map`: Contains the exact string labels for all 8 CVSS metrics (e.g., `["Network", "Adjacent", "Local", "Physical"]`).
* `cwe_labels`: Contains ID-to-Name mappings for all supported CWEs across `pillar`, `class`, `base`, and `variant` levels.
**Note:** Because of the custom multi-head architecture, you cannot use the default `AutoModelForSequenceClassification`. You must define the custom PyTorch class provided in the usage snippet below.
## 💻 Usage & Inference
Below is a complete, standalone Python snippet to load the model, tokenizer, and configuration directly from this Hugging Face repository and perform predictions.
```python
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# 1. Define the Custom Architecture
class SecureBERTFlatClassifier(nn.Module):
def __init__(self, model_name, cvss_map, class_counts):
super().__init__()
config = AutoConfig.from_pretrained(model_name)
if hasattr(config, "reference_compile"): config.reference_compile = False
self.bert = AutoModel.from_pretrained(model_name, config=config)
def make_head(out_features, is_cvss=False):
layers =[
nn.LayerNorm(768), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, out_features)
]
if is_cvss: layers.append(nn.Softmax(dim=1))
return nn.Sequential(*layers)
self.cvss_heads = nn.ModuleDict({k: make_head(len(v), True) for k, v in cvss_map.items()})
self.cwe_heads = nn.ModuleDict({k: make_head(v) for k, v in class_counts.items()})
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
cls_emb = out[:, 0, :]
mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
res = {}
for k, head in self.cvss_heads.items(): res[k] = head(cls_emb)
for k, head in self.cwe_heads.items(): res[k] = head(mean_emb)
return res
# 2. Inference Wrapper
class VulnPredictor:
def __init__(self, repo_id):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
with open(conf_path, "r") as f: self.config = json.load(f)
base_model = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
counts = {k: len(v) for k, v in self.config.get("cwe_labels", {}).items()}
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.model = SecureBERTFlatClassifier(base_model, self.config["cvss_map"], counts)
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
self.model.to(self.device).eval()
def predict(self, text, top_k=3):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
out = self.model(inputs['input_ids'], inputs['attention_mask'])
res = {'cvss': {}, 'cwe': {}}
for task, labels in self.config.get("cvss_map", {}).items():
score, idx = torch.max(out[task], dim=1)
res['cvss'][task] = {"value": labels[idx.item()], "confidence": round(score.item(), 4)}
for lv, cwe_data in self.config.get("cwe_labels", {}).items():
if lv in out:
probs = F.softmax(out[lv], dim=1)
scores, idxs = torch.topk(probs, k=min(top_k, probs.size(1)))
res['cwe'][lv] =[
{"id": int(str(cwe_data[i.item()]['id']).replace('CWE-','')),
"name": cwe_data[i.item()]['name'],
"score": round(s.item(), 4)} for s, i in zip(scores[0], idxs[0])
]
return res
# 3. Quickstart
if __name__ == "__main__":
REPO_ID = "bziemba/SecureBERT2.0-final"
predictor = VulnPredictor(REPO_ID)
sample_cve = "An issue was discovered in the login panel allowing attackers to bypass authentication via crafted SQL queries."
results = predictor.predict(sample_cve)
print(json.dumps(results, indent=2))
|