File size: 6,930 Bytes
33e67ed d72f8c1 09f103b 5054620 d72f8c1 2a6b7c9 9164846 d72f8c1 9164846 d72f8c1 5054620 e5cb338 32bf138 9164846 d72f8c1 9164846 5054620 ee4bd89 d72f8c1 ee4bd89 32bf138 d72f8c1 32bf138 d72f8c1 28e8aa2 d72f8c1 2a6b7c9 d72f8c1 2a6b7c9 d72f8c1 32bf138 d72f8c1 32bf138 d72f8c1 32bf138 d72f8c1 9164846 d72f8c1 32bf138 d72f8c1 09f103b d72f8c1 28e8aa2 d72f8c1 ee4bd89 d72f8c1 32bf138 d72f8c1 09f103b d72f8c1 09f103b d72f8c1 09f103b d72f8c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch.nn.functional as F
from pysdtw import SoftDTW
from torch import nn
import torch
class Vector2MIDI(nn.Module):
def __init__(self, hidden_dim, X_dim=25, dropout=0.3):
super().__init__()
self.vocab_sizes = [101, 17, 73, 17, 17, 59, 17] # ์ค์ ๋ฐ์ดํฐ ๊ธฐ๋ฐ vocab ํฌ๊ธฐ ์ค์
self.init_hidden = nn.Linear(X_dim, hidden_dim)
self.init_cell = nn.Linear(X_dim, hidden_dim)
self.embeddings = nn.ModuleList([
nn.Embedding(vocab_size, hidden_dim, padding_idx=16) for vocab_size in self.vocab_sizes
])
# ๊ณผ์ ํฉ ๋ฐฉ์ง ๋๋กญ์์ LSTM
self.lstm = nn.LSTM(hidden_dim * len(self.vocab_sizes), hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
self.output_heads = nn.ModuleList([ # ๊ฐ ์ฐจ์๋ณ ๋
๋ฆฝ์ ์ธ ์ถ๋ ฅ ํค๋
nn.Linear(hidden_dim, vocab_size) for vocab_size in self.vocab_sizes
])
self.start_token_heads = nn.ModuleList([ # ์ฒซ ํ ํฐ ์์ฑ์ฉ ๋ฉํฐ ํค๋
nn.Linear(X_dim, vocab_size) for vocab_size in self.vocab_sizes
])
def forward_hnc(self, x):
"""hidden๊ณผ cell state ์์ฑ"""
h0 = torch.tanh(self.init_hidden(x)) # ํ์ฑํ ํจ์ ์ถ๊ฐ (hyperbolic tangent)
c0 = torch.tanh(self.init_cell(x))
h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H)
c0 = c0.unsqueeze(0).repeat(2, 1, 1)
return h0, c0
def forward_extend(self, y:torch.Tensor, h=None, c=None):
"""
y: (B, T, 7) - 7์ฐจ์ ์ ์ ํ ํฐ (EOS + ํจ๋ฉ)
"""
emb_list = []
for idx, emb_f in enumerate(self.embeddings):
emb_list.append(emb_f(y[:, :, idx])) # [B, T, 1]
emb = torch.cat(emb_list, dim=-1) # [B, T, 7]
if h is not None and c is not None:
out, (h, c) = self.lstm(emb, (h, c))
else:
out, (h, c) = self.lstm(emb)
output = [head(out) for head in self.output_heads] # list of [B, T, V_i]
return output, (h, c)
def forward_first(self, x:torch.Tensor):
"""x: 25์ฐจ์ ์คํ์ผ ๋ฒกํฐ"""
logits_list = []
for head in self.start_token_heads:
logits = head(x) # (B, vocab_size_i)
logits_list.append(logits)
return logits_list # List of 7 tensors, each (B, vocab_size_i)
def calc_loss(self, style_vec:torch.Tensor, seq:torch.Tensor):
"""
style_vec: (B, 25)
seq: (B, T, 7)
"""
is_cuda = style_vec.device.type == "cuda" # ์ฟ ๋ค ์ฌ์ฉ ์ฌ๋ถ
# ์์ ํ ํฐ loss (cross-entropy)
logits_list = self.forward_first(style_vec) # list of 7 tensors
target_first = seq[:, 0, :] # (B, 7), ์ ๋ต ํด๋์ค ์ธ๋ฑ์ค
first_loss = 0
for i in range(7):
logits_i = logits_list[i] # (B, vocab_size_i)
target_i = target_first[:, i].long() # (B,)
first_loss += F.cross_entropy(logits_i, target_i)
first_loss /= 7
# hidden cell state ์์ธก
(h,c) = self.forward_hnc(style_vec)
# ์ํ์ค ํ์ฅ loss (cross-entropy)
pred_logits, _ = self.forward_extend(seq[:, :-1, :], h, c)
target_seq = seq[:, 1:, :]
extend_loss = 0
pred_tokens = []
for i in range(7):
extend_loss += F.cross_entropy(
pred_logits[i].reshape(-1, pred_logits[i].size(-1)).float(),
target_seq[:, :, i].reshape(-1),
ignore_index=16
)
# argmax๋ก ์์ธก ํ ํฐ ์ถ์ถ (Soft-DTW์ฉ)
pred_tokens.append(pred_logits[i].argmax(-1, keepdim=True)) # (B, T-1, 1)
pred_tokens = torch.cat(pred_tokens, dim=-1) # list -> Tensor (B, T-1, 7)
extend_loss /= 7
# Soft-DTW loss
soft_dtw = SoftDTW(use_cuda=is_cuda)
min_len = min(pred_tokens.shape[1], target_seq.shape[1])
sdtw_loss = soft_dtw(pred_tokens[:, :min_len, :].float(), target_seq[:, :min_len, :].float()).mean()
sdtw_loss = torch.nan_to_num(torch.log1p(sdtw_loss)) # log์ x๊ฐ์ด ๋๋ฌด ์์ผ๋ฉด nan ๋ฐ์ํ๋ ๊ฒ ๊ฐ์
return first_loss + extend_loss + 0.3 * sdtw_loss
def _top_k_sampling(self, logits, top_k=5, temperature=1.0):
logits = logits / temperature
topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1)
probs = F.softmax(topk_vals, dim=-1)
idx = torch.multinomial(probs, 1).squeeze(-1)
return topk_idx.gather(-1, idx.unsqueeze(-1)).squeeze(-1)
def generate(self, x:torch.Tensor, max_len=128, top_k=5): #TODO: ์คํํธ ํ ํฐ ๊ทธ๋ฅ ํ ํฐ์ผ๋ก ๋ฐ๊พธ๊ณ ๊ฑฐ๊ธฐ์ ๊ณ์ autogressive๋ก ๋ค์ ํ ํฐ ๋ฃ๊ธฐ
"""x: 25์ฐจ์ ์คํ์ผ ๋ฒกํฐ"""
self.eval()
batch_size = x.size(0)
h, c = None, None
start_tokens = torch.zeros(batch_size, 1, 7, dtype=torch.int64, device=x.device)
for i, head in enumerate(self.start_token_heads):
logits = head(x) # (B, vocab_size_i)
# ์คํ์ผ ๊ธฐ๋ฐ ์ฒซ ํ ํฐ ์ํ๋ง
if i in [2, 5, 7]: # duration ์ฐจ์: ๋ ํ์ ์ ์ผ๋ก
probs = F.softmax(logits / 0.5, dim=-1) # ๋ฎ์ ์จ๋
else: # pitch, velocity ๋ฑ: ๋ค์์ฑ ํ์ฉ
probs = F.softmax(logits / 1.2, dim=-1) # ์ฝ๊ฐ ๋์ ์จ๋
token = torch.multinomial(probs, num_samples=1) # (B, 1)
start_tokens[:, :, i] = token
generated = [start_tokens.squeeze(0).squeeze(0).tolist()]
for _ in range(max_len - 1):
if h is None and c is None:
h = self.init_hidden(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
c = self.init_cell(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
logits, (h, c) = self.forward_extend(start_tokens, h, c) # logits: list of [B, T, V_i]
last_logits = [log[:, -1, :] for log in logits] # ๋ง์ง๋ง step
sampled = []
for i, logit in enumerate(last_logits):
if i in [2, 5, 7]: # duration ์ฐจ์: ๋ ํ์ ์ ์ผ๋ก
token = self._top_k_sampling(logit, top_k=top_k, temperature=0.5) # ๋ฎ์ ์จ๋
else: # pitch, velocity ๋ฑ: ๋ค์์ฑ ํ์ฉ
token = self._top_k_sampling(logit, top_k=top_k, temperature=1.2) # ์ฝ๊ฐ ๋์ ์จ๋
sampled.append(token.item())
if sampled == [100, 15, 72, 14, 15, 58, 15]: # EOS ํ ํฐ
break
else:
generated.append(sampled)
start_tokens = torch.tensor([[sampled]], device=x.device) # [1,1,7]
return torch.tensor(generated, device=x.device, dtype=torch.long) # (max_len, 7) |