| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoConfig | |
| class SecureBERTMultiHead(nn.Module): | |
| def __init__(self, pretrained_model_name="cisco-ai/SecureBERT2.0-biencoder"): | |
| super().__init__() | |
| self.modernbert_config = AutoConfig.from_pretrained(pretrained_model_name) | |
| if hasattr(self.modernbert_config, "reference_compile"): | |
| self.modernbert_config.reference_compile = False | |
| self.bert = AutoModel.from_pretrained(pretrained_model_name, config=self.modernbert_config) | |
| heads_config = { | |
| 'attack_vector': 4, 'attack_complexity': 2, 'privileges_required': 3, | |
| 'user_interaction': 2, 'scope': 2, 'confidentiality': 3, | |
| 'integrity': 3, 'availability': 3 | |
| } | |
| self.cvss_heads = nn.ModuleDict({ | |
| k: nn.Sequential( | |
| nn.LayerNorm(768), nn.Dropout(0.1), nn.Linear(768, 768), | |
| nn.GELU(), nn.Dropout(0.1), nn.Linear(768, num_classes), nn.Softmax(dim=1) | |
| ) for k, num_classes in heads_config.items() | |
| }) | |
| self.cwe_heads = nn.ModuleDict({ | |
| k: nn.Sequential( | |
| nn.LayerNorm(768), nn.Linear(768, 2048), nn.GELU(), nn.Dropout(0.1), | |
| nn.Linear(2048, 768), nn.GELU(), nn.Dropout(0.1), nn.Linear(768, 768) | |
| ) for k in ['pillar', 'class', 'base', 'variant'] | |
| }) | |
| def forward(self, input_ids, attention_mask): | |
| out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| emb = out.last_hidden_state[:, 0, :] | |
| res = {} | |
| for k, head in self.cvss_heads.items(): res[k] = head(emb) | |
| for k, head in self.cwe_heads.items(): res[k] = F.normalize(head(emb), p=2, dim=1) | |
| return res | |