from transformers import Wav2Vec2Config class Wav2Vec2MultiHeadConfig(Wav2Vec2Config): model_type = "wav2vec2multihead_3class" is_encoder_decoder = False def __init__(self, **kwargs): super().__init__(**kwargs) self.num_labels_1 = kwargs.pop("num_labels_1", 0) self.num_labels_2 = kwargs.pop("num_labels_2", 0) self.num_labels_3 = kwargs.pop("num_labels_3", 0) self.num_labels = self.num_labels_1 + self.num_labels_2 + self.num_labels_3