import torch import torch.nn as nn class CardiovascularRNN(nn.Module): def __init__(self, input_size, hidden_size, num_layers, num_classes): super(CardiovascularRNN, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers # RNN layer (using GRU for better performance over plain RNN) self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True) # Fully connected layer self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x): # Initialize hidden state h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) # Forward propagate GRU out, _ = self.gru(x, h0) # Decode the hidden state of the last time step out = out[:, -1, :] # Output layer out = self.fc(out) return out