Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.autograd import Variable | |
| class MusicLSTM(nn.Module): | |
| def __init__(self, input_size, hidden_size, output_size, model='lstm', num_layers=1, dropout_p=0): | |
| super(MusicLSTM, self).__init__() | |
| self.model = model | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.output_size = output_size | |
| self.num_layers = num_layers | |
| self.embeddings = nn.Embedding(input_size, hidden_size) | |
| if self.model == 'lstm': | |
| self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers) | |
| elif self.model == 'gru': | |
| self.rnn = nn.GRU(hidden_size, hidden_size, num_layers) | |
| else: | |
| raise NotImplementedError | |
| self.out = nn.Linear(self.hidden_size, self.output_size) | |
| self.drop = nn.Dropout(p=dropout_p) | |
| def init_hidden(self, batch_size=1): | |
| """Initialize hidden states.""" | |
| if self.model == 'lstm': | |
| self.hidden = ( | |
| torch.zeros(self.num_layers, batch_size, self.hidden_size), | |
| torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
| ) | |
| elif self.model == 'gru': | |
| self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
| return self.hidden | |
| def forward(self, x): | |
| """Forward pass.""" | |
| # Ensure x is 2D (sequence length, batch size) | |
| if x.dim() > 2: | |
| x = x.squeeze() | |
| batch_size = 1 if x.dim() == 1 else x.size(0) | |
| x = x.long() | |
| # Embed the input | |
| embeds = self.embeddings(x) | |
| # Initialize hidden state if not already done | |
| if not hasattr(self, 'hidden'): | |
| self.init_hidden(batch_size) | |
| # Ensure embeds is 3D for RNN input (sequence length, batch size, embedding size) | |
| if embeds.dim() == 2: | |
| embeds = embeds.unsqueeze(1) | |
| # RNN processing | |
| rnn_out, self.hidden = self.rnn(embeds, self.hidden) | |
| # Dropout and output layer | |
| rnn_out = self.drop(rnn_out.squeeze(1)) | |
| output = self.out(rnn_out) | |
| return output | |