File size: 5,743 Bytes
0c1d6f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | # coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-SEED Audio Encoder model.
Extends WhisperEncoder with the following design choices:
- Trained from random initialization with CTC loss
(not using pretrained ASR weights)
- Temporal pooling (Conv1d, kernel=5, stride=5) applied after the encoder
to reduce output rate from 50 Hz to 10 Hz for multimodal integration
Acknowledgements:
- Audio encoder uses WhisperEncoder from the HuggingFace transformers library
(https://github.com/huggingface/transformers), Apache 2.0 License.
Original Whisper model: https://github.com/openai/whisper (MIT License).
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, WhisperConfig
from transformers.modeling_outputs import BaseModelOutput
try:
from transformers import WhisperEncoder
except ImportError:
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from .configuration_hyperclovax_seed_audio_encoder import HyperCLOVAXSeedAudioEncoderConfig
class HyperCLOVAXSeedAudioEncoder(PreTrainedModel):
"""Audio encoder based on WhisperEncoder with temporal pooling (50Hz -> 10Hz)."""
config_class = HyperCLOVAXSeedAudioEncoderConfig
supports_gradient_checkpointing = True
def __init__(self, config: HyperCLOVAXSeedAudioEncoderConfig):
super().__init__(config)
# Whisper encoder (using HF WhisperEncoder directly)
whisper_config = WhisperConfig(
d_model=config.d_model,
encoder_layers=config.encoder_layers,
encoder_attention_heads=config.encoder_attention_heads,
encoder_ffn_dim=config.encoder_ffn_dim,
num_mel_bins=config.num_mel_bins,
max_source_positions=config.max_source_positions,
dropout=config.dropout,
attention_dropout=config.attention_dropout,
)
self.encoder = WhisperEncoder(whisper_config)
# Temporal pooling: 50Hz -> 10Hz
self.temporal_pool = nn.Conv1d(
config.d_model,
config.d_model,
kernel_size=config.pool_kernel_size,
stride=config.pool_stride,
)
self.layer_norm = nn.LayerNorm(config.d_model)
# Compatibility: modeling_vlm.py accesses audio_model.conv1.weight.{dtype,device}
self.conv1 = self.encoder.conv1
self.post_init()
def _get_feat_extract_output_lengths(
self, input_lengths: torch.LongTensor
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""Compute output sequence lengths after Whisper conv + temporal pooling.
Whisper conv frontend:
Conv1d(128, 768, k=3, s=1, p=1) -> same length
Conv1d(768, 768, k=3, s=2, p=1) -> (L - 1) // 2 + 1
Temporal pool:
Conv1d(768, 768, k=pool_kernel_size, s=pool_stride) -> (L - k) // s + 1
Args:
input_lengths: (B,) number of valid mel frames per sample
Returns:
(feat_lengths, output_lengths) - encoder output lengths and post-pooling lengths
"""
# After Whisper conv frontend (second conv has stride 2, padding 1)
feat_lengths = (input_lengths - 1) // 2 + 1
# After temporal pooling
output_lengths = (feat_lengths - self.config.pool_kernel_size) // self.config.pool_stride + 1
return feat_lengths, output_lengths
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
"""Encode mel spectrogram features and apply temporal pooling.
Args:
input_features: (B, num_mel_bins, T) mel spectrogram (128, 3000)
attention_mask: (B, T) binary mask of valid mel frames; forwarded to WhisperEncoder.
Returns:
BaseModelOutput with last_hidden_state of shape (B, T_10hz, d_model)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whisper encoder: (B, 128, 3000) -> (B, 1500, 768)
input_features = input_features.to(self.encoder.conv1.weight.dtype)
encoder_output = self.encoder(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
x = encoder_output.last_hidden_state
# Temporal pooling: (B, 1500, 768) -> (B, 300, 768)
x = x.transpose(1, 2) # (B, d_model, T_enc)
x = self.temporal_pool(x) # (B, d_model, T_pool)
x = x.transpose(1, 2) # (B, T_pool, d_model)
x = self.layer_norm(x)
if not return_dict:
return (x,)
return BaseModelOutput(last_hidden_state=x)
AutoModel.register(HyperCLOVAXSeedAudioEncoderConfig, HyperCLOVAXSeedAudioEncoder)
__all__ = ["HyperCLOVAXSeedAudioEncoder"]
|