Upload app.py
Browse files
app.py
CHANGED
|
@@ -40,12 +40,8 @@ class WaveLmStutterClassification(nn.Module):
|
|
| 40 |
for param in layer.parameters():
|
| 41 |
param.requires_grad = True
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
nn.ReLU(),
|
| 46 |
-
nn.Dropout(0.3),
|
| 47 |
-
nn.Linear(256, num_labels)
|
| 48 |
-
)
|
| 49 |
self.num_labels = num_labels
|
| 50 |
|
| 51 |
def forward(self, input_values, attention_mask=None):
|
|
|
|
| 40 |
for param in layer.parameters():
|
| 41 |
param.requires_grad = True
|
| 42 |
|
| 43 |
+
# Single linear layer to match the trained checkpoint
|
| 44 |
+
self.classifier = nn.Linear(self.hidden_size, num_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
self.num_labels = num_labels
|
| 46 |
|
| 47 |
def forward(self, input_values, attention_mask=None):
|