natSecLabse / model_architecture.py
kristiangnordby's picture
Upload model_architecture.py with huggingface_hub
7ff7c92 verified
import torch
import torch.nn as nn
class CyberClassifier(nn.Module):
"""Binary classification model for cybersecurity detection."""
def __init__(self, input_dim):
super(CyberClassifier, self).__init__()
self.normalizer = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, 512)
self.bn1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.fc4 = nn.Linear(128, 1)
def forward(self, x):
"""Forward pass through the network."""
x = self.normalizer(x)
x = self.fc1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.fc2(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.fc3(x)
x = self.bn3(x)
x = torch.relu(x)
x = self.fc4(x)
x = torch.sigmoid(x)
return x
def load_model(checkpoint_path, input_dim, device='cpu'):
"""
Load the model from checkpoint.
Args:
checkpoint_path: Path to model checkpoint (.pt file)
input_dim: Input dimension (e.g., 768)
device: 'cuda' or 'cpu'
Returns:
model: Loaded model ready for inference
"""
model = CyberClassifier(input_dim).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model