| import whisper |
| from torch import Tensor |
| import torch.nn.functional as F |
| from whisper.model import AudioEncoder |
|
|
|
|
| class WhisperAudioEncoder(AudioEncoder): |
| """ |
| We inherited the original Whisper encoder and modified its 30-second fixed-length padding logic to |
| improve training and inference efficiency. |
| """ |
| def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| super().__init__(n_mels, n_ctx, n_state, n_head, n_layer) |
| self.audio_emb_dim = n_state |
|
|
| def forward(self, x: Tensor): |
| """ |
| x : torch.Tensor, shape = [B, T, d] the mel spectrogram of the audio |
| """ |
| x = x.transpose(1, 2) |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = x.permute(0, 2, 1) |
| positional_embedding = self.positional_embedding[:x.shape[1], :] |
| assert x.shape[1:] == positional_embedding.shape, "incorrect audio shape" |
| x = (x + positional_embedding).to(x.dtype) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.ln_post(x) |
| return x |
|
|
| @classmethod |
| def from_pretrained(cls, model_path, **kwargs): |
| whisper_model = whisper.load_model(model_path) |
| audio_encoder = cls( |
| whisper_model.dims.n_mels, |
| whisper_model.dims.n_audio_ctx*10, |
| whisper_model.dims.n_audio_state, |
| whisper_model.dims.n_audio_head, |
| whisper_model.dims.n_audio_layer, |
| ).to(whisper_model.device) |
| state_dict = whisper_model.encoder.state_dict() |
| state_dict.pop('positional_embedding') |
| ret = audio_encoder.load_state_dict(state_dict, strict=False) |
| logger.warning(f'whisper encoder does not load `positional_embedding`. {ret}') |
| audio_encoder.audio_emb_dim = whisper_model.dims.n_audio_state |
| return audio_encoder |