techsword commited on
Commit
a2f6223
·
verified ·
1 Parent(s): de16083

Upload Wav2Vec2ForMultiHeadMultiLabelClassification

Browse files
Files changed (1) hide show
  1. modeling_wav2vec2multihead.py +9 -0
modeling_wav2vec2multihead.py CHANGED
@@ -8,6 +8,7 @@ from transformers import (
8
  Wav2Vec2PreTrainedModel,
9
  )
10
  from transformers.utils import ModelOutput
 
11
 
12
 
13
  @dataclass
@@ -21,6 +22,14 @@ class Wav2Vec2MultiHeadMultiLabelOutput(ModelOutput):
21
 
22
 
23
  class Wav2Vec2ForMultiHeadMultiLabelClassification(Wav2Vec2PreTrainedModel):
 
 
 
 
 
 
 
 
24
  def __init__(self, config):
25
  super().__init__(config)
26
  self.wav2vec2 = Wav2Vec2Model(config)
 
8
  Wav2Vec2PreTrainedModel,
9
  )
10
  from transformers.utils import ModelOutput
11
+ from .configuration_wav2vec2multihead import Wav2Vec2MultiHeadConfig
12
 
13
 
14
  @dataclass
 
22
 
23
 
24
  class Wav2Vec2ForMultiHeadMultiLabelClassification(Wav2Vec2PreTrainedModel):
25
+
26
+ """Wav2Vec2ForMultiHeadMultiLabelClassification is a model for multi-label classification using Wav2Vec2 using multiple classifier heads. Three classifier heads are hard-coded for three different tasks, such as action, object, and location classification in FSC-IC dataset.
27
+
28
+ Returns:
29
+ Wav2Vec2MultiHeadMultiLabelOutput: Contains the loss and logits for each of the three tasks, as well as hidden states and attentions if requested.
30
+ """
31
+
32
+ config_class = Wav2Vec2MultiHeadConfig
33
  def __init__(self, config):
34
  super().__init__(config)
35
  self.wav2vec2 = Wav2Vec2Model(config)