Malware_Classifier / src /models /multimodal.py
mulasagg's picture
final
a9640f8
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.bigru import CNNBiGRU
from src.models.cnn import ImprovedCNN
from src.config import config
class CombinedMalwareDetectionModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_filters, kernel_size):
super(CombinedMalwareDetectionModel, self).__init__()
self.malware_detection_model = CNNBiGRU(vocab_size, embedding_dim,config.configuration["input_length"], num_filters, kernel_size,
config.configuration["num_gated_units"], config.configuration["hidden_neurons"],
config.configuration["dropout_cnn"], config.configuration["dropout_fc"])
self.improved_cnn = ImprovedCNN(input_channels=1, hidden_units=32)
self.fc1 = nn.Linear(256, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 7)
self.dropout = nn.Dropout(0.2)
def forward(self, padded_sequences, img_x):
output_api = self.malware_detection_model(padded_sequences)
output_img = self.improved_cnn(img_x)
input_multi = torch.cat([output_img, output_api], dim=-1).to(torch.float32)
x = F.relu(self.fc1(input_multi))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x