Heinrich Dinkel commited on
Commit
3f1e105
·
1 Parent(s): 2dfe0e8

added notebook

Browse files
Files changed (1) hide show
  1. notebook.ipynb +282 -0
notebook.ipynb ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install transformers torch torchaudio librosa pandas scikit-learn tqdm"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import torch\n",
19
+ "import torch.nn as nn\n",
20
+ "from torch.utils.data import Dataset, DataLoader\n",
21
+ "from transformers import AutoModel\n",
22
+ "import librosa\n",
23
+ "import os\n",
24
+ "import pandas as pd\n",
25
+ "from sklearn.model_selection import train_test_split\n",
26
+ "from sklearn.metrics import accuracy_score\n",
27
+ "import numpy as np\n",
28
+ "from tqdm import tqdm"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "class ESC50Dataset(Dataset):\n",
38
+ " def __init__(self, audio_dir, metadata_path, sr=16000, max_length=160000):\n",
39
+ " self.audio_dir = audio_dir\n",
40
+ " self.sr = sr\n",
41
+ " self.max_length = max_length\n",
42
+ " self.metadata = pd.read_csv(metadata_path)\n",
43
+ " \n",
44
+ " def __len__(self):\n",
45
+ " return len(self.metadata)\n",
46
+ " \n",
47
+ " def __getitem__(self, idx):\n",
48
+ " row = self.metadata.iloc[idx]\n",
49
+ " filename = row['filename']\n",
50
+ " label = row['target']\n",
51
+ " \n",
52
+ " audio_path = os.path.join(self.audio_dir, filename)\n",
53
+ " audio, sr = librosa.load(audio_path, sr=self.sr)\n",
54
+ " \n",
55
+ " audio_tensor = torch.tensor(audio).float()\n",
56
+ " label_tensor = torch.tensor(label).long()\n",
57
+ " \n",
58
+ " return audio_tensor, label_tensor"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "def download_esc50():\n",
68
+ " import urllib.request\n",
69
+ " import zipfile\n",
70
+ " \n",
71
+ " if not os.path.exists('ESC-50'):\n",
72
+ " print(\"Downloading ESC-50 dataset...\")\n",
73
+ " url = \"https://github.com/karoldvl/ESC-50/archive/master.zip\"\n",
74
+ " urllib.request.urlretrieve(url, 'esc50.zip')\n",
75
+ " \n",
76
+ " with zipfile.ZipFile('esc50.zip', 'r') as zip_ref:\n",
77
+ " zip_ref.extractall('.')\n",
78
+ " os.rename('ESC-50-master', 'ESC-50')\n",
79
+ " os.remove('esc50.zip')\n",
80
+ " print(\"ESC-50 dataset downloaded and extracted\")"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "def get_embedding_dim(model):\n",
90
+ " dummy_input = torch.randn(1, 160000)\n",
91
+ " with torch.no_grad():\n",
92
+ " output = model(dummy_input)\n",
93
+ " if isinstance(output, dict):\n",
94
+ " for key in ['last_hidden_state', 'embeddings', 'audio']:\n",
95
+ " if key in output:\n",
96
+ " features = output[key]\n",
97
+ " break\n",
98
+ " else:\n",
99
+ " features = list(output.values())[0]\n",
100
+ " else:\n",
101
+ " features = output\n",
102
+ " \n",
103
+ " if features.dim() > 2:\n",
104
+ " embedding_dim = features.shape[-1]\n",
105
+ " else:\n",
106
+ " embedding_dim = features.shape[-1]\n",
107
+ " \n",
108
+ " return embedding_dim"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Download dataset\n",
118
+ "download_esc50()\n",
119
+ "\n",
120
+ "# Load model\n",
121
+ "model = AutoModel.from_pretrained(\"mispeech/dashengtokenizer\", trust_remote_code=True)\n",
122
+ "\n",
123
+ "# Get embedding dimension\n",
124
+ "embedding_dim = get_embedding_dim(model)\n",
125
+ "print(f\"Model embedding dimension: {embedding_dim}\")\n",
126
+ "\n",
127
+ "# Freeze model\n",
128
+ "for param in model.parameters():\n",
129
+ " param.requires_grad = False\n",
130
+ "\n",
131
+ "# Single linear layer\n",
132
+ "classifier = nn.Linear(embedding_dim, 50) # 50 ESC-50 classes\n",
133
+ "\n",
134
+ "# Setup\n",
135
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
136
+ "model.to(device)\n",
137
+ "classifier.to(device)\n",
138
+ "print(f\"Using device: {device}\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "# Create datasets\n",
148
+ "audio_dir = 'ESC-50/audio'\n",
149
+ "metadata_path = 'ESC-50/meta/esc50.csv'\n",
150
+ "\n",
151
+ "dataset = ESC50Dataset(audio_dir, metadata_path)\n",
152
+ "\n",
153
+ "# Split into train/val\n",
154
+ "train_idx, val_idx = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)\n",
155
+ "train_dataset = torch.utils.data.Subset(dataset, train_idx)\n",
156
+ "val_dataset = torch.utils.data.Subset(dataset, val_idx)\n",
157
+ "\n",
158
+ "train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)\n",
159
+ "val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)\n",
160
+ "\n",
161
+ "print(f\"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}\")"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "# Training setup\n",
171
+ "optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)\n",
172
+ "criterion = nn.CrossEntropyLoss()\n",
173
+ "\n",
174
+ "# Training loop\n",
175
+ "for epoch in range(10):\n",
176
+ " model.eval()\n",
177
+ " classifier.train()\n",
178
+ " \n",
179
+ " # Training\n",
180
+ " train_loss = 0\n",
181
+ " train_preds = []\n",
182
+ " train_labels = []\n",
183
+ "\n",
184
+ " pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/10 Training')\n",
185
+ " for batch_audio, batch_labels in pbar:\n",
186
+ " batch_audio = batch_audio.to(device)\n",
187
+ " batch_labels = batch_labels.to(device)\n",
188
+ "\n",
189
+ " # Forward through frozen model\n",
190
+ " with torch.no_grad():\n",
191
+ " features = model.encode(batch_audio)\n",
192
+ " if isinstance(features, dict):\n",
193
+ " for key in ['last_hidden_state', 'embeddings', 'audio']:\n",
194
+ " if key in features:\n",
195
+ " features = features[key]\n",
196
+ " break\n",
197
+ " else:\n",
198
+ " features = list(features.values())[0]\n",
199
+ "\n",
200
+ " # Global average pooling if needed\n",
201
+ " if features.dim() > 2:\n",
202
+ " features = features.mean(dim=1)\n",
203
+ "\n",
204
+ " # Classifier\n",
205
+ " logits = classifier(features)\n",
206
+ " loss = criterion(logits, batch_labels)\n",
207
+ "\n",
208
+ " # Backward\n",
209
+ " optimizer.zero_grad()\n",
210
+ " loss.backward()\n",
211
+ " optimizer.step()\n",
212
+ "\n",
213
+ " train_loss += loss.item()\n",
214
+ " preds = torch.argmax(logits, dim=1)\n",
215
+ " train_preds.extend(preds.cpu().numpy())\n",
216
+ " train_labels.extend(batch_labels.cpu().numpy())\n",
217
+ "\n",
218
+ " # Update progress bar\n",
219
+ " pbar.set_postfix({'loss': f'{loss.item():.4f}'})\n",
220
+ "\n",
221
+ " train_acc = accuracy_score(train_labels, train_preds)\n",
222
+ " \n",
223
+ " # Validation\n",
224
+ " classifier.eval()\n",
225
+ " val_preds = []\n",
226
+ " val_labels = []\n",
227
+ "\n",
228
+ " with torch.no_grad():\n",
229
+ " pbar_val = tqdm(val_loader, desc=f'Epoch {epoch+1}/10 Validation')\n",
230
+ " for batch_audio, batch_labels in pbar_val:\n",
231
+ " batch_audio = batch_audio.to(device)\n",
232
+ " batch_labels = batch_labels.to(device)\n",
233
+ "\n",
234
+ " features = model(batch_audio)\n",
235
+ " if isinstance(features, dict):\n",
236
+ " for key in ['last_hidden_state', 'embeddings', 'audio']:\n",
237
+ " if key in features:\n",
238
+ " features = features[key]\n",
239
+ " break\n",
240
+ " else:\n",
241
+ " features = list(features.values())[0]\n",
242
+ "\n",
243
+ " if features.dim() > 2:\n",
244
+ " features = features.mean(dim=1)\n",
245
+ "\n",
246
+ " logits = classifier(features)\n",
247
+ " preds = torch.argmax(logits, dim=1)\n",
248
+ " val_preds.extend(preds.cpu().numpy())\n",
249
+ " val_labels.extend(batch_labels.cpu().numpy())\n",
250
+ "\n",
251
+ " # Update validation progress bar\n",
252
+ " batch_acc = (preds == batch_labels).float().mean().item()\n",
253
+ " pbar_val.set_postfix({'batch_acc': f'{batch_acc:.4f}'})\n",
254
+ "\n",
255
+ " val_acc = accuracy_score(val_labels, val_preds)\n",
256
+ " \n",
257
+ " print(f\"Epoch {epoch+1}/10 - Train Loss: {train_loss/len(train_loader):.4f} - Train Acc: {train_acc:.4f} - Val Acc: {val_acc:.4f}\")"
258
+ ]
259
+ }
260
+ ],
261
+ "metadata": {
262
+ "kernelspec": {
263
+ "display_name": "Python 3",
264
+ "language": "python",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "codemirror_mode": {
269
+ "name": "ipython",
270
+ "version": 3
271
+ },
272
+ "file_extension": ".py",
273
+ "mimetype": "text/x-python",
274
+ "name": "python",
275
+ "nbconvert_exporter": "python",
276
+ "pygments_lexer": "ipython3",
277
+ "version": "3.8.0"
278
+ }
279
+ },
280
+ "nbformat": 4,
281
+ "nbformat_minor": 4
282
+ }