mwirth7 commited on
Commit
128ce47
·
verified ·
1 Parent(s): 35ede9e

Update modeling_protonet.py

Browse files
Files changed (1) hide show
  1. modeling_protonet.py +5 -16
modeling_protonet.py CHANGED
@@ -870,14 +870,12 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
870
  def forward(
871
  self,
872
  input_values: torch.Tensor,
873
- output_hidden_states: bool = None,
874
- return_dict: bool = None
875
- ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
876
  """
877
  Args:
878
  input_values:
879
  output_hidden_states:
880
- return_dict:
881
 
882
  Returns:
883
  last_hidden_state: torch.FloatTensor = None
@@ -885,7 +883,7 @@ class AudioProtoNetModel(AudioProtoNetPreTrainedModel):
885
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
886
 
887
  """
888
- return self.backbone(input_values, output_hidden_states, return_dict)
889
 
890
 
891
  class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
@@ -904,10 +902,9 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
904
  prototypes_of_wrong_class: torch.Tensor = None,
905
  output_hidden_states: bool = None,
906
  output_prototypical_activations: bool = None,
907
- return_dict: bool = None,
908
- ) -> tuple | SequenceClassifierOutputWithProtoTypeActivations:
909
 
910
- backbone_outputs = self.model(input_values, output_hidden_states, return_dict)
911
 
912
  last_hidden_state = backbone_outputs[0]
913
 
@@ -929,14 +926,6 @@ class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
929
  if output_prototypical_activations is not None:
930
  prototype_activations = info[4]
931
 
932
- if return_dict:
933
- output = (logits,)
934
- output += (loss, ) if loss is not None else ()
935
- output += (last_hidden_state, )
936
- output += (hidden_states, ) if hidden_states is not None else ()
937
- output += (prototype_activations,) if prototype_activations is not None else ()
938
- return output
939
-
940
  return SequenceClassifierOutputWithProtoTypeActivations(
941
  logits=logits,
942
  loss=loss,
 
870
  def forward(
871
  self,
872
  input_values: torch.Tensor,
873
+ output_hidden_states: bool = None
874
+ ) -> BaseModelOutputWithPoolingAndNoAttention:
 
875
  """
876
  Args:
877
  input_values:
878
  output_hidden_states:
 
879
 
880
  Returns:
881
  last_hidden_state: torch.FloatTensor = None
 
883
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
884
 
885
  """
886
+ return self.backbone(input_values, output_hidden_states)
887
 
888
 
889
  class AudioProtoNetForSequenceClassification(AudioProtoNetPreTrainedModel):
 
902
  prototypes_of_wrong_class: torch.Tensor = None,
903
  output_hidden_states: bool = None,
904
  output_prototypical_activations: bool = None,
905
+ ) -> SequenceClassifierOutputWithProtoTypeActivations:
 
906
 
907
+ backbone_outputs = self.model(input_values, output_hidden_states)
908
 
909
  last_hidden_state = backbone_outputs[0]
910
 
 
926
  if output_prototypical_activations is not None:
927
  prototype_activations = info[4]
928
 
 
 
 
 
 
 
 
 
929
  return SequenceClassifierOutputWithProtoTypeActivations(
930
  logits=logits,
931
  loss=loss,