internvl_ola / configuration_audio.py
jjw0126's picture
Upload files
62d115a verified
# --------------------------------------------------------
# InternVL with Audio Support
# Audio Configuration
# --------------------------------------------------------
import copy
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an Audio Encoder Model.
It is used to instantiate an audio encoder according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
Read the documentation from [`PretrainedConfig`] for more information.
Args:
speech_encoder (`str`, *optional*, defaults to `"whisper-base"`):
Path or name of the speech encoder model.
speech_encoder_type (`str`, *optional*, defaults to `"whisper"`):
Type of speech encoder to use.
speech_projector_type (`str`, *optional*, defaults to `"linear"`):
Type of speech projector to use for feature alignment.
speech_encoder_ds_rate (`int`, *optional*, defaults to 5):
Downsampling rate for speech features.
speech_encoder_hidden_size (`int`, *optional*, defaults to 512):
Hidden size of the speech encoder.
mel_bins (`int`, *optional*, defaults to 80):
Number of mel-frequency bins for spectrogram features.
sample_rate (`int`, *optional*, defaults to 16000):
Audio sample rate in Hz.
frame_length (`float`, *optional*, defaults to 25.0):
Frame length in milliseconds for audio processing.
frame_shift (`float`, *optional*, defaults to 10.0):
Frame shift in milliseconds for audio processing.
use_beats (`bool`, *optional*, defaults to False):
Whether to use BEATs model for audio feature extraction.
beats_model_path (`str`, *optional*, defaults to None):
Path to BEATs model if use_beats is True.
whisper_config (`dict`, *optional*, defaults to None):
Configuration dictionary for Whisper model parameters.
"""
model_type = 'audio_encoder'
def __init__(
self,
speech_encoder="whisper-base",
speech_encoder_type="whisper",
speech_projector_type="linear",
speech_encoder_ds_rate=5,
speech_encoder_hidden_size=1280,
mel_bins=80,
sample_rate=16000,
frame_length=25.0,
frame_shift=10.0,
use_beats=False,
beats_model_path=None,
whisper_config=None,
**kwargs,
):
super().__init__(**kwargs)
self.speech_encoder = speech_encoder
self.speech_encoder_type = speech_encoder_type
self.speech_projector_type = speech_projector_type
self.speech_encoder_ds_rate = speech_encoder_ds_rate
self.speech_encoder_hidden_size = speech_encoder_hidden_size
self.mel_bins = mel_bins
self.sample_rate = sample_rate
self.frame_length = frame_length
self.frame_shift = frame_shift
self.use_beats = use_beats
self.beats_model_path = beats_model_path
self.whisper_config = whisper_config or {}
logger.info(f'Audio Config - Speech Encoder: {self.speech_encoder}')
logger.info(f'Audio Config - Encoder Type: {self.speech_encoder_type}')
logger.info(f'Audio Config - Projector Type: {self.speech_projector_type}')
logger.info(f'Audio Config - Downsampling Rate: {self.speech_encoder_ds_rate}')
logger.info(f'Audio Config - Hidden Size: {self.speech_encoder_hidden_size}')
logger.info(f'Audio Config - Mel Bins: {self.mel_bins}')
logger.info(f'Audio Config - Sample Rate: {self.sample_rate}')
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary.
"""
output = copy.deepcopy(self.__dict__)
output['model_type'] = self.__class__.model_type
return output