mwirth7 commited on
Commit
09d7c8a
·
verified ·
1 Parent(s): 3c867d8

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +6 -24
modeling_protonet.py CHANGED
@@ -854,7 +854,7 @@ class AudioProtoNetPreTrainedModel(PreTrainedModel):
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
- if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None: # TODO missing initilization
858
  # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860
 
@@ -870,14 +870,12 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
870
  def forward(
871
  self,
872
  input_values: torch.Tensor,
873
- output_hidden_states: bool = None,
874
- return_dict: bool = None
875
- ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
876
  """
877
  Args:
878
  input_values:
879
  output_hidden_states:
880
- return_dict:
881
 
882
  Returns:
883
  last_hidden_state: torch.FloatTensor = None
@@ -885,7 +883,7 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
885
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
886
 
887
  """
888
- return self.backbone(input_values, output_hidden_states, return_dict)
889
 
890
 
891
  class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
@@ -897,13 +895,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
897
  self.model = AudioProtoNetModel(config)
898
  self.head = AudioProtoNetClassificationHead(config)
899
 
900
-
901
- def freeze_backbone(self):
902
- pass
903
-
904
- def int2str(self): # TODO
905
- pass
906
-
907
  def forward(
908
  self,
909
  input_values: torch.Tensor,
@@ -911,10 +902,9 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
911
  prototypes_of_wrong_class: torch.Tensor = None,
912
  output_hidden_states: bool = None,
913
  output_prototypical_activations: bool = None,
914
- return_dict: bool = None,
915
- ) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
916
 
917
- backbone_outputs = self.model(input_values, output_hidden_states, return_dict)
918
 
919
  last_hidden_state = backbone_outputs[0]
920
 
@@ -936,14 +926,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
936
  if output_prototypical_activations is not None:
937
  prototype_activations = info[4]
938
 
939
- if return_dict:
940
- output = (logits,)
941
- output += (loss, ) if loss is not None else ()
942
- output += (last_hidden_state, )
943
- output += (hidden_states, ) if hidden_states is not None else ()
944
- output += (prototype_activations,) if prototype_activations is not None else ()
945
- return output
946
-
947
  return SequenceClassifierOutputWithProtoTypeActivations(
948
  logits=logits,
949
  loss=loss,
 
854
  nn.init.trunc_normal_(module.weight, std=0.02)
855
  if module.bias is not None:
856
  nn.init.zeros_(module.bias)
857
+ if isinstance(module, LinearLayerWithoutNegativeConnections) and self.incorrect_class_connection is None:
858
  # Initialize all weights to the correct_class_connection value
859
  self.last_layer.weight.data.fill_(self.correct_class_connection)
860
 
 
870
  def forward(
871
  self,
872
  input_values: torch.Tensor,
873
+ output_hidden_states: bool = None
874
+ ) -> BaseModelOutputWithPoolingAndNoAttention:
 
875
  """
876
  Args:
877
  input_values:
878
  output_hidden_states:
 
879
 
880
  Returns:
881
  last_hidden_state: torch.FloatTensor = None
 
883
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
884
 
885
  """
886
+ return self.backbone(input_values, output_hidden_states)
887
 
888
 
889
  class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
 
895
  self.model = AudioProtoNetModel(config)
896
  self.head = AudioProtoNetClassificationHead(config)
897
 
 
 
 
 
 
 
 
898
  def forward(
899
  self,
900
  input_values: torch.Tensor,
 
902
  prototypes_of_wrong_class: torch.Tensor = None,
903
  output_hidden_states: bool = None,
904
  output_prototypical_activations: bool = None,
905
+ ) -> SequenceClassifierOutputWithProtoTypeActivations:
 
906
 
907
+ backbone_outputs = self.model(input_values, output_hidden_states)
908
 
909
  last_hidden_state = backbone_outputs[0]
910
 
 
926
  if output_prototypical_activations is not None:
927
  prototype_activations = info[4]
928
 
 
 
 
 
 
 
 
 
929
  return SequenceClassifierOutputWithProtoTypeActivations(
930
  logits=logits,
931
  loss=loss,