HyperCLOVAX-SEED-Think-4B / modeling_hyperclovax_seed_audio_encoder.py
bigshanedogg's picture
Upload folder using huggingface_hub
0c1d6f8 verified
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-SEED Audio Encoder model.
Extends WhisperEncoder with the following design choices:
- Trained from random initialization with CTC loss
(not using pretrained ASR weights)
- Temporal pooling (Conv1d, kernel=5, stride=5) applied after the encoder
to reduce output rate from 50 Hz to 10 Hz for multimodal integration
Acknowledgements:
- Audio encoder uses WhisperEncoder from the HuggingFace transformers library
(https://github.com/huggingface/transformers), Apache 2.0 License.
Original Whisper model: https://github.com/openai/whisper (MIT License).
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, WhisperConfig
from transformers.modeling_outputs import BaseModelOutput
try:
from transformers import WhisperEncoder
except ImportError:
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from .configuration_hyperclovax_seed_audio_encoder import HyperCLOVAXSeedAudioEncoderConfig
class HyperCLOVAXSeedAudioEncoder(PreTrainedModel):
"""Audio encoder based on WhisperEncoder with temporal pooling (50Hz -> 10Hz)."""
config_class = HyperCLOVAXSeedAudioEncoderConfig
supports_gradient_checkpointing = True
def __init__(self, config: HyperCLOVAXSeedAudioEncoderConfig):
super().__init__(config)
# Whisper encoder (using HF WhisperEncoder directly)
whisper_config = WhisperConfig(
d_model=config.d_model,
encoder_layers=config.encoder_layers,
encoder_attention_heads=config.encoder_attention_heads,
encoder_ffn_dim=config.encoder_ffn_dim,
num_mel_bins=config.num_mel_bins,
max_source_positions=config.max_source_positions,
dropout=config.dropout,
attention_dropout=config.attention_dropout,
)
self.encoder = WhisperEncoder(whisper_config)
# Temporal pooling: 50Hz -> 10Hz
self.temporal_pool = nn.Conv1d(
config.d_model,
config.d_model,
kernel_size=config.pool_kernel_size,
stride=config.pool_stride,
)
self.layer_norm = nn.LayerNorm(config.d_model)
# Compatibility: modeling_vlm.py accesses audio_model.conv1.weight.{dtype,device}
self.conv1 = self.encoder.conv1
self.post_init()
def _get_feat_extract_output_lengths(
self, input_lengths: torch.LongTensor
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""Compute output sequence lengths after Whisper conv + temporal pooling.
Whisper conv frontend:
Conv1d(128, 768, k=3, s=1, p=1) -> same length
Conv1d(768, 768, k=3, s=2, p=1) -> (L - 1) // 2 + 1
Temporal pool:
Conv1d(768, 768, k=pool_kernel_size, s=pool_stride) -> (L - k) // s + 1
Args:
input_lengths: (B,) number of valid mel frames per sample
Returns:
(feat_lengths, output_lengths) - encoder output lengths and post-pooling lengths
"""
# After Whisper conv frontend (second conv has stride 2, padding 1)
feat_lengths = (input_lengths - 1) // 2 + 1
# After temporal pooling
output_lengths = (feat_lengths - self.config.pool_kernel_size) // self.config.pool_stride + 1
return feat_lengths, output_lengths
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
"""Encode mel spectrogram features and apply temporal pooling.
Args:
input_features: (B, num_mel_bins, T) mel spectrogram (128, 3000)
attention_mask: (B, T) binary mask of valid mel frames; forwarded to WhisperEncoder.
Returns:
BaseModelOutput with last_hidden_state of shape (B, T_10hz, d_model)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whisper encoder: (B, 128, 3000) -> (B, 1500, 768)
input_features = input_features.to(self.encoder.conv1.weight.dtype)
encoder_output = self.encoder(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
x = encoder_output.last_hidden_state
# Temporal pooling: (B, 1500, 768) -> (B, 300, 768)
x = x.transpose(1, 2) # (B, d_model, T_enc)
x = self.temporal_pool(x) # (B, d_model, T_pool)
x = x.transpose(1, 2) # (B, T_pool, d_model)
x = self.layer_norm(x)
if not return_dict:
return (x,)
return BaseModelOutput(last_hidden_state=x)
AutoModel.register(HyperCLOVAXSeedAudioEncoderConfig, HyperCLOVAXSeedAudioEncoder)
__all__ = ["HyperCLOVAXSeedAudioEncoder"]