throgletworld commited on
Commit
431e771
·
verified ·
1 Parent(s): b8bfaf3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -40,12 +40,8 @@ class WaveLmStutterClassification(nn.Module):
40
  for param in layer.parameters():
41
  param.requires_grad = True
42
 
43
- self.classifier = nn.Sequential(
44
- nn.Linear(self.hidden_size, 256),
45
- nn.ReLU(),
46
- nn.Dropout(0.3),
47
- nn.Linear(256, num_labels)
48
- )
49
  self.num_labels = num_labels
50
 
51
  def forward(self, input_values, attention_mask=None):
 
40
  for param in layer.parameters():
41
  param.requires_grad = True
42
 
43
+ # Single linear layer to match the trained checkpoint
44
+ self.classifier = nn.Linear(self.hidden_size, num_labels)
 
 
 
 
45
  self.num_labels = num_labels
46
 
47
  def forward(self, input_values, attention_mask=None):