File size: 1,684 Bytes
5054620 e5cb338 5054620 e5cb338 5054620 e5cb338 5054620 e5cb338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch
class Vector2MIDI(nn.Module):
def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
super().__init__() # ๋ถ๋ชจ ํด๋์ค ์์ฑ์ ํธ์ถ
self.input_fc = nn.Linear(input_dim, hidden_dim) # ์
๋ ฅ ์ฐจ์์์ ์๋ ์ฐจ์์ผ๋ก ๋ณํ
# ๊ณผ์ ํฉ ๋ฐฉ์ง ๋๋กญ์์ LSTM
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
self.fc_mid = nn.Linear(hidden_dim, 256)
self.fc_out = nn.Linear(256, n_vocab)
def forward(self, x, lengths, total_length=None):
print("input to forward:", x.shape)
B, feat_dim = x.size()
T = lengths.max()
# [B, 1, feat_dim] โ [B, T, feat_dim]
x = x.unsqueeze(1).expand(B, T, feat_dim)
x = self.input_fc(x)
packed_x = nn.utils.rnn.pack_padded_sequence(
x, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_out, _ = self.lstm(packed_x)
out, _ = nn.utils.rnn.pad_packed_sequence(
packed_out, batch_first=True, total_length=total_length
)
out = self.fc_mid(out)
out = self.fc_out(out) # [B, max_len, vocab_size]
return out
def generate(self, x, lengths, total_length=None):
out = self.forward(x, lengths, total_length)
preds = torch.argmax(out, dim=-1) # [B, T], ๊ฐ์ฅ ํฐ ์ ์ ํด๋์ค ์ ํ
external = preds - 2 # ๋ด๋ถ ํํ โ ์ธ๋ถ ํํ
external[external == -2] = 0 # PAD ์ฒ๋ฆฌ
return external |