File size: 1,186 Bytes
5054620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import numpy as np
import torch

class Vector2MIDI(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
        super().__init__()  # ๋ถ€๋ชจ ํด๋ž˜์Šค ์ƒ์„ฑ์ž ํ˜ธ์ถœ
        self.input_fc = nn.Linear(input_dim, hidden_dim) # ์ž…๋ ฅ ์ฐจ์›์—์„œ ์€๋‹‰ ์ฐจ์›์œผ๋กœ ๋ณ€ํ™˜

        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout) # ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ๋“œ๋กญ์•„์›ƒ LSTM

        self.fc_mid = nn.Linear(hidden_dim, 256)
        self.fc_out = nn.Linear(256, n_vocab)

    def forward(self, x, lengths):
        # x: [batch, seq_len, input_dim]
        x = self.input_fc(x)

        # Token ๊ธธ์ด๊ฐ€ Midi๋งˆ๋‹ค ๋‹ค๋ฅด๋‹ˆ๊นŒ PackedSequence ๋ณ€ํ™˜ ํ›„ LSTM ์ฒ˜๋ฆฌ
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) # ํŒจ๋”ฉ ๋ณ€ํ™˜
        LSTM_out, _ = self.lstm(packed) # LSTM ์ฒ˜๋ฆฌ
        padded, _ = pad_packed_sequence(LSTM_out, batch_first=True) # ํŒจ๋”ฉ ๋ณต์›

        # ์ตœ์ข… ์ถœ๋ ฅ
        x = self.fc_mid(padded)
        return self.fc_out(x)  # [B, T, vocab_size]