File size: 1,851 Bytes
5956d68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig

class SecureBERTMultiHead(nn.Module):
    def __init__(self, pretrained_model_name="cisco-ai/SecureBERT2.0-biencoder"):
        super().__init__()
        self.modernbert_config = AutoConfig.from_pretrained(pretrained_model_name)
        if hasattr(self.modernbert_config, "reference_compile"):
            self.modernbert_config.reference_compile = False

        self.bert = AutoModel.from_pretrained(pretrained_model_name, config=self.modernbert_config)
        
        heads_config = {
            'attack_vector': 4, 'attack_complexity': 2, 'privileges_required': 3,
            'user_interaction': 2, 'scope': 2, 'confidentiality': 3,
            'integrity': 3, 'availability': 3
        }

        self.cvss_heads = nn.ModuleDict({
            k: nn.Sequential(
                nn.LayerNorm(768), nn.Dropout(0.1), nn.Linear(768, 768),
                nn.GELU(), nn.Dropout(0.1), nn.Linear(768, num_classes), nn.Softmax(dim=1)
            ) for k, num_classes in heads_config.items()
        })
        
        self.cwe_heads = nn.ModuleDict({
            k: nn.Sequential(
                nn.LayerNorm(768), nn.Linear(768, 2048), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(2048, 768), nn.GELU(), nn.Dropout(0.1), nn.Linear(768, 768)
            ) for k in ['pillar', 'class', 'base', 'variant']
        })

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        emb = out.last_hidden_state[:, 0, :]
        res = {}
        for k, head in self.cvss_heads.items(): res[k] = head(emb)
        for k, head in self.cwe_heads.items(): res[k] = F.normalize(head(emb), p=2, dim=1)
        return res