| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """HyperCLOVAX-SEED Audio Encoder model. |
| |
| Extends WhisperEncoder with the following design choices: |
| - Trained from random initialization with CTC loss |
| (not using pretrained ASR weights) |
| - Temporal pooling (Conv1d, kernel=5, stride=5) applied after the encoder |
| to reduce output rate from 50 Hz to 10 Hz for multimodal integration |
| |
| Acknowledgements: |
| - Audio encoder uses WhisperEncoder from the HuggingFace transformers library |
| (https://github.com/huggingface/transformers), Apache 2.0 License. |
| Original Whisper model: https://github.com/openai/whisper (MIT License). |
| """ |
|
|
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, PreTrainedModel, WhisperConfig |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
| try: |
| from transformers import WhisperEncoder |
| except ImportError: |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder |
|
|
| from .configuration_hyperclovax_seed_audio_encoder import HyperCLOVAXSeedAudioEncoderConfig |
|
|
|
|
| class HyperCLOVAXSeedAudioEncoder(PreTrainedModel): |
| """Audio encoder based on WhisperEncoder with temporal pooling (50Hz -> 10Hz).""" |
|
|
| config_class = HyperCLOVAXSeedAudioEncoderConfig |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: HyperCLOVAXSeedAudioEncoderConfig): |
| super().__init__(config) |
|
|
| |
| whisper_config = WhisperConfig( |
| d_model=config.d_model, |
| encoder_layers=config.encoder_layers, |
| encoder_attention_heads=config.encoder_attention_heads, |
| encoder_ffn_dim=config.encoder_ffn_dim, |
| num_mel_bins=config.num_mel_bins, |
| max_source_positions=config.max_source_positions, |
| dropout=config.dropout, |
| attention_dropout=config.attention_dropout, |
| ) |
| self.encoder = WhisperEncoder(whisper_config) |
|
|
| |
| self.temporal_pool = nn.Conv1d( |
| config.d_model, |
| config.d_model, |
| kernel_size=config.pool_kernel_size, |
| stride=config.pool_stride, |
| ) |
| self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
| |
| self.conv1 = self.encoder.conv1 |
|
|
| self.post_init() |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: torch.LongTensor |
| ) -> Tuple[torch.LongTensor, torch.LongTensor]: |
| """Compute output sequence lengths after Whisper conv + temporal pooling. |
| |
| Whisper conv frontend: |
| Conv1d(128, 768, k=3, s=1, p=1) -> same length |
| Conv1d(768, 768, k=3, s=2, p=1) -> (L - 1) // 2 + 1 |
| |
| Temporal pool: |
| Conv1d(768, 768, k=pool_kernel_size, s=pool_stride) -> (L - k) // s + 1 |
| |
| Args: |
| input_lengths: (B,) number of valid mel frames per sample |
| Returns: |
| (feat_lengths, output_lengths) - encoder output lengths and post-pooling lengths |
| """ |
| |
| feat_lengths = (input_lengths - 1) // 2 + 1 |
|
|
| |
| output_lengths = (feat_lengths - self.config.pool_kernel_size) // self.config.pool_stride + 1 |
|
|
| return feat_lengths, output_lengths |
|
|
| def forward( |
| self, |
| input_features: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutput]: |
| """Encode mel spectrogram features and apply temporal pooling. |
| |
| Args: |
| input_features: (B, num_mel_bins, T) mel spectrogram (128, 3000) |
| attention_mask: (B, T) binary mask of valid mel frames; forwarded to WhisperEncoder. |
| Returns: |
| BaseModelOutput with last_hidden_state of shape (B, T_10hz, d_model) |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| input_features = input_features.to(self.encoder.conv1.weight.dtype) |
| encoder_output = self.encoder( |
| input_features, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| ) |
| x = encoder_output.last_hidden_state |
|
|
| |
| x = x.transpose(1, 2) |
| x = self.temporal_pool(x) |
| x = x.transpose(1, 2) |
| x = self.layer_norm(x) |
|
|
| if not return_dict: |
| return (x,) |
| return BaseModelOutput(last_hidden_state=x) |
|
|
|
|
| AutoModel.register(HyperCLOVAXSeedAudioEncoderConfig, HyperCLOVAXSeedAudioEncoder) |
|
|
| __all__ = ["HyperCLOVAXSeedAudioEncoder"] |
|
|