Spaces:
Runtime error
Runtime error
| from typing import Dict | |
| import torch | |
| import torch.nn as nn | |
| device = "cpu" | |
| class SeqClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| embeddings: torch.tensor, | |
| hidden_size: int, | |
| num_layers: int, | |
| dropout: float, | |
| bidirectional: bool, | |
| num_class: int, | |
| ) -> None: | |
| super(SeqClassifier, self).__init__() | |
| # Model parameters | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| self.bidirectional = bidirectional | |
| self.num_class = num_class | |
| # Word embeddings layer | |
| self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False) | |
| # GRU layer | |
| self.rnn = nn.GRU( | |
| input_size=embeddings.size(1), | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| bidirectional=bidirectional, | |
| batch_first=True | |
| ) | |
| # Dropout layer | |
| self.dropout_layer = nn.Dropout(p=dropout) | |
| # Fully connected layer for classification | |
| self.fc = nn.Linear(self.encoder_output_size, num_class) | |
| def encoder_output_size(self) -> int: | |
| # Calculate the output dimension of the RNN | |
| if self.bidirectional: | |
| return self.hidden_size * 2 | |
| else: | |
| return self.hidden_size | |
| def forward(self, batch) -> torch.Tensor: | |
| # Embed the input into the word embedding space | |
| embedded = self.embed(batch) | |
| # Pass through the LSTM layer | |
| rnn_output, _ = self.rnn(embedded) | |
| rnn_output = self.dropout_layer(rnn_output) | |
| if not self.training: | |
| last_hidden_state_forward = rnn_output[-1, :self.hidden_size] # Forward hidden state | |
| last_hidden_state_backward = rnn_output[0, self.hidden_size:] # Backward hidden state | |
| combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0) | |
| # Pass through the fully connected layer | |
| logits = self.fc(combined_hidden_state) | |
| return logits # Return predictions | |
| last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # Forward hidden state | |
| last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # Backward hidden state | |
| combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1) | |
| # Pass through the fully connected layer | |
| logits = self.fc(combined_hidden_state) | |
| return logits # Return predictions | |