| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import os |
| | from typing import Union |
| |
|
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class AudioConfig(PretrainedConfig): |
| | r""" |
| | This is the configuration class to store the configuration of an Audio Encoder Model. |
| | It is used to instantiate an audio encoder according to the specified arguments, |
| | defining the model architecture. |
| | |
| | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. |
| | Read the documentation from [`PretrainedConfig`] for more information. |
| | |
| | Args: |
| | speech_encoder (`str`, *optional*, defaults to `"whisper-base"`): |
| | Path or name of the speech encoder model. |
| | speech_encoder_type (`str`, *optional*, defaults to `"whisper"`): |
| | Type of speech encoder to use. |
| | speech_projector_type (`str`, *optional*, defaults to `"linear"`): |
| | Type of speech projector to use for feature alignment. |
| | speech_encoder_ds_rate (`int`, *optional*, defaults to 5): |
| | Downsampling rate for speech features. |
| | speech_encoder_hidden_size (`int`, *optional*, defaults to 512): |
| | Hidden size of the speech encoder. |
| | mel_bins (`int`, *optional*, defaults to 80): |
| | Number of mel-frequency bins for spectrogram features. |
| | sample_rate (`int`, *optional*, defaults to 16000): |
| | Audio sample rate in Hz. |
| | frame_length (`float`, *optional*, defaults to 25.0): |
| | Frame length in milliseconds for audio processing. |
| | frame_shift (`float`, *optional*, defaults to 10.0): |
| | Frame shift in milliseconds for audio processing. |
| | use_beats (`bool`, *optional*, defaults to False): |
| | Whether to use BEATs model for audio feature extraction. |
| | beats_model_path (`str`, *optional*, defaults to None): |
| | Path to BEATs model if use_beats is True. |
| | whisper_config (`dict`, *optional*, defaults to None): |
| | Configuration dictionary for Whisper model parameters. |
| | """ |
| |
|
| | model_type = 'audio_encoder' |
| |
|
| | def __init__( |
| | self, |
| | speech_encoder="whisper-base", |
| | speech_encoder_type="whisper", |
| | speech_projector_type="linear", |
| | speech_encoder_ds_rate=5, |
| | speech_encoder_hidden_size=1280, |
| | mel_bins=80, |
| | sample_rate=16000, |
| | frame_length=25.0, |
| | frame_shift=10.0, |
| | use_beats=False, |
| | beats_model_path=None, |
| | whisper_config=None, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | self.speech_encoder = speech_encoder |
| | self.speech_encoder_type = speech_encoder_type |
| | self.speech_projector_type = speech_projector_type |
| | self.speech_encoder_ds_rate = speech_encoder_ds_rate |
| | self.speech_encoder_hidden_size = speech_encoder_hidden_size |
| | self.mel_bins = mel_bins |
| | self.sample_rate = sample_rate |
| | self.frame_length = frame_length |
| | self.frame_shift = frame_shift |
| | self.use_beats = use_beats |
| | self.beats_model_path = beats_model_path |
| | self.whisper_config = whisper_config or {} |
| |
|
| | logger.info(f'Audio Config - Speech Encoder: {self.speech_encoder}') |
| | logger.info(f'Audio Config - Encoder Type: {self.speech_encoder_type}') |
| | logger.info(f'Audio Config - Projector Type: {self.speech_projector_type}') |
| | logger.info(f'Audio Config - Downsampling Rate: {self.speech_encoder_ds_rate}') |
| | logger.info(f'Audio Config - Hidden Size: {self.speech_encoder_hidden_size}') |
| | logger.info(f'Audio Config - Mel Bins: {self.mel_bins}') |
| | logger.info(f'Audio Config - Sample Rate: {self.sample_rate}') |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': |
| | cls._set_token_in_kwargs(kwargs) |
| |
|
| | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
| | return cls.from_dict(config_dict, **kwargs) |
| |
|
| | def to_dict(self): |
| | """ |
| | Serializes this instance to a Python dictionary. |
| | """ |
| | output = copy.deepcopy(self.__dict__) |
| | output['model_type'] = self.__class__.model_type |
| | return output |
| |
|