AlexTolstenko commited on
Commit
eb9689f
·
1 Parent(s): 9f8556d

Upload 2 files

Browse files
Files changed (2) hide show
  1. models/__init__.py +0 -0
  2. models/model.py +196 -0
models/__init__.py ADDED
File without changes
models/model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # linear algebra
2
+ import pandas as pd
3
+ import os
4
+ import matplotlib.pylab as plt
5
+
6
+ import librosa as lr
7
+ import torch
8
+ import torch.nn as nn
9
+ import pytorch_lightning as pl
10
+ import gradio
11
+
12
+ # HYPERPARAMS
13
+ EPOCHS = 200
14
+ BATCH_SIZE = 32
15
+ NUM_OF_CLASSES = 14
16
+
17
+ class MFCC_CNN(pl.LightningModule):
18
+ def __init__(self, num_of_classes):
19
+ super(MFCC_CNN, self).__init__()
20
+
21
+ self.example_input_array = torch.Tensor(32, 1, 50, 94)
22
+ self.train_loss_output = []
23
+ self.train_acc_output = []
24
+ self.val_acc_output = []
25
+ self.val_loss_output = []
26
+
27
+ self.number_of_classes = num_of_classes
28
+
29
+ self.conv_1 = nn.Sequential(
30
+ nn.Conv2d(in_channels = 1,
31
+ out_channels = 64,
32
+ kernel_size =3,
33
+ padding = 1,
34
+ stride = 1),
35
+ nn.BatchNorm2d(64),
36
+ nn.LeakyReLU(),
37
+ nn.MaxPool2d(kernel_size=2),
38
+ nn.Dropout(0.1)
39
+ )
40
+
41
+ self.conv_2 = nn.Sequential(
42
+ nn.Conv2d(in_channels = 64,
43
+ out_channels = 128,
44
+ kernel_size = 3,
45
+ padding = 1,
46
+ stride = 1),
47
+ nn.BatchNorm2d(128),
48
+ nn.LeakyReLU(),
49
+ nn.MaxPool2d(kernel_size=2),
50
+ nn.Dropout(0.1)
51
+ )
52
+
53
+ self.conv_3 = nn.Sequential(
54
+ nn.Conv2d(in_channels = 128,
55
+ out_channels = 256,
56
+ kernel_size = 3,
57
+ padding = 1,
58
+ stride = 1),
59
+ nn.BatchNorm2d(256),
60
+ nn.LeakyReLU(),
61
+ nn.MaxPool2d(kernel_size=2),
62
+ nn.Dropout(0.1)
63
+ )
64
+
65
+ self.conv_4 = nn.Sequential(
66
+ nn.Conv2d(in_channels = 256,
67
+ out_channels = 512,
68
+ kernel_size = 3,
69
+ padding = 1,
70
+ stride = 1),
71
+ nn.BatchNorm2d(512),
72
+ nn.LeakyReLU(),
73
+ nn.MaxPool2d(kernel_size=2)
74
+ )
75
+
76
+ self.conv_5 = nn.Sequential(
77
+ nn.Conv2d(in_channels = 512,
78
+ out_channels = 512,
79
+ kernel_size = 2,
80
+ padding = 0,
81
+ stride = 1),
82
+ nn.BatchNorm2d(512),
83
+ nn.LeakyReLU(),
84
+ nn.MaxPool2d(kernel_size=2)
85
+ )
86
+
87
+ self.drop = nn.Dropout(0.1)
88
+ self.lin_1 = nn.Linear(1024, 128)
89
+ self.lin_2 = nn.Linear(128, 64)
90
+ self.lin_3 = nn.Linear(64, num_of_classes)
91
+
92
+ self.relu = nn.ReLU()
93
+ self.softmax = nn.Softmax()
94
+
95
+ def forward(self, x):
96
+ out = self.conv_1(x)
97
+ out = self.conv_2(out)
98
+ out = self.conv_3(out)
99
+ out = self.conv_4(out)
100
+ out = self.conv_5(out)
101
+
102
+ out = torch.flatten(out, start_dim=1)
103
+
104
+ out = self.drop(self.lin_1(self.relu(out)))
105
+ out = self.drop(self.lin_2(self.relu(out)))
106
+ out = self.drop(self.lin_3(self.relu(out)))
107
+
108
+ out = self.softmax(out)
109
+
110
+ return out
111
+
112
+ def loss_fn(self, out, target):
113
+ return nn.CrossEntropyLoss()(input=out.view(-1, self.number_of_classes),
114
+ target=target)
115
+
116
+ def configure_optimizers(self):
117
+ LR=1e-3
118
+ optimizer = torch.optim.Adam(self.parameters(), lr=LR, weight_decay=1e-3)
119
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
120
+ mode='min',
121
+ factor=0.5,
122
+ patience=5,
123
+ verbose=True)
124
+
125
+ return {
126
+ 'optimizer': optimizer,
127
+ "lr_scheduler": {
128
+ "scheduler": scheduler,
129
+ "monitor": "val_loss",
130
+ 'interval': 'epoch',
131
+ 'frequency': 1
132
+ },
133
+ }
134
+
135
+ def training_step(self, batch, batch_idx):
136
+ mfcc, lable = batch
137
+ mfcc = mfcc.view(-1, 1, 50, 94)
138
+ lable = lable.view(-1, self.number_of_classes)
139
+
140
+ out = self(mfcc)
141
+
142
+ loss = self.loss_fn(out=out, target=lable)
143
+
144
+ lable = torch.argmax(lable,dim=1)
145
+ predictions = torch.argmax(out,dim=1)
146
+ accuracy = torch.sum(lable==predictions)/float(len(lable))
147
+
148
+ self.train_acc_output.append(accuracy.detach().numpy())
149
+ self.train_loss_output.append(loss.detach().numpy())
150
+ #wandb.log({'train_accuracy_step': accuracy, 'train_loss_step':loss})\
151
+
152
+ self.log('train_accuracy', accuracy, prog_bar=True, on_epoch=True, on_step=False)
153
+ self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=False)
154
+ return loss
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ mfcc, lable = batch
158
+ mfcc = mfcc.view(-1, 1, 50, 94)
159
+ lable = lable.view(-1, self.number_of_classes)
160
+
161
+ out = self(mfcc)
162
+
163
+ loss = self.loss_fn(out=out, target=lable)
164
+
165
+ lable = torch.argmax(lable,dim=1)
166
+ predictions = torch.argmax(out,dim=1)
167
+ accuracy = torch.sum(lable==predictions)/float(len(lable))
168
+
169
+ self.val_acc_output.append(accuracy.detach().numpy())
170
+ self.val_loss_output.append(loss.detach().numpy())
171
+ #wandb.log({'val_accuracy_step': accuracy, 'val_loss_step':loss})
172
+
173
+ self.log('val_accuracy', accuracy, prog_bar=True, on_epoch=True)
174
+ self.log('val_loss', loss, prog_bar=True, on_epoch=True)
175
+
176
+ return loss
177
+
178
+ def on_train_epoch_end(self):
179
+ train_loss_epoch = self.train_loss_output
180
+ train_acc_epoch = self.train_acc_output
181
+
182
+ #wandb.log({'train_loss_epoch':np.mean(train_loss_epoch),
183
+ # 'train_acc_epoch':np.mean(train_acc_epoch)})
184
+
185
+ self.train_loss_output.clear()
186
+ self.train_acc_output.clear()
187
+
188
+ def on_validation_epoch_end(self):
189
+ val_loss_epoch = self.val_loss_output
190
+ val_acc_epoch = self.val_acc_output
191
+
192
+ #wandb.log({'val_loss_epoch':np.mean(val_loss_epoch),
193
+ # 'val_acc_epoch':np.mean(val_acc_epoch)})
194
+
195
+ self.val_acc_output.clear()
196
+ self.val_loss_output.clear()