File size: 5,743 Bytes
0c1d6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-SEED Audio Encoder model.

Extends WhisperEncoder with the following design choices:
- Trained from random initialization with CTC loss
  (not using pretrained ASR weights)
- Temporal pooling (Conv1d, kernel=5, stride=5) applied after the encoder
  to reduce output rate from 50 Hz to 10 Hz for multimodal integration

Acknowledgements:
    - Audio encoder uses WhisperEncoder from the HuggingFace transformers library
      (https://github.com/huggingface/transformers), Apache 2.0 License.
      Original Whisper model: https://github.com/openai/whisper (MIT License).
"""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, WhisperConfig
from transformers.modeling_outputs import BaseModelOutput

try:
    from transformers import WhisperEncoder
except ImportError:
    from transformers.models.whisper.modeling_whisper import WhisperEncoder

from .configuration_hyperclovax_seed_audio_encoder import HyperCLOVAXSeedAudioEncoderConfig


class HyperCLOVAXSeedAudioEncoder(PreTrainedModel):
    """Audio encoder based on WhisperEncoder with temporal pooling (50Hz -> 10Hz)."""

    config_class = HyperCLOVAXSeedAudioEncoderConfig
    supports_gradient_checkpointing = True

    def __init__(self, config: HyperCLOVAXSeedAudioEncoderConfig):
        super().__init__(config)

        # Whisper encoder (using HF WhisperEncoder directly)
        whisper_config = WhisperConfig(
            d_model=config.d_model,
            encoder_layers=config.encoder_layers,
            encoder_attention_heads=config.encoder_attention_heads,
            encoder_ffn_dim=config.encoder_ffn_dim,
            num_mel_bins=config.num_mel_bins,
            max_source_positions=config.max_source_positions,
            dropout=config.dropout,
            attention_dropout=config.attention_dropout,
        )
        self.encoder = WhisperEncoder(whisper_config)

        # Temporal pooling: 50Hz -> 10Hz
        self.temporal_pool = nn.Conv1d(
            config.d_model,
            config.d_model,
            kernel_size=config.pool_kernel_size,
            stride=config.pool_stride,
        )
        self.layer_norm = nn.LayerNorm(config.d_model)

        # Compatibility: modeling_vlm.py accesses audio_model.conv1.weight.{dtype,device}
        self.conv1 = self.encoder.conv1

        self.post_init()

    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.LongTensor
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:
        """Compute output sequence lengths after Whisper conv + temporal pooling.

        Whisper conv frontend:
            Conv1d(128, 768, k=3, s=1, p=1) -> same length
            Conv1d(768, 768, k=3, s=2, p=1) -> (L - 1) // 2 + 1

        Temporal pool:
            Conv1d(768, 768, k=pool_kernel_size, s=pool_stride) -> (L - k) // s + 1

        Args:
            input_lengths: (B,) number of valid mel frames per sample
        Returns:
            (feat_lengths, output_lengths) - encoder output lengths and post-pooling lengths
        """
        # After Whisper conv frontend (second conv has stride 2, padding 1)
        feat_lengths = (input_lengths - 1) // 2 + 1

        # After temporal pooling
        output_lengths = (feat_lengths - self.config.pool_kernel_size) // self.config.pool_stride + 1

        return feat_lengths, output_lengths

    def forward(
        self,
        input_features: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        """Encode mel spectrogram features and apply temporal pooling.

        Args:
            input_features: (B, num_mel_bins, T) mel spectrogram (128, 3000)
            attention_mask: (B, T) binary mask of valid mel frames; forwarded to WhisperEncoder.
        Returns:
            BaseModelOutput with last_hidden_state of shape (B, T_10hz, d_model)
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Whisper encoder: (B, 128, 3000) -> (B, 1500, 768)
        input_features = input_features.to(self.encoder.conv1.weight.dtype)
        encoder_output = self.encoder(
            input_features,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        x = encoder_output.last_hidden_state

        # Temporal pooling: (B, 1500, 768) -> (B, 300, 768)
        x = x.transpose(1, 2)  # (B, d_model, T_enc)
        x = self.temporal_pool(x)  # (B, d_model, T_pool)
        x = x.transpose(1, 2)  # (B, T_pool, d_model)
        x = self.layer_norm(x)

        if not return_dict:
            return (x,)
        return BaseModelOutput(last_hidden_state=x)


AutoModel.register(HyperCLOVAXSeedAudioEncoderConfig, HyperCLOVAXSeedAudioEncoder)

__all__ = ["HyperCLOVAXSeedAudioEncoder"]