Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| class DroughtNetLSTM(nn.Module): | |
| def __init__(self, time_dim=20, lstm_dim=256, num_layers=2, dropout=0.15, | |
| static_dim=29, staticfc_dim=16, hidden_dim=256, output_size=6): | |
| super(DroughtNetLSTM, self).__init__() | |
| # Define LSTM network for time features | |
| self.lstm = nn.LSTM( | |
| time_dim, | |
| lstm_dim, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=dropout | |
| ) | |
| # Define neural network for static features | |
| self.static_fc = nn.Sequential( | |
| nn.Linear(static_dim, staticfc_dim), | |
| nn.ReLU(), | |
| nn.Linear(staticfc_dim, staticfc_dim) | |
| ) | |
| # Define final fully connected layers | |
| self.final_fc = nn.Sequential( | |
| nn.Linear(lstm_dim + staticfc_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, output_size) | |
| ) | |
| def forward(self, x, x_static): | |
| """ | |
| Forward pass through the network | |
| Args: | |
| x: Time series data of shape (batch_size, seq_len, time_dim) | |
| x_static: Static data of shape (batch_size, static_dim) | |
| Returns: | |
| out: Output of shape (batch_size, output_size) | |
| """ | |
| # Process time series data through LSTM | |
| lstm_out, _ = self.lstm(x) | |
| # Take only the last output of the LSTM | |
| lstm_out = lstm_out[:, -1, :] | |
| # Process static data | |
| static_out = self.static_fc(x_static) | |
| # Concatenate LSTM output and static output | |
| combined = torch.cat((lstm_out, static_out), 1) | |
| # Final fully connected layers | |
| out = self.final_fc(combined) | |
| return out |