rrayy commited on
Commit
09f103b
·
1 Parent(s): 9cff955

Changes to be committed: 한번에 생성하는 코드 중간에 남기기

Browse files
Files changed (1) hide show
  1. Models/Vector2MIDI.py +35 -65
Models/Vector2MIDI.py CHANGED
@@ -1,6 +1,7 @@
1
- from torch import tanh, zeros, no_grad, full_like, topk, multinomial, cat, int64, nn, stack
2
  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
3
  import torch.nn.functional as F
 
 
4
 
5
  class Vector2MIDI(nn.Module):
6
  def __init__(self, hidden_dim, input_dim=25, dropout=0.2):
@@ -31,8 +32,8 @@ class Vector2MIDI(nn.Module):
31
 
32
  def init_hidden_states(self, x):
33
  """초기 hidden과 cell state 생성"""
34
- h0 = tanh(self.init_hidden(x)) # 활성화 함수 추가 (hyperbolic tangent)
35
- c0 = tanh(self.init_cell(x))
36
 
37
  h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H)
38
  c0 = c0.unsqueeze(0).repeat(2, 1, 1)
@@ -52,11 +53,11 @@ class Vector2MIDI(nn.Module):
52
  embeddings.append(dim_onehot)
53
 
54
  # 모든 차원을 연결
55
- full_embedding = cat(embeddings, dim=-1) # (B, T, total_vocab)
56
 
57
  # 스타일 컨텍스트를 각 타임스텝에 추가
58
  style_expanded = style_context.unsqueeze(1).expand(-1, seq_len, -1) # (B, T, hidden_dim//2)
59
- combined_input = cat([full_embedding, style_expanded], dim=-1) # (B, T, total_vocab + hidden_dim//2)
60
 
61
  return self.input_embedding(combined_input)
62
 
@@ -85,68 +86,37 @@ class Vector2MIDI(nn.Module):
85
  outputs.append(dim_logits)
86
 
87
  return outputs
 
 
 
 
 
 
 
 
 
88
 
89
- def generate(self, x, device, max_steps=1024, temperature:float=1.0, top_k=None):
90
- self.eval()
91
  x = x.to(device)
92
  batch_size = x.size(0)
93
-
94
- h, c = self.init_hidden_states(x)
95
- style_context = self.style_context(x) # (B, hidden_dim//2)
96
- current_tokens = self.generate_start_tokens_from_style(x) # 첫 토큰
97
- generated_tokens = zeros(batch_size, max_steps, 7, dtype=int64, device=device) # 생성될 토큰 저장 Tenosr
98
-
99
- with no_grad():
100
- for step in range(max_steps):
101
- # 현재 토큰을 임베딩으로 변환
102
- embedded = self.tokens_to_embedding(current_tokens, style_context)
103
- lstm_out, (h, c) = self.lstm(embedded, (h, c)) # lstm_out: (B,1,H)
104
- hidden = self.fc_mid(lstm_out[:, -1, :]) # (B, 256)
105
-
106
- # 각 차원별로 다음 토큰 생성
107
- next_tokens = []
108
- for head in self.output_heads:
109
- logits = head(hidden) # (B, vocab_size_i)
110
-
111
- if temperature != 1.0:
112
- logits = logits / temperature
113
-
114
- if top_k is not None and top_k > 0:
115
- k = min(top_k, logits.size(-1))
116
- topk_vals, topk_idx = topk(logits, k, dim=-1)
117
- # create mask with very low values
118
- low_val = -1e9
119
- mask = full_like(logits, low_val)
120
- logits = mask.scatter(-1, topk_idx, topk_vals)
121
 
122
- probs = F.softmax(logits, dim=-1)
123
- token = multinomial(probs, num_samples=1) # (B,1)
124
- next_tokens.append(token)
125
-
126
- current_tokens = cat(next_tokens, dim=1).unsqueeze(1) # (B, 1, 7)
127
- generated_tokens[:, step, :] = current_tokens.squeeze(1)
128
-
129
- # 종료 조건
130
- if (current_tokens == -1).all():
131
- break
132
-
133
- return generated_tokens
134
-
135
- def generate_start_tokens_from_style(self, x):
136
- """스타일 벡터에서 첫 토큰 생성"""
137
- batch_size = x.size(0)
138
- start_tokens = zeros(batch_size, 1, 7, dtype=int64, device=x.device)
139
-
140
- for i, head in enumerate(self.start_token_heads):
141
- logits = head(x) # (B, vocab_size_i)
142
-
143
- # 스타일 기반 첫 토큰 샘플링
144
- if i in [1, 4, 6]: # duration 차원: 더 확정적으로
145
- probs = F.softmax(logits / 0.5, dim=-1) # 낮은 온도
146
- else: # pitch, velocity 등: 다양성 허용
147
- probs = F.softmax(logits / 1.2, dim=-1) # 약간 높은 온도
148
-
149
- token = multinomial(probs, num_samples=1) # (B, 1)
150
- start_tokens[:, :, i] = token
151
 
152
- return start_tokens
 
 
 
 
 
 
 
 
1
  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
2
  import torch.nn.functional as F
3
+ from torch import nn
4
+ import torch
5
 
6
  class Vector2MIDI(nn.Module):
7
  def __init__(self, hidden_dim, input_dim=25, dropout=0.2):
 
32
 
33
  def init_hidden_states(self, x):
34
  """초기 hidden과 cell state 생성"""
35
+ h0 = torch.tanh(self.init_hidden(x)) # 활성화 함수 추가 (hyperbolic tangent)
36
+ c0 = torch.tanh(self.init_cell(x))
37
 
38
  h0 = h0.unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H)
39
  c0 = c0.unsqueeze(0).repeat(2, 1, 1)
 
53
  embeddings.append(dim_onehot)
54
 
55
  # 모든 차원을 연결
56
+ full_embedding = torch.cat(embeddings, dim=-1) # (B, T, total_vocab)
57
 
58
  # 스타일 컨텍스트를 각 타임스텝에 추가
59
  style_expanded = style_context.unsqueeze(1).expand(-1, seq_len, -1) # (B, T, hidden_dim//2)
60
+ combined_input = torch.cat([full_embedding, style_expanded], dim=-1) # (B, T, total_vocab + hidden_dim//2)
61
 
62
  return self.input_embedding(combined_input)
63
 
 
86
  outputs.append(dim_logits)
87
 
88
  return outputs
89
+
90
+ def top_k_filtering(self, logits, top_k):
91
+ """Top-k 필터링"""
92
+ if top_k > 0:
93
+ # logits의 마지막 차원에서만 top-k 선택
94
+ values, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
95
+ min_values = values[..., -1:] # 마지막 k번째 값, 모든 차원 유지
96
+ logits = torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits)
97
+ return logits
98
 
99
+ def generate(self, x, device, seq_len=64, temperature:float=1.2, top_k=5):
100
+ self.eval() # autogressive로 한 타임 한 타임 생성하는 거 말고, forward를 이용해서 한 번에 생성하기로 변경
101
  x = x.to(device)
102
  batch_size = x.size(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ generated_sequence = torch.zeros((batch_size, seq_len, 7), dtype=torch.long, device=device)
105
+ lengths = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
106
+
107
+ with torch.no_grad():
108
+ logits_list = self.forward(x, lengths, generated_sequence)
109
+
110
+ # 첫 번째 토큰 이후부터 샘플링
111
+ for i, logits in enumerate(logits_list):
112
+ dim_logits = logits[:, :-1, :] / temperature # (B, T-1, vocab)
113
+ dim_logits = self.top_k_filtering(dim_logits, top_k)
114
+ probs = F.softmax(dim_logits, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # 각 타임스텝별로 샘플링
117
+ for t in range(seq_len):
118
+ if t < dim_logits.size(1):
119
+ sampled = torch.multinomial(probs[:, t, :], 1).squeeze(-1)
120
+ generated_sequence[:, t, i] = sampled
121
+
122
+ return generated_sequence