Upload Wav2Vec2ForMultiHeadMultiLabelClassification
Browse files
modeling_wav2vec2multihead.py
CHANGED
|
@@ -8,6 +8,7 @@ from transformers import (
|
|
| 8 |
Wav2Vec2PreTrainedModel,
|
| 9 |
)
|
| 10 |
from transformers.utils import ModelOutput
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
@@ -21,6 +22,14 @@ class Wav2Vec2MultiHeadMultiLabelOutput(ModelOutput):
|
|
| 21 |
|
| 22 |
|
| 23 |
class Wav2Vec2ForMultiHeadMultiLabelClassification(Wav2Vec2PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def __init__(self, config):
|
| 25 |
super().__init__(config)
|
| 26 |
self.wav2vec2 = Wav2Vec2Model(config)
|
|
|
|
| 8 |
Wav2Vec2PreTrainedModel,
|
| 9 |
)
|
| 10 |
from transformers.utils import ModelOutput
|
| 11 |
+
from .configuration_wav2vec2multihead import Wav2Vec2MultiHeadConfig
|
| 12 |
|
| 13 |
|
| 14 |
@dataclass
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class Wav2Vec2ForMultiHeadMultiLabelClassification(Wav2Vec2PreTrainedModel):
|
| 25 |
+
|
| 26 |
+
"""Wav2Vec2ForMultiHeadMultiLabelClassification is a model for multi-label classification using Wav2Vec2 using multiple classifier heads. Three classifier heads are hard-coded for three different tasks, such as action, object, and location classification in FSC-IC dataset.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Wav2Vec2MultiHeadMultiLabelOutput: Contains the loss and logits for each of the three tasks, as well as hidden states and attentions if requested.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
config_class = Wav2Vec2MultiHeadConfig
|
| 33 |
def __init__(self, config):
|
| 34 |
super().__init__(config)
|
| 35 |
self.wav2vec2 = Wav2Vec2Model(config)
|