rrayy commited on
Commit
68d09e4
·
1 Parent(s): 4b7244e

Changes to be committed: 모델 학습 완료

Browse files

new file: DIVA_Model_dict.pt
new file: DIVA_Model_full.pt
modified: DIVA_dataset.pt
modified: preprocessing.ipynb
modified: train.ipynb

Files changed (5) hide show
  1. DIVA_Model_dict.pt +3 -0
  2. DIVA_Model_full.pt +3 -0
  3. DIVA_dataset.pt +1 -1
  4. preprocessing.ipynb +10 -10
  5. train.ipynb +164 -25
DIVA_Model_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547e6a824560bb6f5ef6b097f468fbe6a5ec24efc9ff3d028d1e1ecedb35a0d0
3
+ size 1517753
DIVA_Model_full.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be154505e564927f6d12e0832bf43bb3f17082c97f7408b5692b8af4e9eb851c
3
+ size 1519609
DIVA_dataset.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5d5843f8e01521ce7f7177a76674a83c6444abf577418df5de3db9a34cb5e08f
3
  size 328142
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e956b6342df72c271210930fb6ce094c75c61b5c9d4e155966687599912791b
3
  size 328142
preprocessing.ipynb CHANGED
@@ -1502,7 +1502,7 @@
1502
  },
1503
  {
1504
  "cell_type": "code",
1505
- "execution_count": 3,
1506
  "id": "f7b77c0c",
1507
  "metadata": {},
1508
  "outputs": [],
@@ -1540,7 +1540,7 @@
1540
  },
1541
  {
1542
  "cell_type": "code",
1543
- "execution_count": 4,
1544
  "id": "769af33a",
1545
  "metadata": {},
1546
  "outputs": [
@@ -1548,13 +1548,13 @@
1548
  "name": "stdout",
1549
  "output_type": "stream",
1550
  "text": [
1551
- "min target: -2\n",
1552
- "unique values: tensor([ -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,\n",
1553
- " 12, 16, 22, 31, 35, 38, 43, 44, 45, 46, 47, 48, 49, 50,\n",
1554
- " 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65,\n",
1555
- " 66, 67, 68, 69, 70, 71, 72, 73, 75, 76, 77, 78, 79, 80,\n",
1556
- " 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94,\n",
1557
- " 95, 96, 100, 102, 130, 202, 302])\n"
1558
  ]
1559
  }
1560
  ],
@@ -1565,7 +1565,7 @@
1565
  },
1566
  {
1567
  "cell_type": "code",
1568
- "execution_count": 5,
1569
  "id": "4f5f5dc1",
1570
  "metadata": {},
1571
  "outputs": [],
 
1502
  },
1503
  {
1504
  "cell_type": "code",
1505
+ "execution_count": 7,
1506
  "id": "f7b77c0c",
1507
  "metadata": {},
1508
  "outputs": [],
 
1540
  },
1541
  {
1542
  "cell_type": "code",
1543
+ "execution_count": 8,
1544
  "id": "769af33a",
1545
  "metadata": {},
1546
  "outputs": [
 
1548
  "name": "stdout",
1549
  "output_type": "stream",
1550
  "text": [
1551
+ "min target: 0\n",
1552
+ "unique values: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16,\n",
1553
+ " 22, 31, 35, 38, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52,\n",
1554
+ " 53, 54, 55, 56, 57, 58, 59, 62, 63, 64, 65, 66, 67, 68,\n",
1555
+ " 69, 70, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83,\n",
1556
+ " 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 100,\n",
1557
+ " 102, 130, 202, 302])\n"
1558
  ]
1559
  }
1560
  ],
 
1565
  },
1566
  {
1567
  "cell_type": "code",
1568
+ "execution_count": 9,
1569
  "id": "4f5f5dc1",
1570
  "metadata": {},
1571
  "outputs": [],
train.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 11,
36
  "id": "630dd7ad",
37
  "metadata": {},
38
  "outputs": [],
@@ -42,17 +42,17 @@
42
  "import torch.nn as nn\n",
43
  "import torch\n",
44
  "\n",
45
- "#device = torch.device(\"cuda\") # GPU 사용\n",
46
- "device = torch.device(\"cpu\") # CPU 사용\n",
47
  "\n",
48
- "model = Vector2MIDI(25, 128, 320).to(device)\n",
49
  "criterion = nn.CrossEntropyLoss(ignore_index=0) # 손실함수 패딩(0) 무시\n",
50
  "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
51
  ]
52
  },
53
  {
54
  "cell_type": "code",
55
- "execution_count": 5,
56
  "id": "f8c4a838",
57
  "metadata": {},
58
  "outputs": [
@@ -87,7 +87,7 @@
87
  },
88
  {
89
  "cell_type": "code",
90
- "execution_count": 3,
91
  "id": "4e0ea127",
92
  "metadata": {},
93
  "outputs": [],
@@ -124,7 +124,7 @@
124
  },
125
  {
126
  "cell_type": "code",
127
- "execution_count": 12,
128
  "id": "16a14b5f",
129
  "metadata": {},
130
  "outputs": [
@@ -132,26 +132,142 @@
132
  "name": "stdout",
133
  "output_type": "stream",
134
  "text": [
135
- "input to forward: torch.Size([8, 25])\n",
136
- "outputs shape: torch.Size([8, 1185, 320])\n",
137
- "Y_batch shape: torch.Size([8, 1185])\n",
138
- "outputs(view) shape: torch.Size([9480, 320])\n",
139
- "targets(view) shape: torch.Size([9480])\n"
140
  ]
141
  },
142
  {
143
- "ename": "IndexError",
144
- "evalue": "Target -1 is out of bounds.",
145
- "output_type": "error",
146
- "traceback": [
147
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
148
- "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)",
149
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 25\u001b[39m\n\u001b[32m 22\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33moutputs(view) shape:\u001b[39m\u001b[33m\"\u001b[39m, outputs.shape)\n\u001b[32m 23\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mtargets(view) shape:\u001b[39m\u001b[33m\"\u001b[39m, targets.shape)\n\u001b[32m---> \u001b[39m\u001b[32m25\u001b[39m loss_f = \u001b[43mcriterion\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 26\u001b[39m loss_f.backward()\n\u001b[32m 27\u001b[39m optimizer.step()\n",
150
- "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
151
- "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
152
- "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\torch\\nn\\modules\\loss.py:1297\u001b[39m, in \u001b[36mCrossEntropyLoss.forward\u001b[39m\u001b[34m(self, input, target)\u001b[39m\n\u001b[32m 1296\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor, target: Tensor) -> Tensor:\n\u001b[32m-> \u001b[39m\u001b[32m1297\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcross_entropy\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1298\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 1299\u001b[39m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1300\u001b[39m \u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1301\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1302\u001b[39m \u001b[43m \u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1303\u001b[39m \u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1304\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
153
- "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\rrayy\\anaconda3\\envs\\diva\\Lib\\site-packages\\torch\\nn\\functional.py:3494\u001b[39m, in \u001b[36mcross_entropy\u001b[39m\u001b[34m(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)\u001b[39m\n\u001b[32m 3492\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m size_average \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 3493\u001b[39m reduction = _Reduction.legacy_get_string(size_average, reduce)\n\u001b[32m-> \u001b[39m\u001b[32m3494\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_C\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_nn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcross_entropy_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3495\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 3496\u001b[39m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3497\u001b[39m \u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3498\u001b[39m \u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3499\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3500\u001b[39m \u001b[43m \u001b[49m\u001b[43mlabel_smoothing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3501\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
154
- "\u001b[31mIndexError\u001b[39m: Target -1 is out of bounds."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ]
156
  }
157
  ],
@@ -202,7 +318,30 @@
202
  "\n",
203
  "2. 토큰 매핑 수정\n",
204
  "- 지금 vocab_size=128이면 유효 인덱스는 0 ~ 127만 가능\n",
205
- "- Rest나 특수 심볼 때문에 128이 들어갔다면 vocab_size를 129 이상으로 늘려야 "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  ]
207
  }
208
  ],
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 2,
36
  "id": "630dd7ad",
37
  "metadata": {},
38
  "outputs": [],
 
42
  "import torch.nn as nn\n",
43
  "import torch\n",
44
  "\n",
45
+ "device = torch.device(\"cuda\") # GPU 사용\n",
46
+ "#device = torch.device(\"cpu\") # CPU 사용\n",
47
  "\n",
48
+ "model = Vector2MIDI(25, 128, 303).to(device)\n",
49
  "criterion = nn.CrossEntropyLoss(ignore_index=0) # 손실함수 패딩(0) 무시\n",
50
  "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
51
  ]
52
  },
53
  {
54
  "cell_type": "code",
55
+ "execution_count": 3,
56
  "id": "f8c4a838",
57
  "metadata": {},
58
  "outputs": [
 
87
  },
88
  {
89
  "cell_type": "code",
90
+ "execution_count": 4,
91
  "id": "4e0ea127",
92
  "metadata": {},
93
  "outputs": [],
 
124
  },
125
  {
126
  "cell_type": "code",
127
+ "execution_count": 5,
128
  "id": "16a14b5f",
129
  "metadata": {},
130
  "outputs": [
 
132
  "name": "stdout",
133
  "output_type": "stream",
134
  "text": [
135
+ "input to forward: torch.Size([8, 25])\n"
 
 
 
 
136
  ]
137
  },
138
  {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "outputs shape: torch.Size([8, 1185, 303])\n",
143
+ "Y_batch shape: torch.Size([8, 1185])\n",
144
+ "outputs(view) shape: torch.Size([9480, 303])\n",
145
+ "targets(view) shape: torch.Size([9480])\n",
146
+ "input to forward: torch.Size([8, 25])\n",
147
+ "outputs shape: torch.Size([8, 1185, 303])\n",
148
+ "Y_batch shape: torch.Size([8, 1185])\n",
149
+ "outputs(view) shape: torch.Size([9480, 303])\n",
150
+ "targets(view) shape: torch.Size([9480])\n",
151
+ "input to forward: torch.Size([8, 25])\n",
152
+ "outputs shape: torch.Size([8, 1185, 303])\n",
153
+ "Y_batch shape: torch.Size([8, 1185])\n",
154
+ "outputs(view) shape: torch.Size([9480, 303])\n",
155
+ "targets(view) shape: torch.Size([9480])\n",
156
+ "input to forward: torch.Size([8, 25])\n",
157
+ "outputs shape: torch.Size([8, 1185, 303])\n",
158
+ "Y_batch shape: torch.Size([8, 1185])\n",
159
+ "outputs(view) shape: torch.Size([9480, 303])\n",
160
+ "targets(view) shape: torch.Size([9480])\n",
161
+ "input to forward: torch.Size([2, 25])\n",
162
+ "outputs shape: torch.Size([2, 1185, 303])\n",
163
+ "Y_batch shape: torch.Size([2, 1185])\n",
164
+ "outputs(view) shape: torch.Size([2370, 303])\n",
165
+ "targets(view) shape: torch.Size([2370])\n",
166
+ "Epoch 1, Loss: 5.5885\n",
167
+ "input to forward: torch.Size([8, 25])\n",
168
+ "outputs shape: torch.Size([8, 1185, 303])\n",
169
+ "Y_batch shape: torch.Size([8, 1185])\n",
170
+ "outputs(view) shape: torch.Size([9480, 303])\n",
171
+ "targets(view) shape: torch.Size([9480])\n",
172
+ "input to forward: torch.Size([8, 25])\n",
173
+ "outputs shape: torch.Size([8, 1185, 303])\n",
174
+ "Y_batch shape: torch.Size([8, 1185])\n",
175
+ "outputs(view) shape: torch.Size([9480, 303])\n",
176
+ "targets(view) shape: torch.Size([9480])\n",
177
+ "input to forward: torch.Size([8, 25])\n",
178
+ "outputs shape: torch.Size([8, 1185, 303])\n",
179
+ "Y_batch shape: torch.Size([8, 1185])\n",
180
+ "outputs(view) shape: torch.Size([9480, 303])\n",
181
+ "targets(view) shape: torch.Size([9480])\n",
182
+ "input to forward: torch.Size([8, 25])\n",
183
+ "outputs shape: torch.Size([8, 1185, 303])\n",
184
+ "Y_batch shape: torch.Size([8, 1185])\n",
185
+ "outputs(view) shape: torch.Size([9480, 303])\n",
186
+ "targets(view) shape: torch.Size([9480])\n",
187
+ "input to forward: torch.Size([2, 25])\n",
188
+ "outputs shape: torch.Size([2, 1185, 303])\n",
189
+ "Y_batch shape: torch.Size([2, 1185])\n",
190
+ "outputs(view) shape: torch.Size([2370, 303])\n",
191
+ "targets(view) shape: torch.Size([2370])\n",
192
+ "Epoch 2, Loss: 4.6946\n",
193
+ "input to forward: torch.Size([8, 25])\n",
194
+ "outputs shape: torch.Size([8, 1185, 303])\n",
195
+ "Y_batch shape: torch.Size([8, 1185])\n",
196
+ "outputs(view) shape: torch.Size([9480, 303])\n",
197
+ "targets(view) shape: torch.Size([9480])\n",
198
+ "input to forward: torch.Size([8, 25])\n",
199
+ "outputs shape: torch.Size([8, 1185, 303])\n",
200
+ "Y_batch shape: torch.Size([8, 1185])\n",
201
+ "outputs(view) shape: torch.Size([9480, 303])\n",
202
+ "targets(view) shape: torch.Size([9480])\n",
203
+ "input to forward: torch.Size([8, 25])\n",
204
+ "outputs shape: torch.Size([8, 1185, 303])\n",
205
+ "Y_batch shape: torch.Size([8, 1185])\n",
206
+ "outputs(view) shape: torch.Size([9480, 303])\n",
207
+ "targets(view) shape: torch.Size([9480])\n",
208
+ "input to forward: torch.Size([8, 25])\n",
209
+ "outputs shape: torch.Size([8, 1185, 303])\n",
210
+ "Y_batch shape: torch.Size([8, 1185])\n",
211
+ "outputs(view) shape: torch.Size([9480, 303])\n",
212
+ "targets(view) shape: torch.Size([9480])\n",
213
+ "input to forward: torch.Size([2, 25])\n",
214
+ "outputs shape: torch.Size([2, 1185, 303])\n",
215
+ "Y_batch shape: torch.Size([2, 1185])\n",
216
+ "outputs(view) shape: torch.Size([2370, 303])\n",
217
+ "targets(view) shape: torch.Size([2370])\n",
218
+ "Epoch 3, Loss: 3.0288\n",
219
+ "input to forward: torch.Size([8, 25])\n",
220
+ "outputs shape: torch.Size([8, 1185, 303])\n",
221
+ "Y_batch shape: torch.Size([8, 1185])\n",
222
+ "outputs(view) shape: torch.Size([9480, 303])\n",
223
+ "targets(view) shape: torch.Size([9480])\n",
224
+ "input to forward: torch.Size([8, 25])\n",
225
+ "outputs shape: torch.Size([8, 1185, 303])\n",
226
+ "Y_batch shape: torch.Size([8, 1185])\n",
227
+ "outputs(view) shape: torch.Size([9480, 303])\n",
228
+ "targets(view) shape: torch.Size([9480])\n",
229
+ "input to forward: torch.Size([8, 25])\n",
230
+ "outputs shape: torch.Size([8, 1185, 303])\n",
231
+ "Y_batch shape: torch.Size([8, 1185])\n",
232
+ "outputs(view) shape: torch.Size([9480, 303])\n",
233
+ "targets(view) shape: torch.Size([9480])\n",
234
+ "input to forward: torch.Size([8, 25])\n",
235
+ "outputs shape: torch.Size([8, 1185, 303])\n",
236
+ "Y_batch shape: torch.Size([8, 1185])\n",
237
+ "outputs(view) shape: torch.Size([9480, 303])\n",
238
+ "targets(view) shape: torch.Size([9480])\n",
239
+ "input to forward: torch.Size([2, 25])\n",
240
+ "outputs shape: torch.Size([2, 1185, 303])\n",
241
+ "Y_batch shape: torch.Size([2, 1185])\n",
242
+ "outputs(view) shape: torch.Size([2370, 303])\n",
243
+ "targets(view) shape: torch.Size([2370])\n",
244
+ "Epoch 4, Loss: 2.9275\n",
245
+ "input to forward: torch.Size([8, 25])\n",
246
+ "outputs shape: torch.Size([8, 1185, 303])\n",
247
+ "Y_batch shape: torch.Size([8, 1185])\n",
248
+ "outputs(view) shape: torch.Size([9480, 303])\n",
249
+ "targets(view) shape: torch.Size([9480])\n",
250
+ "input to forward: torch.Size([8, 25])\n",
251
+ "outputs shape: torch.Size([8, 1185, 303])\n",
252
+ "Y_batch shape: torch.Size([8, 1185])\n",
253
+ "outputs(view) shape: torch.Size([9480, 303])\n",
254
+ "targets(view) shape: torch.Size([9480])\n",
255
+ "input to forward: torch.Size([8, 25])\n",
256
+ "outputs shape: torch.Size([8, 1185, 303])\n",
257
+ "Y_batch shape: torch.Size([8, 1185])\n",
258
+ "outputs(view) shape: torch.Size([9480, 303])\n",
259
+ "targets(view) shape: torch.Size([9480])\n",
260
+ "input to forward: torch.Size([8, 25])\n",
261
+ "outputs shape: torch.Size([8, 1185, 303])\n",
262
+ "Y_batch shape: torch.Size([8, 1185])\n",
263
+ "outputs(view) shape: torch.Size([9480, 303])\n",
264
+ "targets(view) shape: torch.Size([9480])\n",
265
+ "input to forward: torch.Size([2, 25])\n",
266
+ "outputs shape: torch.Size([2, 1185, 303])\n",
267
+ "Y_batch shape: torch.Size([2, 1185])\n",
268
+ "outputs(view) shape: torch.Size([2370, 303])\n",
269
+ "targets(view) shape: torch.Size([2370])\n",
270
+ "Epoch 5, Loss: 2.8112\n"
271
  ]
272
  }
273
  ],
 
318
  "\n",
319
  "2. 토큰 매핑 수정\n",
320
  "- 지금 vocab_size=128이면 유효 인덱스는 0 ~ 127만 가능\n",
321
+ "- Rest나 특수 심볼 때문에 128이 들어갔다면 vocab_size를 129 이상으로 늘려야 함\n",
322
+ "\n",
323
+ "고침!!!!!!"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "id": "e610b924",
329
+ "metadata": {},
330
+ "source": [
331
+ "## 모델 저장"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": 7,
337
+ "id": "da89b45a",
338
+ "metadata": {},
339
+ "outputs": [],
340
+ "source": [
341
+ "import torch\n",
342
+ "\n",
343
+ "torch.save(model.state_dict(), 'DIVA_Model_dict.pt') # 모델 가중치, 매개변수 저장\n",
344
+ "torch.save(model, 'DIVA_Model_full.pt') # 모델 전체 저장"
345
  ]
346
  }
347
  ],