rrayy commited on
Commit
2a6b7c9
·
1 Parent(s): 42650ad

Changes to be committed: 전처리 오류 수정, 학습 루프 구성

Browse files

modified: DIVA_dataset.pt
modified: Models/Vector2MIDI.py
modified: preprocessing.ipynb
modified: train.ipynb

Files changed (4) hide show
  1. DIVA_dataset.pt +2 -2
  2. Models/Vector2MIDI.py +66 -28
  3. preprocessing.ipynb +10 -65
  4. train.ipynb +125 -1035
DIVA_dataset.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:51a657804e01360dbf4ae774d45e959d3955e3be0b7f9a84e467c5911d5f7cc3
3
- size 243341
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4db440c793f1db309541cace07cf4f2b83290173f9d5889ca31349fbde0377
3
+ size 243790
Models/Vector2MIDI.py CHANGED
@@ -1,45 +1,83 @@
1
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
2
- import torch.nn as nn
3
  import torch
 
 
 
 
4
 
5
  class Vector2MIDI(nn.Module):
6
  def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
7
- super().__init__() # 부모 클래스 생성자 호출
8
- self.input_fc = nn.Linear(input_dim, hidden_dim) # 입력 차원에서 은닉 차원으로 변환
 
 
 
 
9
 
10
  # 과적합 방지 드롭아웃 LSTM
11
- self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
12
 
13
  self.fc_mid = nn.Linear(hidden_dim, 256)
14
  self.fc_out = nn.Linear(256, n_vocab)
15
 
16
- def forward(self, x, lengths, total_length=None):
17
- print("input to forward:", x.shape)
18
- B, feat_dim = x.size()
19
- T = lengths.max()
20
-
21
- # [B, 1, feat_dim] → [B, T, feat_dim]
22
- x = x.unsqueeze(1).expand(B, T, feat_dim)
23
 
24
- x = self.input_fc(x)
 
25
 
26
- packed_x = nn.utils.rnn.pack_padded_sequence(
27
- x, lengths.cpu(), batch_first=True, enforce_sorted=False
28
- )
29
- packed_out, _ = self.lstm(packed_x)
30
-
31
- out, _ = nn.utils.rnn.pad_packed_sequence(
32
- packed_out, batch_first=True, total_length=total_length
33
- )
34
 
35
  out = self.fc_mid(out)
36
- out = self.fc_out(out) # [B, max_len, vocab_size]
 
37
  return out
38
 
39
- def generate(self, x, lengths, total_length=None):
40
- out = self.forward(x, lengths, total_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- preds = torch.argmax(out, dim=-1) # [B, T], 가장 큰 점수 클래스 선택
43
- external = preds - 2 # 내부 표현 → 외부 표현
44
- external[external == -2] = 0 # PAD 처리
45
- return external
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
+
6
 
7
  class Vector2MIDI(nn.Module):
8
  def __init__(self, input_dim, hidden_dim, n_vocab, dropout=0.2):
9
+ super().__init__()
10
+
11
+ self.n_vocab = n_vocab
12
+
13
+ self.init_hidden = nn.Linear(input_dim, hidden_dim)
14
+ self.init_cell = nn.Linear(input_dim, hidden_dim)
15
 
16
  # 과적합 방지 드롭아웃 LSTM
17
+ self.lstm = nn.LSTM(n_vocab, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)
18
 
19
  self.fc_mid = nn.Linear(hidden_dim, 256)
20
  self.fc_out = nn.Linear(256, n_vocab)
21
 
22
+ def forward(self, x, lengths, target_tokens):
23
+ """
24
+ x: (B, input_dim) - 입력 벡터
25
+ lengths: [B] - 시퀀스 길이
26
+ target_tokens: (B, T, n_vocab) - one-hot 또는 임베딩된 토큰 입력
27
+ """
28
+ B = x.size(0)
29
 
30
+ h0 = self.init_hidden(x).unsqueeze(0).repeat(2, 1, 1) # (num_layers, B, H)
31
+ c0 = self.init_cell(x).unsqueeze(0).repeat(2, 1, 1)
32
 
33
+ packed_input = pack_padded_sequence(target_tokens, lengths.cpu(), batch_first=True, enforce_sorted=False)
34
+ packed_out, _ = self.lstm(packed_input, (h0, c0))
35
+ out, _ = pad_packed_sequence(packed_out, batch_first=True)
 
 
 
 
 
36
 
37
  out = self.fc_mid(out)
38
+ out = self.fc_out(out) # (B, T, vocab_size)
39
+
40
  return out
41
 
42
+ def generate(self, vector, device, max_len=512, temperature=1.0, top_k=None, start_token_idx=0, end_token_idx=None):
43
+ """
44
+ 스타일 벡터 하나로 시퀀스 생성 (확률적 샘플링 기반)
45
+ """
46
+ self.eval()
47
+
48
+ vector = vector.to(device)
49
+ h = self.init_hidden(vector).unsqueeze(0).repeat(2, 1, 1) # (num_layers, 1, hidden)
50
+ c = self.init_cell(vector).unsqueeze(0).repeat(2, 1, 1)
51
+
52
+ # one-hot start token
53
+ x = F.one_hot(torch.tensor([start_token_idx], device=device), num_classes=self.n_vocab).float()
54
+ x = x.unsqueeze(1) # (1, 1, n_vocab)
55
+
56
+ outputs = []
57
+
58
+ for _ in range(max_len):
59
+ out, (h, c) = self.lstm(x, (h, c)) # (1, 1, hidden)
60
+ out = self.fc_mid(out)
61
+ logits = self.fc_out(out).squeeze(0).squeeze(0) # (n_vocab,)
62
+
63
+ # temperature scaling
64
+ logits = logits / temperature
65
+ probs = F.softmax(logits, dim=-1)
66
+
67
+ # top-k filtering
68
+ if top_k is not None:
69
+ top_vals, top_idx = torch.topk(probs, top_k)
70
+ probs = torch.zeros_like(probs).scatter_(0, top_idx, top_vals)
71
+ probs = probs / probs.sum()
72
+
73
+ pred_token = torch.multinomial(probs, 1).item()
74
+
75
+ if end_token_idx is not None and pred_token == end_token_idx:
76
+ break
77
+
78
+ outputs.append(pred_token)
79
+
80
+ # 다음 timestep의 입력으로 사용
81
+ x = F.one_hot(torch.tensor([pred_token], device=device), num_classes=self.n_vocab).float().unsqueeze(1)
82
 
83
+ return outputs
 
 
 
preprocessing.ipynb CHANGED
@@ -307,66 +307,10 @@
307
  },
308
  {
309
  "cell_type": "code",
310
- "execution_count": 3,
311
  "id": "f7b77c0c",
312
  "metadata": {},
313
- "outputs": [
314
- {
315
- "name": "stdout",
316
- "output_type": "stream",
317
- "text": [
318
- "tensor([[[81, 3, 65, ..., 3, 53, 3],\n",
319
- " [ 0, 1, 0, ..., 1, 0, 1],\n",
320
- " [81, 2, 65, ..., 2, 53, 2],\n",
321
- " ...,\n",
322
- " [-1, -1, -1, ..., -1, -1, -1],\n",
323
- " [-1, -1, -1, ..., -1, -1, -1],\n",
324
- " [-1, -1, -1, ..., -1, -1, -1]],\n",
325
- "\n",
326
- " [[77, 2, 65, ..., 2, 53, 2],\n",
327
- " [ 0, 2, 0, ..., 2, 0, 2],\n",
328
- " [89, 1, 65, ..., 1, 53, 1],\n",
329
- " ...,\n",
330
- " [-1, -1, -1, ..., -1, -1, -1],\n",
331
- " [-1, -1, -1, ..., -1, -1, -1],\n",
332
- " [-1, -1, -1, ..., -1, -1, -1]],\n",
333
- "\n",
334
- " [[78, 2, 63, ..., 2, 51, 2],\n",
335
- " [ 0, 2, 0, ..., 2, 0, 2],\n",
336
- " [78, 1, 63, ..., 1, 51, 2],\n",
337
- " ...,\n",
338
- " [-1, -1, -1, ..., -1, -1, -1],\n",
339
- " [-1, -1, -1, ..., -1, -1, -1],\n",
340
- " [-1, -1, -1, ..., -1, -1, -1]],\n",
341
- "\n",
342
- " ...,\n",
343
- "\n",
344
- " [[74, 2, 62, ..., 2, 50, 2],\n",
345
- " [ 0, 2, 0, ..., 2, 0, 2],\n",
346
- " [76, 2, 62, ..., 2, 50, 2],\n",
347
- " ...,\n",
348
- " [-1, -1, -1, ..., -1, -1, -1],\n",
349
- " [-1, -1, -1, ..., -1, -1, -1],\n",
350
- " [-1, -1, -1, ..., -1, -1, -1]],\n",
351
- "\n",
352
- " [[ 0, 4, 0, ..., 4, 53, 4],\n",
353
- " [91, 2, 0, ..., 2, 53, 2],\n",
354
- " [ 0, 2, 0, ..., 2, 0, 2],\n",
355
- " ...,\n",
356
- " [-1, -1, -1, ..., -1, -1, -1],\n",
357
- " [-1, -1, -1, ..., -1, -1, -1],\n",
358
- " [-1, -1, -1, ..., -1, -1, -1]],\n",
359
- "\n",
360
- " [[75, 2, 68, ..., 2, 51, 2],\n",
361
- " [ 0, 2, 0, ..., 2, 0, 2],\n",
362
- " [84, 2, 68, ..., 2, 51, 2],\n",
363
- " ...,\n",
364
- " [-1, -1, -1, ..., -1, -1, -1],\n",
365
- " [-1, -1, -1, ..., -1, -1, -1],\n",
366
- " [-1, -1, -1, ..., -1, -1, -1]]])\n"
367
- ]
368
- }
369
- ],
370
  "source": [
371
  "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n",
372
  "from sklearn.compose import ColumnTransformer\n",
@@ -388,16 +332,15 @@
388
  "X_tensor = torch.tensor(X, dtype=torch.float32)\n",
389
  "Y_tensor = [torch.tensor(item['token'], dtype=torch.long) for item in tokenized_data]\n",
390
  "\n",
391
- "seq_lengths = [len(seq) for seq in Y_tensor]\n",
392
  "\n",
393
  "# 패딩 처리\n",
394
- "padded_Y = pad_sequence(Y_tensor, batch_first=True, padding_value=-1) # (batch_size, max_len, 7)\n",
395
- "print(padded_Y)"
396
  ]
397
  },
398
  {
399
  "cell_type": "code",
400
- "execution_count": 4,
401
  "id": "dd840788",
402
  "metadata": {},
403
  "outputs": [
@@ -406,18 +349,20 @@
406
  "output_type": "stream",
407
  "text": [
408
  "X shape: torch.Size([34, 25])\n",
409
- "Y shape: torch.Size([34, 125, 7])\n"
 
410
  ]
411
  }
412
  ],
413
  "source": [
414
  "print(\"X shape:\", X_tensor.shape)\n",
415
- "print(\"Y shape:\", padded_Y.shape)"
 
416
  ]
417
  },
418
  {
419
  "cell_type": "code",
420
- "execution_count": 5,
421
  "id": "4f5f5dc1",
422
  "metadata": {},
423
  "outputs": [],
 
307
  },
308
  {
309
  "cell_type": "code",
310
+ "execution_count": 7,
311
  "id": "f7b77c0c",
312
  "metadata": {},
313
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  "source": [
315
  "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n",
316
  "from sklearn.compose import ColumnTransformer\n",
 
332
  "X_tensor = torch.tensor(X, dtype=torch.float32)\n",
333
  "Y_tensor = [torch.tensor(item['token'], dtype=torch.long) for item in tokenized_data]\n",
334
  "\n",
335
+ "seq_lengths = torch.tensor([len(seq) for seq in Y_tensor])\n",
336
  "\n",
337
  "# 패딩 처리\n",
338
+ "padded_Y = pad_sequence(Y_tensor, batch_first=True, padding_value=-1) # (batch_size, max_len, 7)"
 
339
  ]
340
  },
341
  {
342
  "cell_type": "code",
343
+ "execution_count": 8,
344
  "id": "dd840788",
345
  "metadata": {},
346
  "outputs": [
 
349
  "output_type": "stream",
350
  "text": [
351
  "X shape: torch.Size([34, 25])\n",
352
+ "Y shape: torch.Size([34, 125, 7])\n",
353
+ "l shape: torch.Size([34])\n"
354
  ]
355
  }
356
  ],
357
  "source": [
358
  "print(\"X shape:\", X_tensor.shape)\n",
359
+ "print(\"Y shape:\", padded_Y.shape)\n",
360
+ "print(\"l shape:\", seq_lengths.shape)"
361
  ]
362
  },
363
  {
364
  "cell_type": "code",
365
+ "execution_count": 9,
366
  "id": "4f5f5dc1",
367
  "metadata": {},
368
  "outputs": [],
train.ipynb CHANGED
@@ -32,21 +32,24 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 1,
36
  "id": "630dd7ad",
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
40
  "from Models.Vector2MIDI import Vector2MIDI # 클래스 정의가 필요\n",
41
  "import torch.optim as optim\n",
42
- "import torch.nn as nn\n",
 
 
43
  "import torch\n",
44
  "\n",
45
  "device = torch.device(\"cuda\") # GPU 사용\n",
46
  "#device = torch.device(\"cpu\") # CPU 사용\n",
47
  "\n",
48
- "model = Vector2MIDI(25, 128, 303).to(device)\n",
49
- "criterion = nn.CrossEntropyLoss(ignore_index=0) # 손실함수 패딩(0) 무시\n",
 
50
  "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
51
  ]
52
  },
@@ -61,7 +64,7 @@
61
  "output_type": "stream",
62
  "text": [
63
  "X_tensor shape: torch.Size([34, 25])\n",
64
- "Y_tensor shape: torch.Size([34, 1185])\n",
65
  "lengths shape: torch.Size([34])\n"
66
  ]
67
  }
@@ -69,12 +72,12 @@
69
  "source": [
70
  "# 전처리 데이터 로드\n",
71
  "from torch.utils.data import DataLoader\n",
72
- "from dataset import MIDIDataset\n",
73
  "import torch\n",
74
  "\n",
75
  "data = torch.load(\"DIVA_dataset.pt\")\n",
76
- "X_tensor = data[\"X\"]\n",
77
- "Y_tensor = data[\"Y\"]\n",
78
  "lengths = data[\"lengths\"]\n",
79
  "\n",
80
  "print(\"X_tensor shape:\", X_tensor.shape)\n",
@@ -124,521 +127,38 @@
124
  },
125
  {
126
  "cell_type": "code",
127
- "execution_count": 4,
128
  "id": "16a14b5f",
129
  "metadata": {},
130
  "outputs": [
 
 
 
 
 
 
 
 
131
  {
132
  "name": "stdout",
133
  "output_type": "stream",
134
  "text": [
135
- "Epoch 1, Loss: 5.5722\n",
136
- "Epoch 2, Loss: 4.4941\n",
137
- "Epoch 3, Loss: 3.0559\n",
138
- "Epoch 4, Loss: 2.9349\n",
139
- "Epoch 5, Loss: 2.8127\n",
140
- "Epoch 6, Loss: 2.8296\n",
141
- "Epoch 7, Loss: 2.8393\n",
142
- "Epoch 8, Loss: 2.7751\n",
143
- "Epoch 9, Loss: 2.7641\n",
144
- "Epoch 10, Loss: 2.7667\n",
145
- "Epoch 11, Loss: 2.7747\n",
146
- "Epoch 12, Loss: 2.7678\n",
147
- "Epoch 13, Loss: 2.7582\n",
148
- "Epoch 14, Loss: 2.7487\n",
149
- "Epoch 15, Loss: 2.7586\n",
150
- "Epoch 16, Loss: 2.7464\n",
151
- "Epoch 17, Loss: 2.7516\n",
152
- "Epoch 18, Loss: 2.7802\n",
153
- "Epoch 19, Loss: 2.7607\n",
154
- "Epoch 20, Loss: 2.7453\n",
155
- "Epoch 21, Loss: 2.7715\n",
156
- "Epoch 22, Loss: 2.7349\n",
157
- "Epoch 23, Loss: 2.7714\n",
158
- "Epoch 24, Loss: 2.7450\n",
159
- "Epoch 25, Loss: 2.7568\n",
160
- "Epoch 26, Loss: 2.7671\n",
161
- "Epoch 27, Loss: 2.7577\n",
162
- "Epoch 28, Loss: 2.7876\n",
163
- "Epoch 29, Loss: 2.7568\n",
164
- "Epoch 30, Loss: 2.7816\n",
165
- "Epoch 31, Loss: 2.7863\n",
166
- "Epoch 32, Loss: 2.7670\n",
167
- "Epoch 33, Loss: 2.7259\n",
168
- "Epoch 34, Loss: 2.6940\n",
169
- "Epoch 35, Loss: 2.7501\n",
170
- "Epoch 36, Loss: 2.7534\n",
171
- "Epoch 37, Loss: 2.7127\n",
172
- "Epoch 38, Loss: 2.7385\n",
173
- "Epoch 39, Loss: 2.7298\n",
174
- "Epoch 40, Loss: 2.7018\n",
175
- "Epoch 41, Loss: 2.7540\n",
176
- "Epoch 42, Loss: 2.7208\n",
177
- "Epoch 43, Loss: 2.7112\n",
178
- "Epoch 44, Loss: 2.6953\n",
179
- "Epoch 45, Loss: 2.7329\n",
180
- "Epoch 46, Loss: 2.7132\n",
181
- "Epoch 47, Loss: 2.7179\n",
182
- "Epoch 48, Loss: 2.6945\n",
183
- "Epoch 49, Loss: 2.7133\n",
184
- "Epoch 50, Loss: 2.7182\n",
185
- "Epoch 51, Loss: 2.7321\n",
186
- "Epoch 52, Loss: 2.7044\n",
187
- "Epoch 53, Loss: 2.7128\n",
188
- "Epoch 54, Loss: 2.7104\n",
189
- "Epoch 55, Loss: 2.7089\n",
190
- "Epoch 56, Loss: 2.7058\n",
191
- "Epoch 57, Loss: 2.7132\n",
192
- "Epoch 58, Loss: 2.7087\n",
193
- "Epoch 59, Loss: 2.7084\n",
194
- "Epoch 60, Loss: 2.7122\n",
195
- "Epoch 61, Loss: 2.6939\n",
196
- "Epoch 62, Loss: 2.6903\n",
197
- "Epoch 63, Loss: 2.6926\n",
198
- "Epoch 64, Loss: 2.6913\n",
199
- "Epoch 65, Loss: 2.6882\n",
200
- "Epoch 66, Loss: 2.7012\n",
201
- "Epoch 67, Loss: 2.7040\n",
202
- "Epoch 68, Loss: 2.7105\n",
203
- "Epoch 69, Loss: 2.6827\n",
204
- "Epoch 70, Loss: 2.7106\n",
205
- "Epoch 71, Loss: 2.6837\n",
206
- "Epoch 72, Loss: 2.6731\n",
207
- "Epoch 73, Loss: 2.7332\n",
208
- "Epoch 74, Loss: 2.6874\n",
209
- "Epoch 75, Loss: 2.6647\n",
210
- "Epoch 76, Loss: 2.6892\n",
211
- "Epoch 77, Loss: 2.6797\n",
212
- "Epoch 78, Loss: 2.6951\n",
213
- "Epoch 79, Loss: 2.6704\n",
214
- "Epoch 80, Loss: 2.6847\n",
215
- "Epoch 81, Loss: 2.6603\n",
216
- "Epoch 82, Loss: 2.6721\n",
217
- "Epoch 83, Loss: 2.6744\n",
218
- "Epoch 84, Loss: 2.6632\n",
219
- "Epoch 85, Loss: 2.6685\n",
220
- "Epoch 86, Loss: 2.6679\n",
221
- "Epoch 87, Loss: 2.7077\n",
222
- "Epoch 88, Loss: 2.6781\n",
223
- "Epoch 89, Loss: 2.6045\n",
224
- "Epoch 90, Loss: 2.6393\n",
225
- "Epoch 91, Loss: 2.6398\n",
226
- "Epoch 92, Loss: 2.6777\n",
227
- "Epoch 93, Loss: 2.6694\n",
228
- "Epoch 94, Loss: 2.5960\n",
229
- "Epoch 95, Loss: 2.6379\n",
230
- "Epoch 96, Loss: 2.6252\n",
231
- "Epoch 97, Loss: 2.6191\n",
232
- "Epoch 98, Loss: 2.5861\n",
233
- "Epoch 99, Loss: 2.6027\n",
234
- "Epoch 100, Loss: 2.5682\n",
235
- "Epoch 101, Loss: 2.5900\n",
236
- "Epoch 102, Loss: 2.5685\n",
237
- "Epoch 103, Loss: 2.5704\n",
238
- "Epoch 104, Loss: 2.5715\n",
239
- "Epoch 105, Loss: 2.5142\n",
240
- "Epoch 106, Loss: 2.5458\n",
241
- "Epoch 107, Loss: 2.5558\n",
242
- "Epoch 108, Loss: 2.5480\n",
243
- "Epoch 109, Loss: 2.4956\n",
244
- "Epoch 110, Loss: 2.4933\n",
245
- "Epoch 111, Loss: 2.5003\n",
246
- "Epoch 112, Loss: 2.5570\n",
247
- "Epoch 113, Loss: 2.4918\n",
248
- "Epoch 114, Loss: 2.4801\n",
249
- "Epoch 115, Loss: 2.4920\n",
250
- "Epoch 116, Loss: 2.4286\n",
251
- "Epoch 117, Loss: 2.4576\n",
252
- "Epoch 118, Loss: 2.4352\n",
253
- "Epoch 119, Loss: 2.4678\n",
254
- "Epoch 120, Loss: 2.4387\n",
255
- "Epoch 121, Loss: 2.3994\n",
256
- "Epoch 122, Loss: 2.4908\n",
257
- "Epoch 123, Loss: 2.5221\n",
258
- "Epoch 124, Loss: 2.4906\n",
259
- "Epoch 125, Loss: 2.4710\n",
260
- "Epoch 126, Loss: 2.4122\n",
261
- "Epoch 127, Loss: 2.4305\n",
262
- "Epoch 128, Loss: 2.4538\n",
263
- "Epoch 129, Loss: 2.4856\n",
264
- "Epoch 130, Loss: 2.4429\n",
265
- "Epoch 131, Loss: 2.4655\n",
266
- "Epoch 132, Loss: 2.4415\n",
267
- "Epoch 133, Loss: 2.3814\n",
268
- "Epoch 134, Loss: 2.3545\n",
269
- "Epoch 135, Loss: 2.3763\n",
270
- "Epoch 136, Loss: 2.3961\n",
271
- "Epoch 137, Loss: 2.3468\n",
272
- "Epoch 138, Loss: 2.3336\n",
273
- "Epoch 139, Loss: 2.4034\n",
274
- "Epoch 140, Loss: 2.3725\n",
275
- "Epoch 141, Loss: 2.4021\n",
276
- "Epoch 142, Loss: 2.3808\n",
277
- "Epoch 143, Loss: 2.3289\n",
278
- "Epoch 144, Loss: 2.3159\n",
279
- "Epoch 145, Loss: 2.3318\n",
280
- "Epoch 146, Loss: 2.3034\n",
281
- "Epoch 147, Loss: 2.3529\n",
282
- "Epoch 148, Loss: 2.3036\n",
283
- "Epoch 149, Loss: 2.3199\n",
284
- "Epoch 150, Loss: 2.3196\n",
285
- "Epoch 151, Loss: 2.3636\n",
286
- "Epoch 152, Loss: 2.3570\n",
287
- "Epoch 153, Loss: 2.3199\n",
288
- "Epoch 154, Loss: 2.3474\n",
289
- "Epoch 155, Loss: 2.3376\n",
290
- "Epoch 156, Loss: 2.3279\n",
291
- "Epoch 157, Loss: 2.2390\n",
292
- "Epoch 158, Loss: 2.2388\n",
293
- "Epoch 159, Loss: 2.2676\n",
294
- "Epoch 160, Loss: 2.2972\n",
295
- "Epoch 161, Loss: 2.2829\n",
296
- "Epoch 162, Loss: 2.2616\n",
297
- "Epoch 163, Loss: 2.2766\n",
298
- "Epoch 164, Loss: 2.2627\n",
299
- "Epoch 165, Loss: 2.2495\n",
300
- "Epoch 166, Loss: 2.2004\n",
301
- "Epoch 167, Loss: 2.2778\n",
302
- "Epoch 168, Loss: 2.2962\n",
303
- "Epoch 169, Loss: 2.2827\n",
304
- "Epoch 170, Loss: 2.2971\n",
305
- "Epoch 171, Loss: 2.3118\n",
306
- "Epoch 172, Loss: 2.3182\n",
307
- "Epoch 173, Loss: 2.2547\n",
308
- "Epoch 174, Loss: 2.2558\n",
309
- "Epoch 175, Loss: 2.2282\n",
310
- "Epoch 176, Loss: 2.2383\n",
311
- "Epoch 177, Loss: 2.2618\n",
312
- "Epoch 178, Loss: 2.3070\n",
313
- "Epoch 179, Loss: 2.2940\n",
314
- "Epoch 180, Loss: 2.2900\n",
315
- "Epoch 181, Loss: 2.2707\n",
316
- "Epoch 182, Loss: 2.2771\n",
317
- "Epoch 183, Loss: 2.2522\n",
318
- "Epoch 184, Loss: 2.2848\n",
319
- "Epoch 185, Loss: 2.2608\n",
320
- "Epoch 186, Loss: 2.2561\n",
321
- "Epoch 187, Loss: 2.2359\n",
322
- "Epoch 188, Loss: 2.2281\n",
323
- "Epoch 189, Loss: 2.2654\n",
324
- "Epoch 190, Loss: 2.2352\n",
325
- "Epoch 191, Loss: 2.2345\n",
326
- "Epoch 192, Loss: 2.2051\n",
327
- "Epoch 193, Loss: 2.2366\n",
328
- "Epoch 194, Loss: 2.2277\n",
329
- "Epoch 195, Loss: 2.2157\n",
330
- "Epoch 196, Loss: 2.2035\n",
331
- "Epoch 197, Loss: 2.2165\n",
332
- "Epoch 198, Loss: 2.1960\n",
333
- "Epoch 199, Loss: 2.1935\n",
334
- "Epoch 200, Loss: 2.1733\n",
335
- "Epoch 201, Loss: 2.1952\n",
336
- "Epoch 202, Loss: 2.2147\n",
337
- "Epoch 203, Loss: 2.2009\n",
338
- "Epoch 204, Loss: 2.2007\n",
339
- "Epoch 205, Loss: 2.2009\n",
340
- "Epoch 206, Loss: 2.1869\n",
341
- "Epoch 207, Loss: 2.2003\n",
342
- "Epoch 208, Loss: 2.1820\n",
343
- "Epoch 209, Loss: 2.1792\n",
344
- "Epoch 210, Loss: 2.1342\n",
345
- "Epoch 211, Loss: 2.1702\n",
346
- "Epoch 212, Loss: 2.1725\n",
347
- "Epoch 213, Loss: 2.1858\n",
348
- "Epoch 214, Loss: 2.2000\n",
349
- "Epoch 215, Loss: 2.2105\n",
350
- "Epoch 216, Loss: 2.2001\n",
351
- "Epoch 217, Loss: 2.2053\n",
352
- "Epoch 218, Loss: 2.1930\n",
353
- "Epoch 219, Loss: 2.1952\n",
354
- "Epoch 220, Loss: 2.1801\n",
355
- "Epoch 221, Loss: 2.1974\n",
356
- "Epoch 222, Loss: 2.1729\n",
357
- "Epoch 223, Loss: 2.1888\n",
358
- "Epoch 224, Loss: 2.1726\n",
359
- "Epoch 225, Loss: 2.1784\n",
360
- "Epoch 226, Loss: 2.1842\n",
361
- "Epoch 227, Loss: 2.1666\n",
362
- "Epoch 228, Loss: 2.1716\n",
363
- "Epoch 229, Loss: 2.1842\n",
364
- "Epoch 230, Loss: 2.2009\n",
365
- "Epoch 231, Loss: 2.1572\n",
366
- "Epoch 232, Loss: 2.1893\n",
367
- "Epoch 233, Loss: 2.1743\n",
368
- "Epoch 234, Loss: 2.1879\n",
369
- "Epoch 235, Loss: 2.2010\n",
370
- "Epoch 236, Loss: 2.1647\n",
371
- "Epoch 237, Loss: 2.1824\n",
372
- "Epoch 238, Loss: 2.1587\n",
373
- "Epoch 239, Loss: 2.1600\n",
374
- "Epoch 240, Loss: 2.1812\n",
375
- "Epoch 241, Loss: 2.1634\n",
376
- "Epoch 242, Loss: 2.1967\n",
377
- "Epoch 243, Loss: 2.1885\n",
378
- "Epoch 244, Loss: 2.2712\n",
379
- "Epoch 245, Loss: 2.2293\n",
380
- "Epoch 246, Loss: 2.2223\n",
381
- "Epoch 247, Loss: 2.2170\n",
382
- "Epoch 248, Loss: 2.1715\n",
383
- "Epoch 249, Loss: 2.1775\n",
384
- "Epoch 250, Loss: 2.1647\n",
385
- "Epoch 251, Loss: 2.2222\n",
386
- "Epoch 252, Loss: 2.2520\n",
387
- "Epoch 253, Loss: 2.2165\n",
388
- "Epoch 254, Loss: 2.2536\n",
389
- "Epoch 255, Loss: 2.2271\n",
390
- "Epoch 256, Loss: 2.2170\n",
391
- "Epoch 257, Loss: 2.1954\n",
392
- "Epoch 258, Loss: 2.2052\n",
393
- "Epoch 259, Loss: 2.1957\n",
394
- "Epoch 260, Loss: 2.1667\n",
395
- "Epoch 261, Loss: 2.1562\n",
396
- "Epoch 262, Loss: 2.1628\n",
397
- "Epoch 263, Loss: 2.1560\n",
398
- "Epoch 264, Loss: 2.1331\n",
399
- "Epoch 265, Loss: 2.1412\n",
400
- "Epoch 266, Loss: 2.1485\n",
401
- "Epoch 267, Loss: 2.1643\n",
402
- "Epoch 268, Loss: 2.1463\n",
403
- "Epoch 269, Loss: 2.1454\n",
404
- "Epoch 270, Loss: 2.1164\n",
405
- "Epoch 271, Loss: 2.1524\n",
406
- "Epoch 272, Loss: 2.1413\n",
407
- "Epoch 273, Loss: 2.1966\n",
408
- "Epoch 274, Loss: 2.1998\n",
409
- "Epoch 275, Loss: 2.1584\n",
410
- "Epoch 276, Loss: 2.2067\n",
411
- "Epoch 277, Loss: 2.1642\n",
412
- "Epoch 278, Loss: 2.1317\n",
413
- "Epoch 279, Loss: 2.1466\n",
414
- "Epoch 280, Loss: 2.1366\n",
415
- "Epoch 281, Loss: 2.1396\n",
416
- "Epoch 282, Loss: 2.0961\n",
417
- "Epoch 283, Loss: 2.1695\n",
418
- "Epoch 284, Loss: 2.1404\n",
419
- "Epoch 285, Loss: 2.1688\n",
420
- "Epoch 286, Loss: 2.1803\n",
421
- "Epoch 287, Loss: 2.1537\n",
422
- "Epoch 288, Loss: 2.1549\n",
423
- "Epoch 289, Loss: 2.1642\n",
424
- "Epoch 290, Loss: 2.1418\n",
425
- "Epoch 291, Loss: 2.1355\n",
426
- "Epoch 292, Loss: 2.1252\n",
427
- "Epoch 293, Loss: 2.1335\n",
428
- "Epoch 294, Loss: 2.1274\n",
429
- "Epoch 295, Loss: 2.0980\n",
430
- "Epoch 296, Loss: 2.1283\n",
431
- "Epoch 297, Loss: 2.1466\n",
432
- "Epoch 298, Loss: 2.1427\n",
433
- "Epoch 299, Loss: 2.1472\n",
434
- "Epoch 300, Loss: 2.1436\n",
435
- "Epoch 301, Loss: 2.1546\n",
436
- "Epoch 302, Loss: 2.1311\n",
437
- "Epoch 303, Loss: 2.1920\n",
438
- "Epoch 304, Loss: 2.1233\n",
439
- "Epoch 305, Loss: 2.1415\n",
440
- "Epoch 306, Loss: 2.1336\n",
441
- "Epoch 307, Loss: 2.1153\n",
442
- "Epoch 308, Loss: 2.1141\n",
443
- "Epoch 309, Loss: 2.1147\n",
444
- "Epoch 310, Loss: 2.1086\n",
445
- "Epoch 311, Loss: 2.0999\n",
446
- "Epoch 312, Loss: 2.0766\n",
447
- "Epoch 313, Loss: 2.1061\n",
448
- "Epoch 314, Loss: 2.1038\n",
449
- "Epoch 315, Loss: 2.1097\n",
450
- "Epoch 316, Loss: 2.0944\n",
451
- "Epoch 317, Loss: 2.1001\n",
452
- "Epoch 318, Loss: 2.0994\n",
453
- "Epoch 319, Loss: 2.0951\n",
454
- "Epoch 320, Loss: 2.1278\n",
455
- "Epoch 321, Loss: 2.1183\n",
456
- "Epoch 322, Loss: 2.1236\n",
457
- "Epoch 323, Loss: 2.1069\n",
458
- "Epoch 324, Loss: 2.1431\n",
459
- "Epoch 325, Loss: 2.1437\n",
460
- "Epoch 326, Loss: 2.1081\n",
461
- "Epoch 327, Loss: 2.1248\n",
462
- "Epoch 328, Loss: 2.1266\n",
463
- "Epoch 329, Loss: 2.1096\n",
464
- "Epoch 330, Loss: 2.0736\n",
465
- "Epoch 331, Loss: 2.0968\n",
466
- "Epoch 332, Loss: 2.1103\n",
467
- "Epoch 333, Loss: 2.1250\n",
468
- "Epoch 334, Loss: 2.0644\n",
469
- "Epoch 335, Loss: 2.0949\n",
470
- "Epoch 336, Loss: 2.1160\n",
471
- "Epoch 337, Loss: 2.0806\n",
472
- "Epoch 338, Loss: 2.1123\n",
473
- "Epoch 339, Loss: 2.1143\n",
474
- "Epoch 340, Loss: 2.0953\n",
475
- "Epoch 341, Loss: 2.0875\n",
476
- "Epoch 342, Loss: 2.1337\n",
477
- "Epoch 343, Loss: 2.1420\n",
478
- "Epoch 344, Loss: 2.1249\n",
479
- "Epoch 345, Loss: 2.1215\n",
480
- "Epoch 346, Loss: 2.1090\n",
481
- "Epoch 347, Loss: 2.0963\n",
482
- "Epoch 348, Loss: 2.0921\n",
483
- "Epoch 349, Loss: 2.0933\n",
484
- "Epoch 350, Loss: 2.0794\n",
485
- "Epoch 351, Loss: 2.0959\n",
486
- "Epoch 352, Loss: 2.0767\n",
487
- "Epoch 353, Loss: 2.0906\n",
488
- "Epoch 354, Loss: 2.1021\n",
489
- "Epoch 355, Loss: 2.0927\n",
490
- "Epoch 356, Loss: 2.1038\n",
491
- "Epoch 357, Loss: 2.0741\n",
492
- "Epoch 358, Loss: 2.0727\n",
493
- "Epoch 359, Loss: 2.0753\n",
494
- "Epoch 360, Loss: 2.0548\n",
495
- "Epoch 361, Loss: 2.0923\n",
496
- "Epoch 362, Loss: 2.0861\n",
497
- "Epoch 363, Loss: 2.0771\n",
498
- "Epoch 364, Loss: 2.0960\n",
499
- "Epoch 365, Loss: 2.0745\n",
500
- "Epoch 366, Loss: 2.0788\n",
501
- "Epoch 367, Loss: 2.0733\n",
502
- "Epoch 368, Loss: 2.0839\n",
503
- "Epoch 369, Loss: 2.0971\n",
504
- "Epoch 370, Loss: 2.0800\n",
505
- "Epoch 371, Loss: 2.1154\n",
506
- "Epoch 372, Loss: 2.0617\n",
507
- "Epoch 373, Loss: 2.0934\n",
508
- "Epoch 374, Loss: 2.0934\n",
509
- "Epoch 375, Loss: 2.1069\n",
510
- "Epoch 376, Loss: 2.0890\n",
511
- "Epoch 377, Loss: 2.0881\n",
512
- "Epoch 378, Loss: 2.1018\n",
513
- "Epoch 379, Loss: 2.0697\n",
514
- "Epoch 380, Loss: 2.0837\n",
515
- "Epoch 381, Loss: 2.0858\n",
516
- "Epoch 382, Loss: 2.0811\n",
517
- "Epoch 383, Loss: 2.0630\n",
518
- "Epoch 384, Loss: 2.0845\n",
519
- "Epoch 385, Loss: 2.0732\n",
520
- "Epoch 386, Loss: 2.0704\n",
521
- "Epoch 387, Loss: 2.0790\n",
522
- "Epoch 388, Loss: 2.0865\n",
523
- "Epoch 389, Loss: 2.1035\n",
524
- "Epoch 390, Loss: 2.0938\n",
525
- "Epoch 391, Loss: 2.1012\n",
526
- "Epoch 392, Loss: 2.0946\n",
527
- "Epoch 393, Loss: 2.0570\n",
528
- "Epoch 394, Loss: 2.0578\n",
529
- "Epoch 395, Loss: 2.0493\n",
530
- "Epoch 396, Loss: 2.0494\n",
531
- "Epoch 397, Loss: 2.0473\n",
532
- "Epoch 398, Loss: 2.0564\n",
533
- "Epoch 399, Loss: 2.0497\n",
534
- "Epoch 400, Loss: 2.0462\n",
535
- "Epoch 401, Loss: 2.0484\n",
536
- "Epoch 402, Loss: 2.0652\n",
537
- "Epoch 403, Loss: 2.0719\n",
538
- "Epoch 404, Loss: 2.1264\n",
539
- "Epoch 405, Loss: 2.0922\n",
540
- "Epoch 406, Loss: 2.0889\n",
541
- "Epoch 407, Loss: 2.0744\n",
542
- "Epoch 408, Loss: 2.0803\n",
543
- "Epoch 409, Loss: 2.0559\n",
544
- "Epoch 410, Loss: 2.0484\n",
545
- "Epoch 411, Loss: 2.0358\n",
546
- "Epoch 412, Loss: 2.0422\n",
547
- "Epoch 413, Loss: 2.0323\n",
548
- "Epoch 414, Loss: 2.0358\n",
549
- "Epoch 415, Loss: 2.0284\n",
550
- "Epoch 416, Loss: 2.0365\n",
551
- "Epoch 417, Loss: 2.0580\n",
552
- "Epoch 418, Loss: 2.0814\n",
553
- "Epoch 419, Loss: 2.0985\n",
554
- "Epoch 420, Loss: 2.0845\n",
555
- "Epoch 421, Loss: 2.1305\n",
556
- "Epoch 422, Loss: 2.1280\n",
557
- "Epoch 423, Loss: 2.0703\n",
558
- "Epoch 424, Loss: 2.0926\n",
559
- "Epoch 425, Loss: 2.0963\n",
560
- "Epoch 426, Loss: 2.0651\n",
561
- "Epoch 427, Loss: 2.0548\n",
562
- "Epoch 428, Loss: 2.0529\n",
563
- "Epoch 429, Loss: 2.0274\n",
564
- "Epoch 430, Loss: 2.0400\n",
565
- "Epoch 431, Loss: 2.0409\n",
566
- "Epoch 432, Loss: 2.0379\n",
567
- "Epoch 433, Loss: 2.0234\n",
568
- "Epoch 434, Loss: 2.0314\n",
569
- "Epoch 435, Loss: 1.9965\n",
570
- "Epoch 436, Loss: 2.0345\n",
571
- "Epoch 437, Loss: 2.0361\n",
572
- "Epoch 438, Loss: 2.0215\n",
573
- "Epoch 439, Loss: 2.0387\n",
574
- "Epoch 440, Loss: 2.0397\n",
575
- "Epoch 441, Loss: 2.0126\n",
576
- "Epoch 442, Loss: 2.0365\n",
577
- "Epoch 443, Loss: 2.0224\n",
578
- "Epoch 444, Loss: 2.0329\n",
579
- "Epoch 445, Loss: 2.0341\n",
580
- "Epoch 446, Loss: 2.0324\n",
581
- "Epoch 447, Loss: 2.0453\n",
582
- "Epoch 448, Loss: 2.0491\n",
583
- "Epoch 449, Loss: 2.0387\n",
584
- "Epoch 450, Loss: 2.0504\n",
585
- "Epoch 451, Loss: 2.0397\n",
586
- "Epoch 452, Loss: 2.0357\n",
587
- "Epoch 453, Loss: 2.0398\n",
588
- "Epoch 454, Loss: 2.0317\n",
589
- "Epoch 455, Loss: 2.0258\n",
590
- "Epoch 456, Loss: 2.0260\n",
591
- "Epoch 457, Loss: 2.0194\n",
592
- "Epoch 458, Loss: 2.0161\n",
593
- "Epoch 459, Loss: 2.0133\n",
594
- "Epoch 460, Loss: 2.0416\n",
595
- "Epoch 461, Loss: 2.0170\n",
596
- "Epoch 462, Loss: 2.0286\n",
597
- "Epoch 463, Loss: 2.0244\n",
598
- "Epoch 464, Loss: 2.0286\n",
599
- "Epoch 465, Loss: 1.9974\n",
600
- "Epoch 466, Loss: 2.0162\n",
601
- "Epoch 467, Loss: 2.0040\n",
602
- "Epoch 468, Loss: 2.0190\n",
603
- "Epoch 469, Loss: 2.0180\n",
604
- "Epoch 470, Loss: 1.9842\n",
605
- "Epoch 471, Loss: 2.0325\n",
606
- "Epoch 472, Loss: 2.0165\n",
607
- "Epoch 473, Loss: 2.0149\n",
608
- "Epoch 474, Loss: 2.0333\n",
609
- "Epoch 475, Loss: 2.0147\n",
610
- "Epoch 476, Loss: 2.0180\n",
611
- "Epoch 477, Loss: 2.0313\n",
612
- "Epoch 478, Loss: 2.0278\n",
613
- "Epoch 479, Loss: 2.0228\n",
614
- "Epoch 480, Loss: 2.0036\n",
615
- "Epoch 481, Loss: 2.0114\n",
616
- "Epoch 482, Loss: 2.0111\n",
617
- "Epoch 483, Loss: 2.0239\n",
618
- "Epoch 484, Loss: 2.0085\n",
619
- "Epoch 485, Loss: 2.0084\n",
620
- "Epoch 486, Loss: 2.0402\n",
621
- "Epoch 487, Loss: 2.0372\n",
622
- "Epoch 488, Loss: 2.0807\n",
623
- "Epoch 489, Loss: 2.0684\n",
624
- "Epoch 490, Loss: 2.0992\n",
625
- "Epoch 491, Loss: 2.0516\n",
626
- "Epoch 492, Loss: 2.1279\n",
627
- "Epoch 493, Loss: 2.1087\n",
628
- "Epoch 494, Loss: 2.0793\n",
629
- "Epoch 495, Loss: 2.0580\n",
630
- "Epoch 496, Loss: 2.0744\n",
631
- "Epoch 497, Loss: 2.0852\n",
632
- "Epoch 498, Loss: 2.0631\n",
633
- "Epoch 499, Loss: 2.0341\n",
634
- "Epoch 500, Loss: 2.0277\n"
635
  ]
636
  }
637
  ],
638
  "source": [
639
  "# 학습 루프\n",
640
  "\n",
641
- "EPOCH = 500\n",
642
  "\n",
643
  "for i in range(EPOCH):\n",
644
  " total_loss = 0\n",
@@ -648,17 +168,24 @@
648
  " lengths_batch = lengths_batch.to(device)\n",
649
  "\n",
650
  " optimizer.zero_grad()\n",
651
- " outputs = model.forward(X_batch, lengths_batch, total_length=Y_batch.size(1))\n",
 
 
 
 
 
652
  "\n",
653
- " # Loss 계산: (B*T, vocab) vs (B*T)\n",
654
- " outputs = outputs.view(-1, outputs.size(-1))\n",
655
- " targets = Y_batch.view(-1)\n",
656
  "\n",
657
- " loss_f = criterion(outputs, targets)\n",
658
- " loss_f.backward()\n",
 
 
659
  " optimizer.step()\n",
660
  "\n",
661
- " total_loss += loss_f.item()\n",
662
  "\n",
663
  " print(f\"Epoch {i+1}, Loss: {total_loss/len(dataloader):.4f}\")"
664
  ]
@@ -691,10 +218,22 @@
691
  },
692
  {
693
  "cell_type": "code",
694
- "execution_count": 5,
695
  "id": "da89b45a",
696
  "metadata": {},
697
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
698
  "source": [
699
  "import torch\n",
700
  "\n",
@@ -704,536 +243,87 @@
704
  },
705
  {
706
  "cell_type": "code",
707
- "execution_count": 14,
708
  "id": "75530554",
709
  "metadata": {},
710
  "outputs": [
711
  {
712
  "data": {
713
  "text/plain": [
714
- "[296,\n",
715
- " 12,\n",
716
- " 73,\n",
717
- " 4,\n",
718
- " 12,\n",
719
- " 84,\n",
720
- " 22,\n",
721
- " 76,\n",
722
- " 4,\n",
723
- " 12,\n",
724
- " 22,\n",
725
- " 4,\n",
726
- " 12,\n",
727
- " 12,\n",
728
- " 2,\n",
729
- " 4,\n",
730
- " 12,\n",
731
- " 48,\n",
732
- " 2,\n",
733
- " 4,\n",
734
- " 12,\n",
735
- " 4,\n",
736
- " 4,\n",
737
- " 3,\n",
738
- " 3,\n",
739
- " 4,\n",
740
- " 73,\n",
741
- " 22,\n",
742
- " 3,\n",
743
- " 3,\n",
744
- " 12,\n",
745
- " 83,\n",
746
- " 12,\n",
747
- " 12,\n",
748
- " 4,\n",
749
- " 83,\n",
750
- " 93,\n",
751
- " 4,\n",
752
- " 22,\n",
753
- " 4,\n",
754
- " 10,\n",
755
- " 86,\n",
756
- " 12,\n",
757
- " 3,\n",
758
- " 86,\n",
759
- " 63,\n",
760
- " 12,\n",
761
- " 12,\n",
762
- " 3,\n",
763
- " 3,\n",
764
- " 1,\n",
765
- " 50,\n",
766
- " 12,\n",
767
- " 1,\n",
768
- " 4,\n",
769
- " 12,\n",
770
- " 84,\n",
771
- " 12,\n",
772
- " 3,\n",
773
- " 12,\n",
774
- " 78,\n",
775
- " 4,\n",
776
- " 12,\n",
777
- " 1,\n",
778
- " 53,\n",
779
- " 4,\n",
780
- " 67,\n",
781
- " 4,\n",
782
- " 4,\n",
783
- " 12,\n",
784
- " 4,\n",
785
- " 4,\n",
786
- " 12,\n",
787
- " 12,\n",
788
- " 1,\n",
789
- " 67,\n",
790
- " 3,\n",
791
- " 22,\n",
792
- " 79,\n",
793
- " 4,\n",
794
- " 4,\n",
795
- " 1,\n",
796
- " 3,\n",
797
- " 8,\n",
798
- " 4,\n",
799
- " 57,\n",
800
- " 86,\n",
801
- " 45,\n",
802
- " 53,\n",
803
- " 3,\n",
804
- " 4,\n",
805
- " 12,\n",
806
- " 73,\n",
807
- " 3,\n",
808
- " 302,\n",
809
- " 12,\n",
810
- " 22,\n",
811
- " 62,\n",
812
- " 3,\n",
813
- " 71,\n",
814
- " 1,\n",
815
- " 51,\n",
816
- " 4,\n",
817
- " 12,\n",
818
- " 71,\n",
819
- " 45,\n",
820
- " 1,\n",
821
- " 1,\n",
822
- " 80,\n",
823
- " 1,\n",
824
- " 1,\n",
825
- " 4,\n",
826
- " 302,\n",
827
- " 64,\n",
828
- " 8,\n",
829
- " 1,\n",
830
- " 12,\n",
831
- " 3,\n",
832
- " 4,\n",
833
- " 12,\n",
834
- " 12,\n",
835
- " 1,\n",
836
- " 1,\n",
837
- " 22,\n",
838
- " 3,\n",
839
- " 5,\n",
840
- " 76,\n",
841
- " 66,\n",
842
- " 1,\n",
843
- " 22,\n",
844
- " 56,\n",
845
- " 4,\n",
846
- " 22,\n",
847
- " 4,\n",
848
- " 77,\n",
849
- " 12,\n",
850
- " 22,\n",
851
- " 52,\n",
852
- " 12,\n",
853
- " 3,\n",
854
- " 12,\n",
855
- " 80,\n",
856
- " 4,\n",
857
- " 12,\n",
858
- " 22,\n",
859
- " 12,\n",
860
- " 50,\n",
861
- " 4,\n",
862
- " 86,\n",
863
- " 4,\n",
864
- " 22,\n",
865
- " 5,\n",
866
- " 4,\n",
867
- " 43,\n",
868
- " 4,\n",
869
- " 3,\n",
870
- " 4,\n",
871
- " 64,\n",
872
- " 3,\n",
873
- " 12,\n",
874
- " 5,\n",
875
- " 12,\n",
876
- " 85,\n",
877
- " 4,\n",
878
- " 12,\n",
879
- " 22,\n",
880
- " 6,\n",
881
- " 6,\n",
882
- " 3,\n",
883
- " 53,\n",
884
- " 1,\n",
885
- " 12,\n",
886
- " 12,\n",
887
- " 12,\n",
888
- " 68,\n",
889
- " 4,\n",
890
- " 63,\n",
891
- " 3,\n",
892
- " 86,\n",
893
- " 3,\n",
894
- " 12,\n",
895
- " 22,\n",
896
- " 22,\n",
897
- " 130,\n",
898
- " 90,\n",
899
- " 69,\n",
900
- " 4,\n",
901
- " 4,\n",
902
- " 1,\n",
903
- " 4,\n",
904
- " 3,\n",
905
- " 12,\n",
906
- " 1,\n",
907
- " 3,\n",
908
- " 1,\n",
909
- " 1,\n",
910
- " 4,\n",
911
- " 1,\n",
912
- " 3,\n",
913
- " 5,\n",
914
- " 49,\n",
915
- " 65,\n",
916
- " 4,\n",
917
- " 1,\n",
918
- " 6,\n",
919
- " 202,\n",
920
- " 1,\n",
921
- " 81,\n",
922
- " 67,\n",
923
- " 52,\n",
924
- " 12,\n",
925
- " 7,\n",
926
- " 12,\n",
927
- " 3,\n",
928
- " 3,\n",
929
- " 5,\n",
930
- " 4,\n",
931
- " 12,\n",
932
- " 4,\n",
933
- " 3,\n",
934
- " 12,\n",
935
- " 12,\n",
936
- " 4,\n",
937
- " 5,\n",
938
- " 77,\n",
939
- " 4,\n",
940
- " 3,\n",
941
- " 3,\n",
942
- " 84,\n",
943
- " 78,\n",
944
- " 63,\n",
945
- " 83,\n",
946
- " 4,\n",
947
- " 46,\n",
948
- " 22,\n",
949
- " 22,\n",
950
- " 7,\n",
951
- " 12,\n",
952
- " 177,\n",
953
- " 62,\n",
954
- " 4,\n",
955
- " 70,\n",
956
- " 66,\n",
957
- " 7,\n",
958
- " 1,\n",
959
- " 79,\n",
960
- " 82,\n",
961
- " 4,\n",
962
- " 6,\n",
963
- " 22,\n",
964
- " 12,\n",
965
- " 4,\n",
966
- " 53,\n",
967
- " 4,\n",
968
- " 4,\n",
969
- " 12,\n",
970
- " 1,\n",
971
- " 22,\n",
972
- " 12,\n",
973
- " 3,\n",
974
- " 48,\n",
975
- " 12,\n",
976
- " 4,\n",
977
- " 12,\n",
978
- " 4,\n",
979
- " 53,\n",
980
- " 12,\n",
981
- " 12,\n",
982
- " 3,\n",
983
- " 4,\n",
984
- " 1,\n",
985
- " 12,\n",
986
- " 3,\n",
987
- " 22,\n",
988
- " 12,\n",
989
- " 12,\n",
990
- " 76,\n",
991
- " 12,\n",
992
- " 78,\n",
993
- " 22,\n",
994
- " 22,\n",
995
- " 4,\n",
996
- " 78,\n",
997
- " 4,\n",
998
- " 3,\n",
999
- " 1,\n",
1000
- " 4,\n",
1001
- " 6,\n",
1002
- " 5,\n",
1003
- " 64,\n",
1004
- " 4,\n",
1005
- " 4,\n",
1006
- " 47,\n",
1007
- " 22,\n",
1008
- " 22,\n",
1009
- " 1,\n",
1010
- " 12,\n",
1011
- " 3,\n",
1012
- " 3,\n",
1013
- " 68,\n",
1014
- " 4,\n",
1015
- " 1,\n",
1016
- " 22,\n",
1017
- " 12,\n",
1018
- " 22,\n",
1019
- " 3,\n",
1020
- " 12,\n",
1021
- " 12,\n",
1022
- " 4,\n",
1023
- " 1,\n",
1024
- " 3,\n",
1025
- " 3,\n",
1026
- " 1,\n",
1027
- " 7,\n",
1028
- " 4,\n",
1029
- " 3,\n",
1030
- " 12,\n",
1031
- " 81,\n",
1032
- " 3,\n",
1033
- " 49,\n",
1034
- " 4,\n",
1035
- " 12,\n",
1036
- " 1,\n",
1037
- " 88,\n",
1038
- " 4,\n",
1039
- " 4,\n",
1040
- " 66,\n",
1041
- " 22,\n",
1042
- " 1,\n",
1043
- " 12,\n",
1044
- " 45,\n",
1045
- " 78,\n",
1046
- " 78,\n",
1047
- " 22,\n",
1048
- " 12,\n",
1049
- " 6,\n",
1050
- " 12,\n",
1051
- " 52,\n",
1052
- " 47,\n",
1053
- " 4,\n",
1054
- " 12,\n",
1055
- " 76,\n",
1056
- " 5,\n",
1057
- " 12,\n",
1058
- " 64,\n",
1059
- " 52,\n",
1060
- " 12,\n",
1061
- " 4,\n",
1062
- " 22,\n",
1063
- " 4,\n",
1064
- " 4,\n",
1065
- " 202,\n",
1066
- " 1,\n",
1067
- " 22,\n",
1068
- " 22,\n",
1069
- " 73,\n",
1070
- " 65,\n",
1071
- " 4,\n",
1072
- " 1,\n",
1073
- " 1,\n",
1074
- " 3,\n",
1075
- " 22,\n",
1076
- " 6,\n",
1077
- " 3,\n",
1078
- " 12,\n",
1079
- " 12,\n",
1080
- " 69,\n",
1081
- " 58,\n",
1082
- " 84,\n",
1083
- " 5,\n",
1084
- " 4,\n",
1085
- " 12,\n",
1086
- " 1,\n",
1087
- " 12,\n",
1088
- " 22,\n",
1089
- " 12,\n",
1090
- " 51,\n",
1091
- " 1,\n",
1092
- " 1,\n",
1093
- " 22,\n",
1094
- " 1,\n",
1095
- " 12,\n",
1096
- " 4,\n",
1097
- " 4,\n",
1098
- " 4,\n",
1099
- " 4,\n",
1100
- " 3,\n",
1101
- " 3,\n",
1102
- " 7,\n",
1103
- " 4,\n",
1104
- " 84,\n",
1105
- " 22,\n",
1106
- " 12,\n",
1107
- " 4,\n",
1108
- " 3,\n",
1109
- " 66,\n",
1110
- " 51,\n",
1111
- " 22,\n",
1112
- " 49,\n",
1113
- " 4,\n",
1114
- " 4,\n",
1115
- " 64,\n",
1116
- " 1,\n",
1117
- " 12,\n",
1118
- " 56,\n",
1119
- " 12,\n",
1120
- " 54,\n",
1121
- " 3,\n",
1122
- " 77,\n",
1123
- " 4,\n",
1124
- " 4,\n",
1125
- " 71,\n",
1126
- " 4,\n",
1127
- " 12,\n",
1128
- " 3,\n",
1129
- " 22,\n",
1130
- " 76,\n",
1131
- " 45,\n",
1132
- " 12,\n",
1133
- " 4,\n",
1134
- " 82,\n",
1135
- " 4,\n",
1136
- " 22,\n",
1137
- " 1,\n",
1138
- " 12,\n",
1139
- " 49,\n",
1140
- " 4,\n",
1141
- " 12,\n",
1142
- " 1,\n",
1143
- " 12,\n",
1144
- " 22,\n",
1145
- " 4,\n",
1146
- " 22,\n",
1147
- " 12,\n",
1148
- " 45,\n",
1149
- " 73,\n",
1150
- " 12,\n",
1151
- " 22,\n",
1152
- " 12,\n",
1153
- " 4,\n",
1154
- " 4,\n",
1155
- " 12,\n",
1156
- " 72,\n",
1157
- " 4,\n",
1158
- " 3,\n",
1159
- " 1,\n",
1160
- " 6,\n",
1161
- " 1,\n",
1162
- " 50,\n",
1163
- " 3,\n",
1164
- " 1,\n",
1165
- " 4,\n",
1166
- " 12,\n",
1167
- " 22,\n",
1168
- " 47,\n",
1169
- " 4,\n",
1170
- " 1,\n",
1171
- " 1,\n",
1172
- " 3,\n",
1173
- " 50,\n",
1174
- " 80,\n",
1175
- " 4,\n",
1176
- " 4,\n",
1177
- " 1,\n",
1178
- " 4,\n",
1179
- " 49,\n",
1180
- " 4,\n",
1181
- " 4,\n",
1182
- " 71,\n",
1183
- " 77,\n",
1184
- " 3,\n",
1185
- " 3,\n",
1186
- " 22,\n",
1187
- " 1,\n",
1188
- " 12,\n",
1189
- " 78,\n",
1190
- " 4,\n",
1191
- " 4,\n",
1192
- " 66,\n",
1193
- " 22,\n",
1194
- " 22,\n",
1195
- " 4,\n",
1196
- " 3,\n",
1197
- " 3,\n",
1198
- " 12,\n",
1199
- " 73,\n",
1200
- " 1,\n",
1201
- " 3,\n",
1202
- " 12,\n",
1203
- " 22,\n",
1204
- " 4,\n",
1205
- " 3,\n",
1206
- " 12,\n",
1207
- " 5,\n",
1208
- " 4,\n",
1209
- " 12,\n",
1210
- " 3,\n",
1211
- " 22,\n",
1212
- " 12,\n",
1213
- " 12,\n",
1214
- " 12,\n",
1215
- " 12,\n",
1216
- " 3,\n",
1217
- " 12,\n",
1218
- " 7,\n",
1219
- " 11,\n",
1220
- " 12,\n",
1221
- " 4,\n",
1222
- " 22,\n",
1223
- " 66,\n",
1224
- " 12,\n",
1225
- " 12]"
1226
  ]
1227
  },
1228
- "execution_count": 14,
1229
  "metadata": {},
1230
  "output_type": "execute_result"
1231
  }
1232
  ],
1233
  "source": [
1234
- "model.load_state_dict(torch.load('DIVA_Model_dict.pt')) # 모델 가중치, 매개변수 불러오기\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1235
  "\n",
1236
- "model.generate(X_tensor[0], device=device) # 스타일 벡터 하나로 시퀀스 생성"
 
1237
  ]
1238
  }
1239
  ],
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": null,
36
  "id": "630dd7ad",
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
40
  "from Models.Vector2MIDI import Vector2MIDI # 클래스 정의가 필요\n",
41
  "import torch.optim as optim\n",
42
+ "from torch.nn import HuberLoss\n",
43
+ "from pysdtw import SoftDTW\n",
44
+ "#from utility.lossf import get_loss_function # 나중에 직접 해보자\n",
45
  "import torch\n",
46
  "\n",
47
  "device = torch.device(\"cuda\") # GPU 사용\n",
48
  "#device = torch.device(\"cpu\") # CPU 사용\n",
49
  "\n",
50
+ "model = Vector2MIDI(25, 1024, 7).to(device)\n",
51
+ "sdtw = SoftDTW(0.6) # Soft Dynamic Time Warping (timestep 끼리 비교해 loss 계산 -> gradient 가 흐르도록 함) https://judy-son.tistory.com/3\n",
52
+ "huber = HuberLoss(reduction='none', delta=1.0).to(device) # HuberLoss (reduction='none'로 개별 timestep loss 계산)\n",
53
  "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
54
  ]
55
  },
 
64
  "output_type": "stream",
65
  "text": [
66
  "X_tensor shape: torch.Size([34, 25])\n",
67
+ "Y_tensor shape: torch.Size([34, 125, 7])\n",
68
  "lengths shape: torch.Size([34])\n"
69
  ]
70
  }
 
72
  "source": [
73
  "# 전처리 데이터 로드\n",
74
  "from torch.utils.data import DataLoader\n",
75
+ "from utility.dataset import MIDIDataset\n",
76
  "import torch\n",
77
  "\n",
78
  "data = torch.load(\"DIVA_dataset.pt\")\n",
79
+ "X_tensor = data[\"X\"].float()\n",
80
+ "Y_tensor = data[\"Y\"].float()\n",
81
  "lengths = data[\"lengths\"]\n",
82
  "\n",
83
  "print(\"X_tensor shape:\", X_tensor.shape)\n",
 
127
  },
128
  {
129
  "cell_type": "code",
130
+ "execution_count": 10,
131
  "id": "16a14b5f",
132
  "metadata": {},
133
  "outputs": [
134
+ {
135
+ "name": "stderr",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "c:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\numba\\cuda\\dispatcher.py:536: NumbaPerformanceWarning: \u001b[1mGrid size 8 will likely result in GPU under-utilization due to low occupancy.\u001b[0m\n",
139
+ " warn(NumbaPerformanceWarning(msg))\n"
140
+ ]
141
+ },
142
  {
143
  "name": "stdout",
144
  "output_type": "stream",
145
  "text": [
146
+ "Epoch 1, Loss: 123961.2219\n"
147
+ ]
148
+ },
149
+ {
150
+ "name": "stderr",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "c:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\numba\\cuda\\dispatcher.py:536: NumbaPerformanceWarning: \u001b[1mGrid size 2 will likely result in GPU under-utilization due to low occupancy.\u001b[0m\n",
154
+ " warn(NumbaPerformanceWarning(msg))\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ]
156
  }
157
  ],
158
  "source": [
159
  "# 학습 루프\n",
160
  "\n",
161
+ "EPOCH = 1\n",
162
  "\n",
163
  "for i in range(EPOCH):\n",
164
  " total_loss = 0\n",
 
168
  " lengths_batch = lengths_batch.to(device)\n",
169
  "\n",
170
  " optimizer.zero_grad()\n",
171
+ " outputs = model(X_batch, lengths_batch, Y_batch)\n",
172
+ "\n",
173
+ " min_len = min(outputs.size(1), Y_batch.size(1))\n",
174
+ "\n",
175
+ " loss_HL = huber(outputs[:, :min_len, :], Y_batch[:, :min_len, :]) # 슬라이싱을 이용해 output과 target(Y) 길이가 달라도 loss 측정 가능\n",
176
+ " loss_HL = loss_HL.mean(dim=2) # (B, T), 7차원 평균\n",
177
  "\n",
178
+ " max_len = Y_batch.size(1)\n",
179
+ " mask = torch.arange(max_len, device=device).unsqueeze(0) < lengths_batch.unsqueeze(1) # (B, T)\n",
180
+ " loss_HL = (loss_HL * mask[:, :min_len]).sum() / mask[:, :min_len].sum() # huber만 padding 제외 (sdtw랑 shape가 달라서)\n",
181
  "\n",
182
+ " loss_sdtw = sdtw(outputs, Y_batch).mean() # 스칼라\n",
183
+ " loss = 0.7*loss_HL+0.3*loss_sdtw # 가중합(다른 loss function 동시에 사용 가능)\n",
184
+ "\n",
185
+ " loss.backward()\n",
186
  " optimizer.step()\n",
187
  "\n",
188
+ " total_loss += loss.item()\n",
189
  "\n",
190
  " print(f\"Epoch {i+1}, Loss: {total_loss/len(dataloader):.4f}\")"
191
  ]
 
218
  },
219
  {
220
  "cell_type": "code",
221
+ "execution_count": 2,
222
  "id": "da89b45a",
223
  "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "ename": "NameError",
227
+ "evalue": "name 'model' is not defined",
228
+ "output_type": "error",
229
+ "traceback": [
230
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
231
+ "\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
232
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m torch.save(\u001b[43mmodel\u001b[49m.state_dict(), \u001b[33m'\u001b[39m\u001b[33mDIVA_Model_dict.pt\u001b[39m\u001b[33m'\u001b[39m) \u001b[38;5;66;03m# 모델 가중치, 매개변수 저장\u001b[39;00m\n\u001b[32m 4\u001b[39m torch.save(model, \u001b[33m'\u001b[39m\u001b[33mDIVA_Model_full.pt\u001b[39m\u001b[33m'\u001b[39m) \u001b[38;5;66;03m# 모델 전체 저장\u001b[39;00m\n",
233
+ "\u001b[31mNameError\u001b[39m: name 'model' is not defined"
234
+ ]
235
+ }
236
+ ],
237
  "source": [
238
  "import torch\n",
239
  "\n",
 
243
  },
244
  {
245
  "cell_type": "code",
246
+ "execution_count": 11,
247
  "id": "75530554",
248
  "metadata": {},
249
  "outputs": [
250
  {
251
  "data": {
252
  "text/plain": [
253
+ "<All keys matched successfully>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  ]
255
  },
256
+ "execution_count": 11,
257
  "metadata": {},
258
  "output_type": "execute_result"
259
  }
260
  ],
261
  "source": [
262
+ "model.load_state_dict(torch.load('DIVA_Model_dict.pt')) # 모델 가중치, 매개변수 불러오기"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 14,
268
+ "id": "6c7f2aa0",
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "name": "stdout",
273
+ "output_type": "stream",
274
+ "text": [
275
+ "[128, 100, 10, -1, 2, 10, 3, -1, 10, 10, 81, 10, -1, 10, 84, 1, 81, 10, -1, 79, 10, 10, 10, -1, 1, 2, 10, 1, 2, 1, 1, 79, 10, 10, 1, 10, 84, 2, 1, 86, 1, 84, 1, 84, 83, 10, 10, 1, 84, 83, -1, 10, 10, 10, -1, 10, -1, 2, 10, 10, -1, 10, 10, 81, 1, 83, 1, 10, 1, 10, 1, 81, 1, 10, 2, 10, 10, 10, 84, 10, -1, 1, 84, 10, -1, 10, 10, 1, 10, 10, 84, 1, 10, -1, 1, 2, 2, 10, 2, 83, 3, 10, 84, 10, 10, -1, 84, 83, 81, 2, 2, 10, 10, 10, 10, -1, 10, 81, 79, 2, 2, 1, 84, 10, 10, 10, 1, 1, 10, 10, 10, 1, 3, 81, 10, 10, 1, 2, 2, 10, 84, 2, 79, 1, 10, 91, -1, 86, 81, 84, -1, 84, 10, -1, 10, 2, 10, 10, 10, 84, 83, 79, 10, -1, -1, 1, 3, 10, 84, 1, 84, 2, 10, 1, -1, 10, 10, 81, 84, 10, 83, 2, 84, 1, 10, 2, 10, 3, -1, 10, 2, 1, 84, 1, 84, 84, 10, 10, 86, 10, 84, 83, 2, 3, -1, 10, 3, 91, 84, 10, 84, 2, 10, 10, 10, 83, 84, 84, 2, 10, 10, 10, 10, 91, 84, 10, 2, 2, 2, 1, 2, 60, 2, 65, 20, 1, 200, 1, 1, 2, 2, 2, 2, 3, 20, 1, 1, 20, 1, 3, 1, 1, 1, 20, -1, 2, 1, 2, 1, 20, 60, 1, 1, 20, 20, 1, 20, 8, 1, 8, 20, 8, 20, 1, 1, 20, -1, 20, 1, -1, 60, 1, 8, 60, 2, 2, -1, 8, 20, -1, 60, -1, -1, 60, 1, 8, 20, 60, 20, 60, 20, 65, 8, 20, 5, -1, 60, 20, 67, 60, 69, -1, 8, 1, -1, 69, 65, 69, 20, 20, 69, 2, 5, -1, 20, 5, 8, 2, 69, 2, 1, 5, -1, 67, 20, -1, 5, 69, 20, 2, 20, 2, 2, 1, 69, -1, 2, 2, 1, 8, 8, 67, 5, 2, 20, 65, 1, 3, 20, 1, 2, 20, 67, 1, -1, -1, 20, 2, 65, -1, 67, 300, 20, 1, 2, 1, -1, 1, 20, 1, 20, 67, 60, 20, 65, 20, -1, 20, -1, 67, -1, 20, 1, 5, 2, 20, 20, -1, 20, -1, 69, 20, 20, 20, -1, -1, -1, 2, -1, 1, 1, -1, -1, 2, 2, 1, 2, 2, 20, 1, 1, -1, 20, 2, 3, 67, 2, 20, 60, -1, 2, 1, 60, -1, 1, 20, 20, 20, 20, 1, 69, -1, 1, 2, -1, 60, 20, 2, 60, 65, -1, 8, -1, 20, -1, 20, 4, 2, 20, -1, 20, 20, 8, 65, 5, 1, 8, -1, 69, 1, 20, 1, 69, -1, 1, 67, 69, 20, 8, -1, 2, 1, 2, -1, 1, -1, 2, -1, 1, 67, 65, 20, 1, 67, 20, 65, 20, 5, 2, 2, 1, 2, -1, -1, -1, 67, -1, 2, 65, -1, 1, 67]\n"
276
+ ]
277
+ },
278
+ {
279
+ "ename": "ValueError",
280
+ "evalue": "invalid literal for int() with base 10: ''",
281
+ "output_type": "error",
282
+ "traceback": [
283
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
284
+ "\u001b[31mValueError\u001b[39m Traceback (most recent call last)",
285
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(token)\n\u001b[32m 7\u001b[39m MIDI = Tokenizer()\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m \u001b[43mMIDI\u001b[49m\u001b[43m.\u001b[49m\u001b[43mset_id\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 10\u001b[39m midi= MIDI.to_midi() \u001b[38;5;66;03m# This should generate MIDI from the stored melody and chords\u001b[39;00m\n\u001b[32m 11\u001b[39m midi.write(\u001b[33m'\u001b[39m\u001b[33mmidi\u001b[39m\u001b[33m'\u001b[39m, fp=\u001b[33m'\u001b[39m\u001b[33mtest_output.mid\u001b[39m\u001b[33m'\u001b[39m) \u001b[38;5;66;03m# Save the generated MIDI to a file\u001b[39;00m\n",
286
+ "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\HarmonyMIDIToken\\tokenizer.py:189\u001b[39m, in \u001b[36mHarmonyMIDIToken.set_id\u001b[39m\u001b[34m(self, token_id)\u001b[39m\n\u001b[32m 186\u001b[39m bass_tokens = token_id[token_id.index(\u001b[32m300\u001b[39m)+\u001b[32m1\u001b[39m:]\n\u001b[32m 188\u001b[39m \u001b[38;5;28mself\u001b[39m.melody = \u001b[38;5;28mself\u001b[39m._detokenize_note(melody_tokens)\n\u001b[32m--> \u001b[39m\u001b[32m189\u001b[39m \u001b[38;5;28mself\u001b[39m.chords = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_detokenize_chord\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchords_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 190\u001b[39m \u001b[38;5;28mself\u001b[39m.bass = \u001b[38;5;28mself\u001b[39m._detokenize_note(bass_tokens)\n",
287
+ "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\HarmonyMIDIToken\\tokenizer.py:166\u001b[39m, in \u001b[36mHarmonyMIDIToken._detokenize_chord\u001b[39m\u001b[34m(self, token)\u001b[39m\n\u001b[32m 164\u001b[39m output.append({\u001b[33m\"\u001b[39m\u001b[33mchord\u001b[39m\u001b[33m\"\u001b[39m: \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mduration\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mfloat\u001b[39m(chord_list[-\u001b[32m2\u001b[39m])/\u001b[32m4\u001b[39m})\n\u001b[32m 165\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m166\u001b[39m output.append({\u001b[33m\"\u001b[39m\u001b[33mchord\u001b[39m\u001b[33m\"\u001b[39m:\u001b[38;5;28mself\u001b[39m._intpitch_to_note_name(\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mchord_list\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m)[:-\u001b[32m1\u001b[39m]+inverse_quality_map[\u001b[38;5;28mint\u001b[39m(chord_list[\u001b[32m2\u001b[39m])], \u001b[33m\"\u001b[39m\u001b[33mduration\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mfloat\u001b[39m(chord_list[-\u001b[32m2\u001b[39m])/\u001b[32m4\u001b[39m})\n\u001b[32m 168\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
288
+ "\u001b[31mValueError\u001b[39m: invalid literal for int() with base 10: ''"
289
+ ]
290
+ }
291
+ ],
292
+ "source": [
293
+ "from HarmonyMIDIToken import HarmonyMIDIToken as Tokenizer\n",
294
+ "\n",
295
+ "Y = model.generate(X_tensor[0], device=device) # 스타일 벡터 하나로 시퀀스 생성\n",
296
+ "token = [i-2 for i in Y]\n",
297
+ "print(token)\n",
298
+ "\n",
299
+ "MIDI = Tokenizer()\n",
300
+ "MIDI.set_id(token)\n",
301
+ "\n",
302
+ "midi= MIDI.to_midi() # This should generate MIDI from the stored melody and chords\n",
303
+ "midi.write('midi', fp='test_output.mid') # Save the generated MIDI to a file"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 7,
309
+ "id": "b2a75eeb",
310
+ "metadata": {},
311
+ "outputs": [
312
+ {
313
+ "name": "stdout",
314
+ "output_type": "stream",
315
+ "text": [
316
+ "X shape: torch.Size([10, 5])\n",
317
+ "Y shape: torch.Size([10, 9])\n"
318
+ ]
319
+ }
320
+ ],
321
+ "source": [
322
+ "X = torch.rand((10, 5, 7), device=device, requires_grad=True)\n",
323
+ "Y = torch.rand((10, 9, 7), device=device)\n",
324
  "\n",
325
+ "print(\"X shape:\", X.shape[:2])\n",
326
+ "print(\"Y shape:\", Y.shape[:2])"
327
  ]
328
  }
329
  ],