mwirth7 commited on
Commit
c22197d
·
verified ·
1 Parent(s): 82af4a5

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +1 -2
modeling_protonet.py CHANGED
@@ -854,8 +854,7 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
- if self.incorrect_class_connection is None and isinstance(self.last_layer, LinearLayerWithoutNegativeConnections): # TODO missing initilization
858
- # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860
 
861
 
 
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
+ if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
 
858
  self.last_layer.weight.data.fill_(self.correct_class_connection)
859
 
860