File size: 1,684 Bytes
5054620
 
 
 
 
 
 
 
 
e5cb338
 
5054620
 
 
 
e5cb338
 
 
 
 
 
 
 
5054620
 
e5cb338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5054620
e5cb338
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch

class Vector2MIDI(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
        super().__init__()  # ๋ถ€๋ชจ ํด๋ž˜์Šค ์ƒ์„ฑ์ž ํ˜ธ์ถœ
        self.input_fc = nn.Linear(input_dim, hidden_dim) # ์ž…๋ ฅ ์ฐจ์›์—์„œ ์€๋‹‰ ์ฐจ์›์œผ๋กœ ๋ณ€ํ™˜

        # ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ๋“œ๋กญ์•„์›ƒ LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)

        self.fc_mid = nn.Linear(hidden_dim, 256)
        self.fc_out = nn.Linear(256, n_vocab)

    def forward(self, x, lengths, total_length=None):
        print("input to forward:", x.shape)
        B, feat_dim = x.size()
        T = lengths.max()

        # [B, 1, feat_dim] โ†’ [B, T, feat_dim]
        x = x.unsqueeze(1).expand(B, T, feat_dim)

        x = self.input_fc(x)

        packed_x = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, _ = self.lstm(packed_x)

        out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_out, batch_first=True, total_length=total_length
        )

        out = self.fc_mid(out)
        out = self.fc_out(out)  # [B, max_len, vocab_size]
        return out

    def generate(self, x, lengths, total_length=None):
        out = self.forward(x, lengths, total_length)

        preds = torch.argmax(out, dim=-1)  # [B, T], ๊ฐ€์žฅ ํฐ ์ ์ˆ˜ ํด๋ž˜์Šค ์„ ํƒ
        external = preds - 2                  # ๋‚ด๋ถ€ ํ‘œํ˜„ โ†’ ์™ธ๋ถ€ ํ‘œํ˜„
        external[external == -2] = 0          # PAD ์ฒ˜๋ฆฌ
        return external