bziemba's picture
Upload folder using huggingface_hub
5956d68 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
class SecureBERTMultiHead(nn.Module):
def __init__(self, pretrained_model_name="cisco-ai/SecureBERT2.0-biencoder"):
super().__init__()
self.modernbert_config = AutoConfig.from_pretrained(pretrained_model_name)
if hasattr(self.modernbert_config, "reference_compile"):
self.modernbert_config.reference_compile = False
self.bert = AutoModel.from_pretrained(pretrained_model_name, config=self.modernbert_config)
heads_config = {
'attack_vector': 4, 'attack_complexity': 2, 'privileges_required': 3,
'user_interaction': 2, 'scope': 2, 'confidentiality': 3,
'integrity': 3, 'availability': 3
}
self.cvss_heads = nn.ModuleDict({
k: nn.Sequential(
nn.LayerNorm(768), nn.Dropout(0.1), nn.Linear(768, 768),
nn.GELU(), nn.Dropout(0.1), nn.Linear(768, num_classes), nn.Softmax(dim=1)
) for k, num_classes in heads_config.items()
})
self.cwe_heads = nn.ModuleDict({
k: nn.Sequential(
nn.LayerNorm(768), nn.Linear(768, 2048), nn.GELU(), nn.Dropout(0.1),
nn.Linear(2048, 768), nn.GELU(), nn.Dropout(0.1), nn.Linear(768, 768)
) for k in ['pillar', 'class', 'base', 'variant']
})
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
emb = out.last_hidden_state[:, 0, :]
res = {}
for k, head in self.cvss_heads.items(): res[k] = head(emb)
for k, head in self.cwe_heads.items(): res[k] = F.normalize(head(emb), p=2, dim=1)
return res