Moremoholo2 commited on
Commit
2ca0e96
·
verified ·
1 Parent(s): 3d83fd2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -12
model.py CHANGED
@@ -5,28 +5,36 @@ import torch.nn.functional as F
5
  class AudioCNN(nn.Module):
6
  def __init__(self, num_classes, input_length):
7
  super(AudioCNN, self).__init__()
8
- # Ensure input_length is an integer, not a tuple
 
9
  if isinstance(input_length, (tuple, list)):
10
  input_length = input_length[0]
11
 
12
- # 1D Convolution for raw audio
13
  self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
14
- self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
15
-
16
- # Dynamically compute flattened size
 
 
17
  with torch.no_grad():
18
- dummy_input = torch.zeros((1, 1, input_length)) # batch=1, channels=1
19
- x = self.pool(F.relu(self.conv1(dummy_input)))
 
20
  flattened_size = x.numel() // x.size(0)
21
-
22
- # Fully connected layer
23
- self.fc1 = nn.Linear(flattened_size, num_classes)
 
24
 
25
  def forward(self, x):
26
- x = self.pool(F.relu(self.conv1(x)))
 
27
  x = x.view(x.size(0), -1)
28
- x = self.fc1(x)
 
29
  return x
30
 
31
 
32
 
 
 
5
  class AudioCNN(nn.Module):
6
  def __init__(self, num_classes, input_length):
7
  super(AudioCNN, self).__init__()
8
+
9
+ # Ensure input_length is integer
10
  if isinstance(input_length, (tuple, list)):
11
  input_length = input_length[0]
12
 
13
+ # Convolutional layers
14
  self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)
15
+ self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
16
+ self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)
17
+ self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
18
+
19
+ # Compute output size after conv + pooling dynamically
20
  with torch.no_grad():
21
+ dummy_input = torch.zeros(1, 1, input_length) # batch=1, channel=1
22
+ x = self.pool1(F.relu(self.conv1(dummy_input)))
23
+ x = self.pool2(F.relu(self.conv2(x)))
24
  flattened_size = x.numel() // x.size(0)
25
+
26
+ # Fully connected layers
27
+ self.fc1 = nn.Linear(flattened_size, 128)
28
+ self.fc2 = nn.Linear(128, num_classes)
29
 
30
  def forward(self, x):
31
+ x = self.pool1(F.relu(self.conv1(x)))
32
+ x = self.pool2(F.relu(self.conv2(x)))
33
  x = x.view(x.size(0), -1)
34
+ x = F.relu(self.fc1(x))
35
+ x = self.fc2(x)
36
  return x
37
 
38
 
39
 
40
+