File size: 1,881 Bytes
5ef46eb
 
 
 
 
6743bb5
2b0b6ad
5ef46eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f9ead2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import math

class EnhancedLSTM(nn.Module):
    def __init__(self, input_size, lstm_layer_sizes, linear_layer_size, output_size):
        super(EnhancedLSTM, self).__init__()
        self.input_size = input_size

        self.lstm_layers = nn.ModuleList()
        for i in range(len(lstm_layer_sizes)):
            if i == 0:
                self.lstm_layers.append(nn.LSTM(input_size, lstm_layer_sizes[i], batch_first=True, bidirectional=True))
            else:
                self.lstm_layers.append(nn.LSTM(lstm_layer_sizes[i-1]*2, lstm_layer_sizes[i], batch_first=True, bidirectional=True))

        self.dropout = nn.Dropout(0.3)
        self.fc_layers = nn.ModuleList()
        for i in range(len(linear_layer_size)):
            if i == 0:
                self.fc_layers.append(nn.Linear(lstm_layer_sizes[-1]*2, linear_layer_size[i]))
            else:
                self.fc_layers.append(nn.Linear(linear_layer_size[i-1], linear_layer_size[i]))

        self.output_layer = nn.Linear(linear_layer_size[-1], output_size)
        self.relu = nn.ReLU()

        self.apply(self.initialize_weights)

    def forward(self, x):
        out = x
        for lstm in self.lstm_layers:
            out, (hn, cn) = lstm(out)
        out = self.dropout(out[:, -1, :])

        for fc in self.fc_layers:
            out = self.relu(fc(out))

        out = self.output_layer(out)
        return out

    def initialize_weights(self, layer):
        if isinstance(layer, nn.Linear):
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        elif isinstance(layer, nn.LSTM):
            for name, param in layer.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'bias' in name:
                    nn.init.zeros_(param.data)