File size: 6,567 Bytes
46750c9
4c82e5c
 
46750c9
 
 
 
 
 
4c82e5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8d4e5
4c82e5c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---

license: mit
datasets:
- bziemba/cve_cwe_cvss
language:
- en
base_model:
- cisco-ai/SecureBERT2.0-biencoder
tags:
- cybersecurity
- vulnerability-classification
- cvss
- cwe
- securebert
- multi-task-learning
---

# SecureBERT Vulnerability Classifier (CVSS & CWE Flat Classifier)

This model automatically analyzes raw vulnerability descriptions (e.g., CVE reports, bug bounty submissions) and predicts **CVSS v3.1 metrics** alongside a 4-level **CWE taxonomy** (Pillar, Class, Base, Variant). 

It is a fine-tuned version of the domain-specific [`cisco-ai/SecureBERT2.0`](https://huggingface.co/cisco-ai/SecureBERT2.0) utilizing a Multi-Task Learning (MTL) architecture with flat classification heads.

## 🎯 Intended Use

The primary use case is automating the initial **Vulnerability Triage** process. By inputting unstructured threat narratives, security analysts can instantly receive:
* **8 CVSS v3.1 Metrics:** Attack Vector, Attack Complexity, Privileges Required, User Interaction, Scope, Confidentiality, Integrity, and Availability.
* **CWE Classification:** Probabilistic mapping to the MITRE CWE tree across 4 levels of abstraction (Top-K predictions).

## 🧠 Model Architecture

The model uses a shared `SecureBERT2.0` backbone with 12 distinct classification heads attached to the pooled outputs:
* **CVSS Heads (8):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax`. They use the `[CLS]` token embedding to predict nominal and ordinal CVSS categories.
* **CWE Heads (4):** Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear. These heads utilize the Mean-Pooled token embeddings.

## 📂 Repository Structure & Custom Config

Unlike standard Hugging Face models, this repository features a highly customized `config.json`. It dynamically dictates the architecture and handles label decoding.
* `cvss_map`: Contains the exact string labels for all 8 CVSS metrics (e.g., `["Network", "Adjacent", "Local", "Physical"]`).
* `cwe_labels`: Contains ID-to-Name mappings for all supported CWEs across `pillar`, `class`, `base`, and `variant` levels.

**Note:** Because of the custom multi-head architecture, you cannot use the default `AutoModelForSequenceClassification`. You must define the custom PyTorch class provided in the usage snippet below.

## 💻 Usage & Inference

Below is a complete, standalone Python snippet to load the model, tokenizer, and configuration directly from this Hugging Face repository and perform predictions.

```python
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# 1. Define the Custom Architecture
class SecureBERTFlatClassifier(nn.Module):
    def __init__(self, model_name, cvss_map, class_counts):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        if hasattr(config, "reference_compile"): config.reference_compile = False
        self.bert = AutoModel.from_pretrained(model_name, config=config)
        
        def make_head(out_features, is_cvss=False):
            layers =[
                nn.LayerNorm(768), nn.Dropout(0.1), 
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, out_features)
            ]
            if is_cvss: layers.append(nn.Softmax(dim=1))
            return nn.Sequential(*layers)

        self.cvss_heads = nn.ModuleDict({k: make_head(len(v), True) for k, v in cvss_map.items()})
        self.cwe_heads = nn.ModuleDict({k: make_head(v) for k, v in class_counts.items()})

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        cls_emb = out[:, 0, :]
        mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
        mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
        
        res = {}
        for k, head in self.cvss_heads.items(): res[k] = head(cls_emb)
        for k, head in self.cwe_heads.items(): res[k] = head(mean_emb)
        return res

# 2. Inference Wrapper
class VulnPredictor:
    def __init__(self, repo_id):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
        
        with open(conf_path, "r") as f: self.config = json.load(f)
            
        base_model = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
        counts = {k: len(v) for k, v in self.config.get("cwe_labels", {}).items()}
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.model = SecureBERTFlatClassifier(base_model, self.config["cvss_map"], counts)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
        self.model.to(self.device).eval()

    def predict(self, text, top_k=3):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            out = self.model(inputs['input_ids'], inputs['attention_mask'])
        
        res = {'cvss': {}, 'cwe': {}}
        for task, labels in self.config.get("cvss_map", {}).items():
            score, idx = torch.max(out[task], dim=1)
            res['cvss'][task] = {"value": labels[idx.item()], "confidence": round(score.item(), 4)}
            
        for lv, cwe_data in self.config.get("cwe_labels", {}).items():
            if lv in out:
                probs = F.softmax(out[lv], dim=1)
                scores, idxs = torch.topk(probs, k=min(top_k, probs.size(1)))
                res['cwe'][lv] =[
                    {"id": int(str(cwe_data[i.item()]['id']).replace('CWE-','')), 
                     "name": cwe_data[i.item()]['name'], 
                     "score": round(s.item(), 4)} for s, i in zip(scores[0], idxs[0])
                ]
        return res

# 3. Quickstart
if __name__ == "__main__":
    REPO_ID = "bziemba/SecureBERT2.0-final" 
    
    predictor = VulnPredictor(REPO_ID)
    
    sample_cve = "An issue was discovered in the login panel allowing attackers to bypass authentication via crafted SQL queries."
    results = predictor.predict(sample_cve)
    
    print(json.dumps(results, indent=2))