import torch import torch.nn as nn class CyberClassifier(nn.Module): """Binary classification model for cybersecurity detection.""" def __init__(self, input_dim): super(CyberClassifier, self).__init__() self.normalizer = nn.LayerNorm(input_dim) self.fc1 = nn.Linear(input_dim, 512) self.bn1 = nn.BatchNorm1d(512) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.fc3 = nn.Linear(256, 128) self.bn3 = nn.BatchNorm1d(128) self.fc4 = nn.Linear(128, 1) def forward(self, x): """Forward pass through the network.""" x = self.normalizer(x) x = self.fc1(x) x = self.bn1(x) x = torch.relu(x) x = self.fc2(x) x = self.bn2(x) x = torch.relu(x) x = self.fc3(x) x = self.bn3(x) x = torch.relu(x) x = self.fc4(x) x = torch.sigmoid(x) return x def load_model(checkpoint_path, input_dim, device='cpu'): """ Load the model from checkpoint. Args: checkpoint_path: Path to model checkpoint (.pt file) input_dim: Input dimension (e.g., 768) device: 'cuda' or 'cpu' Returns: model: Loaded model ready for inference """ model = CyberClassifier(input_dim).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model