Spaces:
Build error
Build error
File size: 1,952 Bytes
7ccf60d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import torch
from torch import nn
class DroughtNetLSTM(nn.Module):
def __init__(self, time_dim=20, lstm_dim=256, num_layers=2, dropout=0.15,
static_dim=29, staticfc_dim=16, hidden_dim=256, output_size=6):
super(DroughtNetLSTM, self).__init__()
# Define LSTM network for time features
self.lstm = nn.LSTM(
time_dim,
lstm_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout
)
# Define neural network for static features
self.static_fc = nn.Sequential(
nn.Linear(static_dim, staticfc_dim),
nn.ReLU(),
nn.Linear(staticfc_dim, staticfc_dim)
)
# Define final fully connected layers
self.final_fc = nn.Sequential(
nn.Linear(lstm_dim + staticfc_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_size)
)
def forward(self, x, x_static):
"""
Forward pass through the network
Args:
x: Time series data of shape (batch_size, seq_len, time_dim)
x_static: Static data of shape (batch_size, static_dim)
Returns:
out: Output of shape (batch_size, output_size)
"""
# Process time series data through LSTM
lstm_out, _ = self.lstm(x)
# Take only the last output of the LSTM
lstm_out = lstm_out[:, -1, :]
# Process static data
static_out = self.static_fc(x_static)
# Concatenate LSTM output and static output
combined = torch.cat((lstm_out, static_out), 1)
# Final fully connected layers
out = self.final_fc(combined)
return out |