Diva / Models /Vector2MIDI.py
rrayy
Changes to be committed: Rest๊ฐ€ -1๋กœ ๋งตํ•‘๋˜์–ด ์ „์ฒด ํ† ํฐ +2 ๋กœ ํ•ด๊ฒฐ
e5cb338
raw
history blame
1.68 kB
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch
class Vector2MIDI(nn.Module):
def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
super().__init__() # ๋ถ€๋ชจ ํด๋ž˜์Šค ์ƒ์„ฑ์ž ํ˜ธ์ถœ
self.input_fc = nn.Linear(input_dim, hidden_dim) # ์ž…๋ ฅ ์ฐจ์›์—์„œ ์€๋‹‰ ์ฐจ์›์œผ๋กœ ๋ณ€ํ™˜
# ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ๋“œ๋กญ์•„์›ƒ LSTM
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
self.fc_mid = nn.Linear(hidden_dim, 256)
self.fc_out = nn.Linear(256, n_vocab)
def forward(self, x, lengths, total_length=None):
print("input to forward:", x.shape)
B, feat_dim = x.size()
T = lengths.max()
# [B, 1, feat_dim] โ†’ [B, T, feat_dim]
x = x.unsqueeze(1).expand(B, T, feat_dim)
x = self.input_fc(x)
packed_x = nn.utils.rnn.pack_padded_sequence(
x, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_out, _ = self.lstm(packed_x)
out, _ = nn.utils.rnn.pad_packed_sequence(
packed_out, batch_first=True, total_length=total_length
)
out = self.fc_mid(out)
out = self.fc_out(out) # [B, max_len, vocab_size]
return out
def generate(self, x, lengths, total_length=None):
out = self.forward(x, lengths, total_length)
preds = torch.argmax(out, dim=-1) # [B, T], ๊ฐ€์žฅ ํฐ ์ ์ˆ˜ ํด๋ž˜์Šค ์„ ํƒ
external = preds - 2 # ๋‚ด๋ถ€ ํ‘œํ˜„ โ†’ ์™ธ๋ถ€ ํ‘œํ˜„
external[external == -2] = 0 # PAD ์ฒ˜๋ฆฌ
return external