SecureBERT2.0-final / README.md
bziemba's picture
Update README.md
fd8d4e5 verified
---
license: mit
datasets:
- bziemba/cve_cwe_cvss
language:
- en
base_model:
- cisco-ai/SecureBERT2.0-biencoder
tags:
- cybersecurity
- vulnerability-classification
- cvss
- cwe
- securebert
- multi-task-learning
---
# SecureBERT Vulnerability Classifier (CVSS & CWE Flat Classifier)
This model automatically analyzes raw vulnerability descriptions (e.g., CVE reports, bug bounty submissions) and predicts **CVSS v3.1 metrics** alongside a 4-level **CWE taxonomy** (Pillar, Class, Base, Variant).
It is a fine-tuned version of the domain-specific [`cisco-ai/SecureBERT2.0`](https://huggingface.co/cisco-ai/SecureBERT2.0) utilizing a Multi-Task Learning (MTL) architecture with flat classification heads.
## 🎯 Intended Use
The primary use case is automating the initial **Vulnerability Triage** process. By inputting unstructured threat narratives, security analysts can instantly receive:
* **8 CVSS v3.1 Metrics:** Attack Vector, Attack Complexity, Privileges Required, User Interaction, Scope, Confidentiality, Integrity, and Availability.
* **CWE Classification:** Probabilistic mapping to the MITRE CWE tree across 4 levels of abstraction (Top-K predictions).
## 🧠 Model Architecture
The model uses a shared `SecureBERT2.0` backbone with 12 distinct classification heads attached to the pooled outputs:
* **CVSS Heads (8):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax`. They use the `[CLS]` token embedding to predict nominal and ordinal CVSS categories.
* **CWE Heads (4):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear. These heads utilize the Mean-Pooled token embeddings.
## 📂 Repository Structure & Custom Config
Unlike standard Hugging Face models, this repository features a highly customized `config.json`. It dynamically dictates the architecture and handles label decoding.
* `cvss_map`: Contains the exact string labels for all 8 CVSS metrics (e.g., `["Network", "Adjacent", "Local", "Physical"]`).
* `cwe_labels`: Contains ID-to-Name mappings for all supported CWEs across `pillar`, `class`, `base`, and `variant` levels.
**Note:** Because of the custom multi-head architecture, you cannot use the default `AutoModelForSequenceClassification`. You must define the custom PyTorch class provided in the usage snippet below.
## 💻 Usage & Inference
Below is a complete, standalone Python snippet to load the model, tokenizer, and configuration directly from this Hugging Face repository and perform predictions.
```python
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
# 1. Define the Custom Architecture
class SecureBERTFlatClassifier(nn.Module):
def __init__(self, model_name, cvss_map, class_counts):
super().__init__()
config = AutoConfig.from_pretrained(model_name)
if hasattr(config, "reference_compile"): config.reference_compile = False
self.bert = AutoModel.from_pretrained(model_name, config=config)
def make_head(out_features, is_cvss=False):
layers =[
nn.LayerNorm(768), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
nn.Linear(768, out_features)
]
if is_cvss: layers.append(nn.Softmax(dim=1))
return nn.Sequential(*layers)
self.cvss_heads = nn.ModuleDict({k: make_head(len(v), True) for k, v in cvss_map.items()})
self.cwe_heads = nn.ModuleDict({k: make_head(v) for k, v in class_counts.items()})
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
cls_emb = out[:, 0, :]
mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
res = {}
for k, head in self.cvss_heads.items(): res[k] = head(cls_emb)
for k, head in self.cwe_heads.items(): res[k] = head(mean_emb)
return res
# 2. Inference Wrapper
class VulnPredictor:
def __init__(self, repo_id):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
with open(conf_path, "r") as f: self.config = json.load(f)
base_model = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
counts = {k: len(v) for k, v in self.config.get("cwe_labels", {}).items()}
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.model = SecureBERTFlatClassifier(base_model, self.config["cvss_map"], counts)
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
self.model.to(self.device).eval()
def predict(self, text, top_k=3):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
out = self.model(inputs['input_ids'], inputs['attention_mask'])
res = {'cvss': {}, 'cwe': {}}
for task, labels in self.config.get("cvss_map", {}).items():
score, idx = torch.max(out[task], dim=1)
res['cvss'][task] = {"value": labels[idx.item()], "confidence": round(score.item(), 4)}
for lv, cwe_data in self.config.get("cwe_labels", {}).items():
if lv in out:
probs = F.softmax(out[lv], dim=1)
scores, idxs = torch.topk(probs, k=min(top_k, probs.size(1)))
res['cwe'][lv] =[
{"id": int(str(cwe_data[i.item()]['id']).replace('CWE-','')),
"name": cwe_data[i.item()]['name'],
"score": round(s.item(), 4)} for s, i in zip(scores[0], idxs[0])
]
return res
# 3. Quickstart
if __name__ == "__main__":
REPO_ID = "bziemba/SecureBERT2.0-final"
predictor = VulnPredictor(REPO_ID)
sample_cve = "An issue was discovered in the login panel allowing attackers to bypass authentication via crafted SQL queries."
results = predictor.predict(sample_cve)
print(json.dumps(results, indent=2))