import torch import torch.nn as nn import torch.nn.functional as F from src.models.bigru import CNNBiGRU from src.models.cnn import ImprovedCNN from src.config import config class CombinedMalwareDetectionModel(nn.Module): def __init__(self, vocab_size, embedding_dim, num_filters, kernel_size): super(CombinedMalwareDetectionModel, self).__init__() self.malware_detection_model = CNNBiGRU(vocab_size, embedding_dim,config.configuration["input_length"], num_filters, kernel_size, config.configuration["num_gated_units"], config.configuration["hidden_neurons"], config.configuration["dropout_cnn"], config.configuration["dropout_fc"]) self.improved_cnn = ImprovedCNN(input_channels=1, hidden_units=32) self.fc1 = nn.Linear(256, 64) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(32, 7) self.dropout = nn.Dropout(0.2) def forward(self, padded_sequences, img_x): output_api = self.malware_detection_model(padded_sequences) output_img = self.improved_cnn(img_x) input_multi = torch.cat([output_img, output_api], dim=-1).to(torch.float32) x = F.relu(self.fc1(input_multi)) x = self.dropout(x) x = F.relu(self.fc2(x)) x = self.dropout(x) x = self.fc3(x) return x