import torch.nn.functional as F from pysdtw import SoftDTW from torch import nn import torch class Vector2MIDI(nn.Module): def __init__(self, hidden_dim, X_dim=25, dropout=0.3): super().__init__() self.vocab_sizes = [101, 17, 73, 17, 17, 59, 17] # 실제 데이터 기반 vocab 크기 설정 self.init_hidden = nn.Linear(X_dim, hidden_dim) self.init_cell = nn.Linear(X_dim, hidden_dim) self.embeddings = nn.ModuleList([ nn.Embedding(vocab_size, hidden_dim, padding_idx=16) for vocab_size in self.vocab_sizes ]) # 과적합 방지 드롭아웃 LSTM self.lstm = nn.LSTM(hidden_dim * len(self.vocab_sizes), hidden_dim, num_layers=2, batch_first=True, dropout=dropout) self.output_heads = nn.ModuleList([ # 각 차원별 독립적인 출력 헤드 nn.Linear(hidden_dim, vocab_size) for vocab_size in self.vocab_sizes ]) self.start_token_heads = nn.ModuleList([ # 첫 토큰 생성용 멀티 헤드 nn.Linear(X_dim, vocab_size) for vocab_size in self.vocab_sizes ]) def forward_hnc(self, x): """hidden과 cell state 생성""" h0 = torch.tanh(self.init_hidden(x)) # 활성화 함수 추가 (hyperbolic tangent) c0 = torch.tanh(self.init_cell(x)) h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H) c0 = c0.unsqueeze(0).repeat(2, 1, 1) return h0, c0 def forward_extend(self, y:torch.Tensor, h=None, c=None): """ y: (B, T, 7) - 7차원 정수 토큰 (EOS + 패딩) """ emb_list = [] for idx, emb_f in enumerate(self.embeddings): emb_list.append(emb_f(y[:, :, idx])) # [B, T, 1] emb = torch.cat(emb_list, dim=-1) # [B, T, 7] if h is not None and c is not None: out, (h, c) = self.lstm(emb, (h, c)) else: out, (h, c) = self.lstm(emb) output = [head(out) for head in self.output_heads] # list of [B, T, V_i] return output, (h, c) def forward_first(self, x:torch.Tensor): """x: 25차원 스타일 벡터""" logits_list = [] for head in self.start_token_heads: logits = head(x) # (B, vocab_size_i) logits_list.append(logits) return logits_list # List of 7 tensors, each (B, vocab_size_i) def calc_loss(self, style_vec:torch.Tensor, seq:torch.Tensor): """ style_vec: (B, 25) seq: (B, T, 7) """ is_cuda = style_vec.device.type == "cuda" # 쿠다 사용 여부 # 시작 토큰 loss (cross-entropy) logits_list = self.forward_first(style_vec) # list of 7 tensors target_first = seq[:, 0, :] # (B, 7), 정답 클래스 인덱스 first_loss = 0 for i in range(7): logits_i = logits_list[i] # (B, vocab_size_i) target_i = target_first[:, i].long() # (B,) first_loss += F.cross_entropy(logits_i, target_i) first_loss /= 7 # hidden cell state 예측 (h,c) = self.forward_hnc(style_vec) # 시퀀스 확장 loss (cross-entropy) pred_logits, _ = self.forward_extend(seq[:, :-1, :], h, c) target_seq = seq[:, 1:, :] extend_loss = 0 pred_tokens = [] for i in range(7): extend_loss += F.cross_entropy( pred_logits[i].reshape(-1, pred_logits[i].size(-1)).float(), target_seq[:, :, i].reshape(-1), ignore_index=16 ) # argmax로 예측 토큰 추출 (Soft-DTW용) pred_tokens.append(pred_logits[i].argmax(-1, keepdim=True)) # (B, T-1, 1) pred_tokens = torch.cat(pred_tokens, dim=-1) # list -> Tensor (B, T-1, 7) extend_loss /= 7 # Soft-DTW loss soft_dtw = SoftDTW(use_cuda=is_cuda) min_len = min(pred_tokens.shape[1], target_seq.shape[1]) sdtw_loss = soft_dtw(pred_tokens[:, :min_len, :].float(), target_seq[:, :min_len, :].float()).mean() sdtw_loss = torch.nan_to_num(torch.log1p(sdtw_loss)) # log의 x값이 너무 작으면 nan 발생하는 것 같음 return first_loss + extend_loss + 0.3 * sdtw_loss def _top_k_sampling(self, logits, top_k=5, temperature=1.0): logits = logits / temperature topk_vals, topk_idx = torch.topk(logits, top_k, dim=-1) probs = F.softmax(topk_vals, dim=-1) idx = torch.multinomial(probs, 1).squeeze(-1) return topk_idx.gather(-1, idx.unsqueeze(-1)).squeeze(-1) def generate(self, x:torch.Tensor, max_len=128, top_k=5): #TODO: 스타트 토큰 그냥 토큰으로 바꾸고 거기에 계속 autogressive로 다음 토큰 넣기 """x: 25차원 스타일 벡터""" self.eval() batch_size = x.size(0) h, c = None, None start_tokens = torch.zeros(batch_size, 1, 7, dtype=torch.int64, device=x.device) for i, head in enumerate(self.start_token_heads): logits = head(x) # (B, vocab_size_i) # 스타일 기반 첫 토큰 샘플링 if i in [2, 5, 7]: # duration 차원: 더 확정적으로 probs = F.softmax(logits / 0.5, dim=-1) # 낮은 온도 else: # pitch, velocity 등: 다양성 허용 probs = F.softmax(logits / 1.2, dim=-1) # 약간 높은 온도 token = torch.multinomial(probs, num_samples=1) # (B, 1) start_tokens[:, :, i] = token generated = [start_tokens.squeeze(0).squeeze(0).tolist()] for _ in range(max_len - 1): if h is None and c is None: h = self.init_hidden(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) c = self.init_cell(x).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1) logits, (h, c) = self.forward_extend(start_tokens, h, c) # logits: list of [B, T, V_i] last_logits = [log[:, -1, :] for log in logits] # 마지막 step sampled = [] for i, logit in enumerate(last_logits): if i in [2, 5, 7]: # duration 차원: 더 확정적으로 token = self._top_k_sampling(logit, top_k=top_k, temperature=0.5) # 낮은 온도 else: # pitch, velocity 등: 다양성 허용 token = self._top_k_sampling(logit, top_k=top_k, temperature=1.2) # 약간 높은 온도 sampled.append(token.item()) if sampled == [100, 15, 72, 14, 15, 58, 15]: # EOS 토큰 break else: generated.append(sampled) start_tokens = torch.tensor([[sampled]], device=x.device) # [1,1,7] return torch.tensor(generated, device=x.device, dtype=torch.long) # (max_len, 7)