Diva / Models /Vector2MIDI.py
rrayy
Changes to be committed: LSTM ์ดˆ๊ธฐ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ x์—์„œ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋„๋ก ์ˆ˜์ •
32bf138
import torch.nn.functional as F
from pysdtw import SoftDTW
from torch import nn
import torch
class Vector2MIDI(nn.Module):
def __init__(self, hidden_dim, X_dim=25, dropout=0.3):
super().__init__()
self.vocab_sizes = [101, 17, 73, 17, 17, 59, 17] # ์‹ค์ œ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ vocab ํฌ๊ธฐ ์„ค์ •
self.init_hidden = nn.Linear(X_dim, hidden_dim)
self.init_cell = nn.Linear(X_dim, hidden_dim)
self.embeddings = nn.ModuleList([
nn.Embedding(vocab_size, hidden_dim, padding_idx=16) for vocab_size in self.vocab_sizes
])
# ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€ ๋“œ๋กญ์•„์›ƒ LSTM
self.lstm = nn.LSTM(hidden_dim * len(self.vocab_sizes), hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
self.output_heads = nn.ModuleList([ # ๊ฐ ์ฐจ์›๋ณ„ ๋…๋ฆฝ์ ์ธ ์ถœ๋ ฅ ํ—ค๋“œ
nn.Linear(hidden_dim, vocab_size) for vocab_size in self.vocab_sizes
])
self.start_token_heads = nn.ModuleList([ # ์ฒซ ํ† ํฐ ์ƒ์„ฑ์šฉ ๋ฉ€ํ‹ฐ ํ—ค๋“œ
nn.Linear(X_dim, vocab_size) for vocab_size in self.vocab_sizes
])
def forward_hnc(self, x):
"""hidden๊ณผ cell state ์ƒ์„ฑ"""
h0 = torch.tanh(self.init_hidden(x)) # ํ™œ์„ฑํ™” ํ•จ์ˆ˜ ์ถ”๊ฐ€ (hyperbolic tangent)
c0 = torch.tanh(self.init_cell(x))
h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H)
c0 = c0.unsqueeze(0).repeat(2, 1, 1)
return h0, c0
def forward_extend(self, y:torch.Tensor, h=None, c=None):
"""
y: (B, T, 7) - 7์ฐจ์› ์ •์ˆ˜ ํ† ํฐ (EOS + ํŒจ๋”ฉ)
"""
emb_list = []
for idx, emb_f in enumerate(self.embeddings):
emb_list.append(emb_f(y[:, :, idx])) # [B, T, 1]
emb = torch.cat(emb_list, dim=-1) # [B, T, 7]
if h is not None and c is not None:
out, (h, c) = self.lstm(emb, (h, c))
else:
out, (h, c) = self.lstm(emb)
output = [head(out) for head in self.output_heads] # list of [B, T, V_i]
return output, (h, c)
def forward_first(self, x:torch.Tensor):
"""x: 25์ฐจ์› ์Šคํƒ€์ผ ๋ฒกํ„ฐ"""
logits_list = []
for head in self.start_token_heads:
logits = head(x) # (B, vocab_size_i)
logits_list.append(logits)
return logits_list # List of 7 tensors, each (B, vocab_size_i)
def calc_loss(self, style_vec:torch.Tensor, seq:torch.Tensor):
"""
style_vec: (B, 25)
seq: (B, T, 7)
"""
is_cuda = style_vec.device.type == "cuda" # ์ฟ ๋‹ค ์‚ฌ์šฉ ์—ฌ๋ถ€
# ์‹œ์ž‘ ํ† ํฐ loss (cross-entropy)
logits_list = self.forward_first(style_vec) # list of 7 tensors
target_first = seq[:, 0, :] # (B, 7), ์ •๋‹ต ํด๋ž˜์Šค ์ธ๋ฑ์Šค
first_loss = 0
for i in range(7):
logits_i = logits_list[i] # (B, vocab_size_i)
target_i = target_first[:, i].long() # (B,)
first_loss += F.cross_entropy(logits_i, target_i)
first_loss /= 7
# hidden cell state ์˜ˆ์ธก
(h,c) = self.forward_hnc(style_vec)
# ์‹œํ€€์Šค ํ™•์žฅ loss (cross-entropy)
pred_logits, _ = self.forward_extend(seq[:, :-1, :], h, c)
target_seq = seq[:, 1:, :]
extend_loss = 0
pred_tokens = []
for i in range(7):
extend_loss += F.cross_entropy(
pred_logits[i].reshape(-1, pred_logits[i].size(-1)).float(),
target_seq[:, :, i].reshape(-1),
ignore_index=16
)
# argmax๋กœ ์˜ˆ์ธก ํ† ํฐ ์ถ”์ถœ (Soft-DTW์šฉ)
pred_tokens.append(pred_logits[i].argmax(-1, keepdim=True)) # (B, T-1, 1)
pred_tokens = torch.cat(pred_tokens, dim=-1) # list -> Tensor (B, T-1, 7)
extend_loss /= 7
# Soft-DTW loss
soft_dtw = SoftDTW(use_cuda=is_cuda)
min_len = min(pred_tokens.shape[1], target_seq.shape[1])
sdtw_loss = soft_dtw(pred_tokens[:, :min_len, :].float(), target_seq[:, :min_len, :].float()).mean()
sdtw_loss = torch.nan_to_num(torch.log1p(sdtw_loss)) # log์˜ x๊ฐ’์ด ๋„ˆ๋ฌด ์ž‘์œผ๋ฉด nan ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ ๊ฐ™์Œ
return first_loss + extend_loss + 0.3 * sdtw_loss
def _top_k_sampling(self, logits, top_k=5, temperature=1.0):
logits = logits / temperature
topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1)
probs = F.softmax(topk_vals, dim=-1)
idx = torch.multinomial(probs, 1).squeeze(-1)
return topk_idx.gather(-1, idx.unsqueeze(-1)).squeeze(-1)
def generate(self, x:torch.Tensor, max_len=128, top_k=5): #TODO: ์Šคํƒ€ํŠธ ํ† ํฐ ๊ทธ๋ƒฅ ํ† ํฐ์œผ๋กœ ๋ฐ”๊พธ๊ณ  ๊ฑฐ๊ธฐ์— ๊ณ„์† autogressive๋กœ ๋‹ค์Œ ํ† ํฐ ๋„ฃ๊ธฐ
"""x: 25์ฐจ์› ์Šคํƒ€์ผ ๋ฒกํ„ฐ"""
self.eval()
batch_size = x.size(0)
h, c = None, None
start_tokens = torch.zeros(batch_size, 1, 7, dtype=torch.int64, device=x.device)
for i, head in enumerate(self.start_token_heads):
logits = head(x) # (B, vocab_size_i)
# ์Šคํƒ€์ผ ๊ธฐ๋ฐ˜ ์ฒซ ํ† ํฐ ์ƒ˜ํ”Œ๋ง
if i in [2, 5, 7]: # duration ์ฐจ์›: ๋” ํ™•์ •์ ์œผ๋กœ
probs = F.softmax(logits / 0.5, dim=-1) # ๋‚ฎ์€ ์˜จ๋„
else: # pitch, velocity ๋“ฑ: ๋‹ค์–‘์„ฑ ํ—ˆ์šฉ
probs = F.softmax(logits / 1.2, dim=-1) # ์•ฝ๊ฐ„ ๋†’์€ ์˜จ๋„
token = torch.multinomial(probs, num_samples=1) # (B, 1)
start_tokens[:, :, i] = token
generated = [start_tokens.squeeze(0).squeeze(0).tolist()]
for _ in range(max_len - 1):
if h is None and c is None:
h = self.init_hidden(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
c = self.init_cell(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
logits, (h, c) = self.forward_extend(start_tokens, h, c) # logits: list of [B, T, V_i]
last_logits = [log[:, -1, :] for log in logits] # ๋งˆ์ง€๋ง‰ step
sampled = []
for i, logit in enumerate(last_logits):
if i in [2, 5, 7]: # duration ์ฐจ์›: ๋” ํ™•์ •์ ์œผ๋กœ
token = self._top_k_sampling(logit, top_k=top_k, temperature=0.5) # ๋‚ฎ์€ ์˜จ๋„
else: # pitch, velocity ๋“ฑ: ๋‹ค์–‘์„ฑ ํ—ˆ์šฉ
token = self._top_k_sampling(logit, top_k=top_k, temperature=1.2) # ์•ฝ๊ฐ„ ๋†’์€ ์˜จ๋„
sampled.append(token.item())
if sampled == [100, 15, 72, 14, 15, 58, 15]: # EOS ํ† ํฐ
break
else:
generated.append(sampled)
start_tokens = torch.tensor([[sampled]], device=x.device) # [1,1,7]
return torch.tensor(generated, device=x.device, dtype=torch.long) # (max_len, 7)