Teeradej Sawettraporn commited on
Commit
d6f1da2
·
verified ·
1 Parent(s): 097e069

Upload torch_audio_classification_demo.ipynb

Browse files
Torch_audio_classification/torch_audio_classification_demo.ipynb ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 14,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import librosa\n",
11
+ "import torch.nn as nn\n",
12
+ "import numpy as np\n",
13
+ "import os\n",
14
+ "import matplotlib.pyplot as plt\n",
15
+ "from torch.utils.data import DataLoader, Dataset, random_split\n",
16
+ "\n",
17
+ "def get_mfcc(wav_file_path):\n",
18
+ " y, sr = librosa.load(wav_file_path, offset=0, duration=30)\n",
19
+ " mfcc = np.array(librosa.feature.mfcc(y=y, sr=sr))\n",
20
+ " return mfcc\n",
21
+ "\n",
22
+ "def get_melspectrogram(wav_file_path):\n",
23
+ " y, sr = librosa.load(wav_file_path, offset=0, duration=30)\n",
24
+ " melspectrogram = np.array(librosa.feature.melspectrogram(y=y, sr=sr))\n",
25
+ " return melspectrogram\n",
26
+ "\n",
27
+ "def get_chroma_vector(wav_file_path):\n",
28
+ " y, sr = librosa.load(wav_file_path, offset=0 , duration= 30)\n",
29
+ " chroma = np.array(librosa.feature.chroma_stft(y=y, sr=sr))\n",
30
+ " return chroma\n",
31
+ "\n",
32
+ "def get_tonnetz(wav_file_path):\n",
33
+ " y, sr = librosa.load(wav_file_path, offset=0, duration= 30)\n",
34
+ " tonnetz = np.array(librosa.feature.tonnetz(y=y, sr=sr))\n",
35
+ " return tonnetz\n",
36
+ "\n",
37
+ "\n",
38
+ "\n"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 23,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "def get_feature(file_path):\n",
48
+ " # Extracting MFCC feature\n",
49
+ " mfcc = get_mfcc(file_path)\n",
50
+ " mfcc_mean = mfcc.mean(axis=1)\n",
51
+ " mfcc_min = mfcc.min(axis=1)\n",
52
+ " mfcc_max = mfcc.max(axis=1)\n",
53
+ " mfcc_feature = np.concatenate((mfcc_mean, mfcc_min, mfcc_max))\n",
54
+ "\n",
55
+ " # Extracting Mel Spectrogram feature\n",
56
+ " melspectrogram = get_melspectrogram(file_path)\n",
57
+ " melspectrogram_mean = melspectrogram.mean(axis=1)\n",
58
+ " melspectrogram_min = melspectrogram.min(axis=1)\n",
59
+ " melspectrogram_max = melspectrogram.max(axis=1)\n",
60
+ " melspectrogram_feature = np.concatenate((melspectrogram_mean, melspectrogram_min, melspectrogram_max))\n",
61
+ "\n",
62
+ " # Extracting chroma vector feature\n",
63
+ " chroma = get_chroma_vector(file_path)\n",
64
+ " chroma_mean = chroma.mean(axis=1)\n",
65
+ " chroma_min = chroma.min(axis=1)\n",
66
+ " chroma_max = chroma.max(axis=1)\n",
67
+ " chroma_feature = np.concatenate((chroma_mean, chroma_min, chroma_max))\n",
68
+ "\n",
69
+ " # Extracting tonnetz feature\n",
70
+ " tntz = get_tonnetz(file_path)\n",
71
+ " tntz_mean = tntz.mean(axis=1)\n",
72
+ " tntz_min = tntz.min(axis=1)\n",
73
+ " tntz_max = tntz.max(axis=1)\n",
74
+ " tntz_feature = np.concatenate((tntz_mean, tntz_min, tntz_max)) \n",
75
+ "\n",
76
+ " feature = np.concatenate((chroma_feature, melspectrogram_feature, mfcc_feature, tntz_feature))\n",
77
+ " \n",
78
+ " # Reshape to fixed size (for example, 128x128)\n",
79
+ " feature = np.resize(feature, (128, 128)) # Resize to 128x128\n",
80
+ " return feature\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 32,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "# Define a custom dataset\n",
90
+ "class AudioDataset(Dataset):\n",
91
+ " def __init__(self, directory, genres):\n",
92
+ " self.features = []\n",
93
+ " self.labels = []\n",
94
+ " for genre in genres:\n",
95
+ " print(\"Calculating features for genre: \" + genre)\n",
96
+ " for file in os.listdir(os.path.join(directory, genre)):\n",
97
+ " file_path = os.path.join(directory, genre, file)\n",
98
+ " self.features.append(get_feature(file_path))\n",
99
+ " label = genres.index(genre)\n",
100
+ " self.labels.append(label)\n",
101
+ " \n",
102
+ " self.features = np.array(self.features)\n",
103
+ " self.labels = np.array(self.labels)\n",
104
+ "\n",
105
+ " def __len__(self):\n",
106
+ " return len(self.labels)\n",
107
+ "\n",
108
+ " def __getitem__(self, idx):\n",
109
+ " feature = torch.tensor(self.features[idx], dtype=torch.float32)\n",
110
+ " label = torch.tensor(self.labels[idx], dtype=torch.long)\n",
111
+ " feature = feature.unsqueeze(0) # Add channel dimension\n",
112
+ " return feature, label"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 33,
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "name": "stdout",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "Calculating features for genre: blues\n",
125
+ "Calculating features for genre: classical\n",
126
+ "Calculating features for genre: metal\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "# Data Preparation\n",
132
+ "directory = 'd:/Coding/audio_dl_tf/dataset'\n",
133
+ "genres = ['blues', 'classical', 'metal']\n",
134
+ "\n",
135
+ "dataset = AudioDataset(directory, genres)\n",
136
+ "\n"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 34,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "# Split dataset\n",
146
+ "train_size = int(0.6 * len(dataset))\n",
147
+ "val_size = int(0.2 * len(dataset))\n",
148
+ "test_size = len(dataset) - train_size - val_size\n",
149
+ "\n",
150
+ "train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])\n",
151
+ "\n",
152
+ "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
153
+ "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n",
154
+ "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 35,
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "# VGG Model definition\n",
164
+ "VGG_types = {\n",
165
+ " 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],\n",
166
+ " 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],\n",
167
+ " 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],\n",
168
+ " 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],\n",
169
+ "}\n",
170
+ "\n",
171
+ "class VGG_net(nn.Module):\n",
172
+ " def __init__(self, in_channels=1, num_classes=3):\n",
173
+ " super(VGG_net, self).__init__()\n",
174
+ " self.in_channels = in_channels\n",
175
+ " self.conv_layers = self.create_conv_layers(VGG_types['VGG16'])\n",
176
+ " \n",
177
+ " self.flatten = nn.Flatten()\n",
178
+ " \n",
179
+ " self.fc1 = nn.Linear(self.calculate_flatten_dim(), 4096)\n",
180
+ " self.fc2 = nn.Linear(4096, 4096)\n",
181
+ " self.fc3 = nn.Linear(4096, num_classes)\n",
182
+ "\n",
183
+ " self.relu = nn.ReLU()\n",
184
+ " self.dropout = nn.Dropout(p=0.5)\n",
185
+ " \n",
186
+ " def forward(self, x):\n",
187
+ " x = self.conv_layers(x)\n",
188
+ " x = self.flatten(x)\n",
189
+ " x = self.relu(self.fc1(x))\n",
190
+ " x = self.dropout(x)\n",
191
+ " x = self.relu(self.fc2(x))\n",
192
+ " x = self.dropout(x)\n",
193
+ " x = self.fc3(x)\n",
194
+ " return x\n",
195
+ "\n",
196
+ " def create_conv_layers(self, architecture):\n",
197
+ " layers = []\n",
198
+ " in_channels = self.in_channels\n",
199
+ "\n",
200
+ " for x in architecture:\n",
201
+ " if type(x) == int:\n",
202
+ " out_channels = x\n",
203
+ "\n",
204
+ " layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,\n",
205
+ " kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),\n",
206
+ " nn.BatchNorm2d(x),\n",
207
+ " nn.ReLU()]\n",
208
+ " in_channels = x\n",
209
+ " elif x == 'M':\n",
210
+ " layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]\n",
211
+ " return nn.Sequential(*layers)\n",
212
+ " \n",
213
+ " def calculate_flatten_dim(self):\n",
214
+ " with torch.no_grad():\n",
215
+ " sample_input = torch.zeros((1, self.in_channels, 128, 128))\n",
216
+ " sample_output = self.conv_layers(sample_input)\n",
217
+ " return sample_output.numel()"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 36,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "# Initialize the model\n",
227
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu' \n",
228
+ "model = VGG_net(in_channels=1, num_classes=len(genres)).to(device)\n",
229
+ "\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 37,
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "name": "stdout",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "torch.Size([32, 3])\n"
242
+ ]
243
+ }
244
+ ],
245
+ "source": [
246
+ "# Test forward pass\n",
247
+ "sample_data, _ = next(iter(train_loader))\n",
248
+ "sample_data = sample_data.to(device) # Add channel dimension\n",
249
+ "print(model(sample_data).shape)"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 38,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "# Training and evaluation functions\n",
259
+ "def train(model, dataloader, criterion, optimizer, device):\n",
260
+ " model.train()\n",
261
+ " running_loss = 0.0\n",
262
+ " correct = 0\n",
263
+ " total = 0\n",
264
+ " for inputs, labels in dataloader:\n",
265
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
266
+ " optimizer.zero_grad()\n",
267
+ " outputs = model(inputs)\n",
268
+ " loss = criterion(outputs, labels)\n",
269
+ " loss.backward()\n",
270
+ " optimizer.step()\n",
271
+ " running_loss += loss.item() * inputs.size(0)\n",
272
+ " _, predicted = torch.max(outputs, 1)\n",
273
+ " total += labels.size(0)\n",
274
+ " correct += (predicted == labels).sum().item()\n",
275
+ " epoch_loss = running_loss / len(dataloader.dataset)\n",
276
+ " epoch_acc = correct / total\n",
277
+ " return epoch_loss, epoch_acc\n",
278
+ "\n"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 39,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "def evaluate(model, dataloader, criterion, device):\n",
288
+ " model.eval()\n",
289
+ " running_loss = 0.0\n",
290
+ " correct = 0\n",
291
+ " total = 0\n",
292
+ " with torch.no_grad():\n",
293
+ " for inputs, labels in dataloader:\n",
294
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
295
+ " outputs = model(inputs)\n",
296
+ " loss = criterion(outputs, labels)\n",
297
+ " running_loss += loss.item() * inputs.size(0)\n",
298
+ " _, predicted = torch.max(outputs, 1)\n",
299
+ " total += labels.size(0)\n",
300
+ " correct += (predicted == labels).sum().item()\n",
301
+ " epoch_loss = running_loss / len(dataloader.dataset)\n",
302
+ " epoch_acc = correct / total\n",
303
+ " return epoch_loss, epoch_acc"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 40,
309
+ "metadata": {},
310
+ "outputs": [
311
+ {
312
+ "name": "stdout",
313
+ "output_type": "stream",
314
+ "text": [
315
+ "Epoch 1/20\n",
316
+ "Train Loss: 8.2556, Train Acc: 0.3500\n",
317
+ "Val Loss: 4557.6196, Val Acc: 0.3000\n",
318
+ "Model saved!\n"
319
+ ]
320
+ },
321
+ {
322
+ "ename": "KeyboardInterrupt",
323
+ "evalue": "",
324
+ "output_type": "error",
325
+ "traceback": [
326
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
327
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
328
+ "Cell \u001b[1;32mIn[40], line 15\u001b[0m\n\u001b[0;32m 12\u001b[0m best_val_acc \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[1;32m---> 15\u001b[0m train_loss, train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 16\u001b[0m val_loss, val_acc \u001b[38;5;241m=\u001b[39m evaluate(model, val_loader, criterion, device)\n\u001b[0;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
329
+ "Cell \u001b[1;32mIn[38], line 10\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(model, dataloader, criterion, optimizer, device)\u001b[0m\n\u001b[0;32m 8\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mto(device), labels\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m 9\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m---> 10\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 11\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, labels)\n\u001b[0;32m 12\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
330
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
331
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
332
+ "Cell \u001b[1;32mIn[35], line 25\u001b[0m, in \u001b[0;36mVGG_net.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m---> 25\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv_layers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 26\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten(x)\n\u001b[0;32m 27\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc1(x))\n",
333
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
334
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
335
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
336
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
337
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
338
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\conv.py:460\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 459\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m--> 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
339
+ "File \u001b[1;32md:\\Coding\\audio_dl_torch\\torch_dl\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[1;34m(self, input, weight, bias)\u001b[0m\n\u001b[0;32m 452\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 453\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv2d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[0;32m 454\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[0;32m 455\u001b[0m _pair(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[1;32m--> 456\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
340
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
341
+ ]
342
+ }
343
+ ],
344
+ "source": [
345
+ "import torch.optim as optim\n",
346
+ "\n",
347
+ "# Hyperparameters\n",
348
+ "num_epochs = 20\n",
349
+ "learning_rate = 0.001\n",
350
+ "\n",
351
+ "# Loss and optimizer\n",
352
+ "criterion = nn.CrossEntropyLoss()\n",
353
+ "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
354
+ "\n",
355
+ "# Training and validation\n",
356
+ "best_val_acc = 0.0\n",
357
+ "\n",
358
+ "for epoch in range(num_epochs):\n",
359
+ " train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)\n",
360
+ " val_loss, val_acc = evaluate(model, val_loader, criterion, device)\n",
361
+ "\n",
362
+ " print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
363
+ " print(f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}\")\n",
364
+ " print(f\"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\")\n",
365
+ "\n",
366
+ " # Save the best model\n",
367
+ " if val_acc > best_val_acc:\n",
368
+ " best_val_acc = val_acc\n",
369
+ " torch.save(model.state_dict(), 'best_model.pth')\n",
370
+ " print(\"Model saved!\")\n",
371
+ "\n",
372
+ "# Load the best model\n",
373
+ "model.load_state_dict(torch.load('best_model.pth'))\n"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": [
382
+ "def train(model, dataloader, criterion, optimizer, device):\n",
383
+ " model.train()\n",
384
+ " running_loss = 0.0\n",
385
+ " correct = 0\n",
386
+ " total = 0\n",
387
+ " for inputs, labels in dataloader:\n",
388
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
389
+ " optimizer.zero_grad()\n",
390
+ " outputs = model(inputs)\n",
391
+ " loss = criterion(outputs, labels)\n",
392
+ " loss.backward()\n",
393
+ " optimizer.step()\n",
394
+ " running_loss += loss.item() * inputs.size(0)\n",
395
+ " _, predicted = torch.max(outputs, 1)\n",
396
+ " total += labels.size(0)\n",
397
+ " correct += (predicted == labels).sum().item()\n",
398
+ " epoch_loss = running_loss / len(dataloader.dataset)\n",
399
+ " epoch_acc = correct / total\n",
400
+ " return epoch_loss, epoch_acc\n",
401
+ "\n",
402
+ "def evaluate(model, dataloader, criterion, device):\n",
403
+ " model.eval()\n",
404
+ " running_loss = 0.0\n",
405
+ " correct = 0\n",
406
+ " total = 0\n",
407
+ " with torch.no_grad():\n",
408
+ " for inputs, labels in dataloader:\n",
409
+ " inputs, labels = inputs.to(device), labels.to(device)\n",
410
+ " outputs = model(inputs)\n",
411
+ " loss = criterion(outputs, labels)\n",
412
+ " running_loss += loss.item() * inputs.size(0)\n",
413
+ " _, predicted = torch.max(outputs, 1)\n",
414
+ " total += labels.size(0)\n",
415
+ " correct += (predicted == labels).sum().item()\n",
416
+ " epoch_loss = running_loss / len(dataloader.dataset)\n",
417
+ " epoch_acc = correct / total\n",
418
+ " return epoch_loss, epoch_acc\n"
419
+ ]
420
+ }
421
+ ],
422
+ "metadata": {
423
+ "kernelspec": {
424
+ "display_name": "torch_dl",
425
+ "language": "python",
426
+ "name": "python3"
427
+ },
428
+ "language_info": {
429
+ "codemirror_mode": {
430
+ "name": "ipython",
431
+ "version": 3
432
+ },
433
+ "file_extension": ".py",
434
+ "mimetype": "text/x-python",
435
+ "name": "python",
436
+ "nbconvert_exporter": "python",
437
+ "pygments_lexer": "ipython3",
438
+ "version": "3.10.11"
439
+ }
440
+ },
441
+ "nbformat": 4,
442
+ "nbformat_minor": 2
443
+ }