| from transformers import Wav2Vec2Config | |
| class Wav2Vec2MultiHeadConfig(Wav2Vec2Config): | |
| model_type = "wav2vec2multihead_3class" | |
| is_encoder_decoder = False | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.num_labels_1 = kwargs.pop("num_labels_1", 0) | |
| self.num_labels_2 = kwargs.pop("num_labels_2", 0) | |
| self.num_labels_3 = kwargs.pop("num_labels_3", 0) | |
| self.num_labels = self.num_labels_1 + self.num_labels_2 + self.num_labels_3 | |