File size: 1,425 Bytes
a9640f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.models.bigru import CNNBiGRU
from src.models.cnn import ImprovedCNN
from src.config import config


class CombinedMalwareDetectionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, kernel_size):
        super(CombinedMalwareDetectionModel, self).__init__()

        self.malware_detection_model = CNNBiGRU(vocab_size, embedding_dim,config.configuration["input_length"], num_filters, kernel_size, 
                                                config.configuration["num_gated_units"], config.configuration["hidden_neurons"],
                                                config.configuration["dropout_cnn"], config.configuration["dropout_fc"]) 
        
        self.improved_cnn = ImprovedCNN(input_channels=1, hidden_units=32)

        self.fc1 = nn.Linear(256, 64)  
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 7)  

        self.dropout = nn.Dropout(0.2)

    def forward(self, padded_sequences, img_x):
        output_api = self.malware_detection_model(padded_sequences)
        output_img = self.improved_cnn(img_x)

        input_multi = torch.cat([output_img, output_api], dim=-1).to(torch.float32) 
        
        x = F.relu(self.fc1(input_multi)) 
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x) 
        return x