File size: 6,930 Bytes
33e67ed
d72f8c1
09f103b
 
5054620
 
d72f8c1
2a6b7c9
9164846
d72f8c1
9164846
d72f8c1
 
 
 
 
 
5054620
e5cb338
32bf138
9164846
 
d72f8c1
9164846
5054620
ee4bd89
d72f8c1
ee4bd89
 
32bf138
 
 
 
 
 
 
 
 
 
 
 
 
 
d72f8c1
 
32bf138
d72f8c1
28e8aa2
d72f8c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6b7c9
d72f8c1
 
2a6b7c9
d72f8c1
 
 
 
 
 
 
 
 
 
 
 
 
 
32bf138
 
 
d72f8c1
32bf138
d72f8c1
 
 
 
 
 
 
 
32bf138
d72f8c1
 
 
 
 
9164846
d72f8c1
 
 
 
32bf138
d72f8c1
 
 
 
 
 
 
 
 
09f103b
d72f8c1
 
 
28e8aa2
d72f8c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee4bd89
d72f8c1
32bf138
 
 
d72f8c1
 
09f103b
d72f8c1
 
 
 
 
 
 
 
09f103b
d72f8c1
 
 
 
 
09f103b
d72f8c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch.nn.functional as F
from pysdtw import SoftDTW
from torch import nn
import torch

class Vector2MIDI(nn.Module):
    def __init__(self, hidden_dim, X_dim=25, dropout=0.3):
        super().__init__()
        
        self.vocab_sizes = [101, 17, 73, 17, 17, 59, 17] # ์‹ค์ œ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ vocab ํฌ๊ธฐ ์„ค์ •
        
        self.init_hidden = nn.Linear(X_dim, hidden_dim)
        self.init_cell = nn.Linear(X_dim, hidden_dim)

        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, hidden_dim, padding_idx=16) for vocab_size in self.vocab_sizes
        ])

        # ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ๋“œ๋กญ์•„์›ƒ LSTM
        self.lstm = nn.LSTM(hidden_dim * len(self.vocab_sizes), hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
        
        self.output_heads = nn.ModuleList([ # ๊ฐ ์ฐจ์›๋ณ„ ๋…๋ฆฝ์ ์ธ ์ถœ๋ ฅ ํ—ค๋“œ
            nn.Linear(hidden_dim, vocab_size) for vocab_size in self.vocab_sizes
        ])

        self.start_token_heads = nn.ModuleList([ # ์ฒซ ํ† ํฐ ์ƒ์„ฑ์šฉ ๋ฉ€ํ‹ฐ ํ—ค๋“œ
            nn.Linear(X_dim, vocab_size) for vocab_size in self.vocab_sizes
        ])

    def forward_hnc(self, x):
        """hidden๊ณผ cell state ์ƒ์„ฑ"""
        h0 = torch.tanh(self.init_hidden(x)) # ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ์ถ”๊ฐ€ (hyperbolic tangent)
        c0 = torch.tanh(self.init_cell(x))
        
        h0 = h0.unsqueeze(0).repeat(2, 1, 1)  # (num_layers, B, H)
        c0 = c0.unsqueeze(0).repeat(2, 1, 1)
        
        return h0, c0

    def forward_extend(self, y:torch.Tensor, h=None, c=None):
        """
        y: (B, T, 7) - 7์ฐจ์› ์ •์ˆ˜ ํ† ํฐ (EOS + ํŒจ๋”ฉ)
        """
        emb_list = []
        for idx, emb_f in enumerate(self.embeddings):
            emb_list.append(emb_f(y[:, :, idx]))  # [B, T, 1]
        emb = torch.cat(emb_list, dim=-1)  # [B, T, 7]

        if h is not None and c is not None:
            out, (h, c) = self.lstm(emb, (h, c))
        else:
            out, (h, c) = self.lstm(emb)

        output = [head(out) for head in self.output_heads] # list of [B, T, V_i]
        return output, (h, c)

    def forward_first(self, x:torch.Tensor):
        """x: 25์ฐจ์› ์Šคํƒ€์ผ ๋ฒกํ„ฐ"""
        logits_list = []
        for head in self.start_token_heads:
            logits = head(x)  # (B, vocab_size_i)
            logits_list.append(logits)
        return logits_list  # List of 7 tensors, each (B, vocab_size_i)

    def calc_loss(self, style_vec:torch.Tensor, seq:torch.Tensor):
        """
        style_vec: (B, 25)
        seq: (B, T, 7)
        """
        is_cuda = style_vec.device.type == "cuda" # ์ฟ ๋‹ค ์‚ฌ์šฉ ์—ฌ๋ถ€

        # ์‹œ์ž‘ ํ† ํฐ loss (cross-entropy)
        logits_list = self.forward_first(style_vec)  # list of 7 tensors
        target_first = seq[:, 0, :]  # (B, 7), ์ •๋‹ต ํด๋ž˜์Šค ์ธ๋ฑ์Šค

        first_loss = 0
        for i in range(7):
            logits_i = logits_list[i]                # (B, vocab_size_i)
            target_i = target_first[:, i].long()     # (B,)
            first_loss += F.cross_entropy(logits_i, target_i)

        first_loss /= 7

        # hidden cell state ์˜ˆ์ธก
        (h,c) = self.forward_hnc(style_vec)

        # ์‹œํ€€์Šค ํ™•์žฅ loss (cross-entropy)
        pred_logits, _ = self.forward_extend(seq[:, :-1, :], h, c)
        target_seq = seq[:, 1:, :]

        extend_loss = 0
        pred_tokens = []
        for i in range(7):
            extend_loss += F.cross_entropy(
                pred_logits[i].reshape(-1, pred_logits[i].size(-1)).float(),
                target_seq[:, :, i].reshape(-1),
                ignore_index=16
            )
            # argmax๋กœ ์˜ˆ์ธก ํ† ํฐ ์ถ”์ถœ (Soft-DTW์šฉ)
            pred_tokens.append(pred_logits[i].argmax(-1, keepdim=True))  # (B, T-1, 1)
        pred_tokens = torch.cat(pred_tokens, dim=-1)  # list -> Tensor (B, T-1, 7)
        extend_loss /= 7

        # Soft-DTW loss
        soft_dtw = SoftDTW(use_cuda=is_cuda)
        min_len = min(pred_tokens.shape[1], target_seq.shape[1])
        sdtw_loss = soft_dtw(pred_tokens[:, :min_len, :].float(), target_seq[:, :min_len, :].float()).mean()
        sdtw_loss = torch.nan_to_num(torch.log1p(sdtw_loss)) # log์˜ x๊ฐ’์ด ๋„ˆ๋ฌด ์ž‘์œผ๋ฉด nan ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ ๊ฐ™์Œ

        return first_loss + extend_loss + 0.3 * sdtw_loss
    
    def _top_k_sampling(self, logits, top_k=5, temperature=1.0):
        logits = logits / temperature
        topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1)
        probs = F.softmax(topk_vals, dim=-1)
        idx = torch.multinomial(probs, 1).squeeze(-1)
        return topk_idx.gather(-1, idx.unsqueeze(-1)).squeeze(-1)
    
    def generate(self, x:torch.Tensor, max_len=128, top_k=5): #TODO: ์Šคํƒ€ํŠธ ํ† ํฐ ๊ทธ๋ƒฅ ํ† ํฐ์œผ๋กœ ๋ฐ”๊พธ๊ณ  ๊ฑฐ๊ธฐ์— ๊ณ„์† autogressive๋กœ ๋‹ค์Œ ํ† ํฐ ๋„ฃ๊ธฐ
        """x: 25์ฐจ์› ์Šคํƒ€์ผ ๋ฒกํ„ฐ"""
        self.eval()
        batch_size = x.size(0)
        h, c = None, None
        start_tokens = torch.zeros(batch_size, 1, 7, dtype=torch.int64, device=x.device)
        
        for i, head in enumerate(self.start_token_heads):
            logits = head(x)  # (B, vocab_size_i)
            
            # ์Šคํƒ€์ผ ๊ธฐ๋ฐ˜ ์ฒซ ํ† ํฐ ์ƒ˜ํ”Œ๋ง
            if i in [2, 5, 7]:  # duration ์ฐจ์›: ๋” ํ™•์ •์ ์œผ๋กœ
                probs = F.softmax(logits / 0.5, dim=-1)  # ๋‚ฎ์€ ์˜จ๋„
            else:  # pitch, velocity ๋“ฑ: ๋‹ค์–‘์„ฑ ํ—ˆ์šฉ
                probs = F.softmax(logits / 1.2, dim=-1)  # ์•ฝ๊ฐ„ ๋†’์€ ์˜จ๋„
            
            token = torch.multinomial(probs, num_samples=1)  # (B, 1)
            start_tokens[:, :, i] = token
            
        generated = [start_tokens.squeeze(0).squeeze(0).tolist()]

        for _ in range(max_len - 1):
            if h is None and c is None:
                h = self.init_hidden(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
                c = self.init_cell(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
            logits, (h, c) = self.forward_extend(start_tokens, h, c)  # logits: list of [B, T, V_i]
            last_logits = [log[:, -1, :] for log in logits]  # ๋งˆ์ง€๋ง‰ step

            sampled = []
            for i, logit in enumerate(last_logits):
                if i in [2, 5, 7]:  # duration ์ฐจ์›: ๋” ํ™•์ •์ ์œผ๋กœ
                    token = self._top_k_sampling(logit, top_k=top_k, temperature=0.5)  # ๋‚ฎ์€ ์˜จ๋„
                else:  # pitch, velocity ๋“ฑ: ๋‹ค์–‘์„ฑ ํ—ˆ์šฉ
                    token = self._top_k_sampling(logit, top_k=top_k, temperature=1.2)  # ์•ฝ๊ฐ„ ๋†’์€ ์˜จ๋„
                
                sampled.append(token.item())

            if sampled == [100, 15, 72, 14, 15, 58, 15]: # EOS ํ† ํฐ
                break
            else:
                generated.append(sampled)
                start_tokens = torch.tensor([[sampled]], device=x.device)  # [1,1,7]

        return torch.tensor(generated, device=x.device, dtype=torch.long)  # (max_len, 7)