File size: 1,952 Bytes
7ccf60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch import nn

class DroughtNetLSTM(nn.Module):
    def __init__(self, time_dim=20, lstm_dim=256, num_layers=2, dropout=0.15, 

                 static_dim=29, staticfc_dim=16, hidden_dim=256, output_size=6):
        super(DroughtNetLSTM, self).__init__()
        
        # Define LSTM network for time features
        self.lstm = nn.LSTM(
            time_dim,
            lstm_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )
        
        # Define neural network for static features
        self.static_fc = nn.Sequential(
            nn.Linear(static_dim, staticfc_dim),
            nn.ReLU(),
            nn.Linear(staticfc_dim, staticfc_dim)
        )
        
        # Define final fully connected layers
        self.final_fc = nn.Sequential(
            nn.Linear(lstm_dim + staticfc_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_size)
        )

    def forward(self, x, x_static):
        """

        Forward pass through the network

        

        Args:

            x: Time series data of shape (batch_size, seq_len, time_dim)

            x_static: Static data of shape (batch_size, static_dim)

            

        Returns:

            out: Output of shape (batch_size, output_size)

        """
        # Process time series data through LSTM
        lstm_out, _ = self.lstm(x)
        
        # Take only the last output of the LSTM
        lstm_out = lstm_out[:, -1, :]
        
        # Process static data
        static_out = self.static_fc(x_static)
        
        # Concatenate LSTM output and static output
        combined = torch.cat((lstm_out, static_out), 1)
        
        # Final fully connected layers
        out = self.final_fc(combined)
        
        return out