mwirth7 commited on
Commit
35ede9e
·
verified ·
1 Parent(s): c22197d

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +2 -8
modeling_protonet.py CHANGED
@@ -854,7 +854,8 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
- if isinstance(self.last_layer, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
 
858
  self.last_layer.weight.data.fill_(self.correct_class_connection)
859
 
860
 
@@ -896,13 +897,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
896
  self.model = AudioProtoNetModel(config)
897
  self.head = AudioProtoNetClassificationHead(config)
898
 
899
-
900
- def freeze_backbone(self):
901
- pass
902
-
903
- def int2str(self): # TODO
904
- pass
905
-
906
  def forward(
907
  self,
908
  input_values: torch.Tensor,
 
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
+ if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
858
+ # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860
 
861
 
 
897
  self.model = AudioProtoNetModel(config)
898
  self.head = AudioProtoNetClassificationHead(config)
899
 
 
 
 
 
 
 
 
900
  def forward(
901
  self,
902
  input_values: torch.Tensor,