import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoConfig class SecureBERTMultiHead(nn.Module): def __init__(self, pretrained_model_name="cisco-ai/SecureBERT2.0-biencoder"): super().__init__() self.modernbert_config = AutoConfig.from_pretrained(pretrained_model_name) if hasattr(self.modernbert_config, "reference_compile"): self.modernbert_config.reference_compile = False self.bert = AutoModel.from_pretrained(pretrained_model_name, config=self.modernbert_config) heads_config = { 'attack_vector': 4, 'attack_complexity': 2, 'privileges_required': 3, 'user_interaction': 2, 'scope': 2, 'confidentiality': 3, 'integrity': 3, 'availability': 3 } self.cvss_heads = nn.ModuleDict({ k: nn.Sequential( nn.LayerNorm(768), nn.Dropout(0.1), nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1), nn.Linear(768, num_classes), nn.Softmax(dim=1) ) for k, num_classes in heads_config.items() }) self.cwe_heads = nn.ModuleDict({ k: nn.Sequential( nn.LayerNorm(768), nn.Linear(768, 2048), nn.GELU(), nn.Dropout(0.1), nn.Linear(2048, 768), nn.GELU(), nn.Dropout(0.1), nn.Linear(768, 768) ) for k in ['pillar', 'class', 'base', 'variant'] }) def forward(self, input_ids, attention_mask): out = self.bert(input_ids=input_ids, attention_mask=attention_mask) emb = out.last_hidden_state[:, 0, :] res = {} for k, head in self.cvss_heads.items(): res[k] = head(emb) for k, head in self.cwe_heads.items(): res[k] = F.normalize(head(emb), p=2, dim=1) return res