HyperCLOVAX-SEED-Think-4B / audio_processing_hyperclovax_seed.py
bigshanedogg's picture
Upload folder using huggingface_hub
0c1d6f8 verified
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HyperCLOVAX-SEED Audio Processor
Implements Whisper-compatible audio feature extraction:
- Log-mel spectrogram extraction from waveform
- Chunked processing for long audio clips
- Attention mask generation for padded sequences
- Discrete audio token count calculation (conv-based)
"""
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
try:
from transformers.image_processing_utils import BatchFeature
except ImportError:
from transformers import BatchFeature
try:
from torchaudio.functional import melscale_fbanks as _melscale_fbanks
except (ImportError, AttributeError):
# fallback: transformers mel_filter_bank wrapped to return torch.Tensor
from transformers.audio_utils import mel_filter_bank as _mel_filter_bank
def _melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_scale):
return torch.from_numpy(_mel_filter_bank(
num_frequency_bins=n_freqs,
num_mel_filters=n_mels,
min_frequency=f_min,
max_frequency=f_max,
sampling_rate=sample_rate,
norm=norm,
mel_scale=mel_scale,
))
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
try:
from transformers.processing_utils import AudioKwargs
except ImportError:
from typing import TypedDict as AudioKwargs # transformers < 4.46
def _conv_output_length(
input_length: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 1,
dilation: int = 1,
) -> int:
"""Compute output length of a 1D convolution.
Formula: (input + 2*padding - dilation*(kernel-1) - 1) // stride + 1
"""
return (input_length + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
class HyperCLOVAXSeedAudioKwargs(AudioKwargs, total=False):
feature_size: Optional[int]
hop_length: Optional[int]
chunk_length: Optional[int]
n_fft: Optional[int]
n_samples: Optional[int]
nb_max_frames: Optional[int]
chunk_unit: Optional[int]
min_chunk_size: Optional[int]
dither: Optional[float]
# Token parameters
audio_token: Optional[str]
audio_start_token: Optional[str]
audio_end_token: Optional[str]
# Discrete audio parameters
use_discrete_token: Optional[bool]
discrete_audio_token: Optional[str]
discrete_audio_start_token: Optional[str]
discrete_audio_end_token: Optional[str]
class HyperCLOVAXSeedAudioProcessor(SequenceFeatureExtractor):
"""Audio processor for HyperCLOVAX-SEED.
Extracts Whisper-compatible log-mel spectrogram features and computes
attention masks for the audio encoder. Also supports discrete audio
token count calculation.
"""
model_input_names = ["audio_values", "audio_masks", "discrete_audio_values"]
def __init__(
self,
feature_size: int = 128,
sampling_rate: int = 16000,
hop_length: int = 160,
chunk_length: int = 30,
n_fft: int = 400,
padding_value: float = 0.0,
padding_side: str = "right",
dither: float = 0.0,
return_attention_mask: bool = False,
n_samples: int = 480000,
nb_max_frames: int = 3000,
chunk_unit: int = 80,
min_chunk_size: int = 1600,
# Temporal pooling parameters
pool_kernel_size: int = 5,
pool_stride: int = 5,
# Token parameters
audio_token: str = "<|AUDIO_PAD|>",
audio_start_token: str = "<|audio_start|>",
audio_end_token: str = "<|audio_end|>",
video_audio_pool_size: int = 25,
# Discrete audio parameters
use_discrete_token: bool = False,
discrete_audio_token: str = "<|DISCRETE_AUDIO_PAD|>",
discrete_audio_start_token: str = "<|discrete_audio_start|>",
discrete_audio_end_token: str = "<|discrete_audio_end|>",
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
hop_length=hop_length,
chunk_length=chunk_length,
n_fft=n_fft,
padding_value=padding_value,
padding_side=padding_side,
dither=dither,
return_attention_mask=return_attention_mask,
n_samples=n_samples,
nb_max_frames=nb_max_frames,
chunk_unit=chunk_unit,
min_chunk_size=min_chunk_size,
# Token parameters
audio_token=audio_token,
audio_start_token=audio_start_token,
audio_end_token=audio_end_token,
video_audio_pool_size=video_audio_pool_size,
pool_kernel_size=pool_kernel_size,
pool_stride=pool_stride,
# Discrete audio parameters
use_discrete_token=use_discrete_token,
discrete_audio_token=discrete_audio_token,
discrete_audio_start_token=discrete_audio_start_token,
discrete_audio_end_token=discrete_audio_end_token,
)
# Mel filter bank (Whisper-compatible) — torchaudio primary, transformers fallback
self.mel_filters = _melscale_fbanks(
n_freqs=1 + n_fft // 2,
f_min=0.0,
f_max=8000.0,
n_mels=feature_size,
sample_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
) # torch.Tensor, shape (n_freqs, n_mels)
def _extract_fbank_features(
self,
waveform_batch: np.ndarray,
device: str = "cpu",
) -> np.ndarray:
"""Extract log-mel spectrogram features from a waveform batch.
Follows the OpenAI Whisper feature extraction pipeline.
Reference: https://github.com/openai/whisper (MIT License)
Adapted from WhisperFeatureExtractor._torch_extract_fbank_features.
Args:
waveform_batch: Waveform array of shape (batch_size, n_samples).
device: Device for computation. Defaults to "cpu".
Returns:
Log-mel spectrogram of shape (batch_size, feature_size, num_frames).
"""
waveform = torch.from_numpy(waveform_batch).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)
if self.dither != 0.0:
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = self.mel_filters.to(device=device, dtype=torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device != "cpu":
log_spec = log_spec.detach().cpu()
return log_spec.numpy()
def _pad_and_extract_features(
self,
chunks: List[np.ndarray],
sampling_rate: int,
) -> dict:
"""Pad audio chunks and extract mel-spectrogram features.
Each chunk is padded to n_samples length, then mel-spectrogram is
extracted and an attention mask is generated.
Args:
chunks: List of 1D numpy arrays, each representing an audio chunk.
sampling_rate: Audio sampling rate.
Returns:
Dictionary with:
- "input_features": Array of shape (num_chunks, feature_size, nb_max_frames).
- "attention_mask": Array of shape (num_chunks, nb_max_frames).
"""
n_samples = self.chunk_length * sampling_rate
nb_max_frames = n_samples // self.hop_length
padded_waveforms = []
attention_masks = []
for chunk in chunks:
chunk = np.asarray(chunk, dtype=np.float32)
chunk_len = len(chunk)
# Pad or truncate
if chunk_len < n_samples:
padded = np.full(n_samples, self.padding_value, dtype=np.float32)
padded[:chunk_len] = chunk
else:
padded = chunk[:n_samples]
chunk_len = n_samples
padded_waveforms.append(padded)
# Attention mask (sample-level -> frame-level)
sample_mask = np.zeros(n_samples, dtype=np.int32)
sample_mask[:chunk_len] = 1
frame_mask = sample_mask[:: self.hop_length]
if len(frame_mask) > nb_max_frames:
frame_mask = frame_mask[:nb_max_frames]
elif len(frame_mask) < nb_max_frames:
frame_mask = np.pad(frame_mask, (0, nb_max_frames - len(frame_mask)))
attention_masks.append(frame_mask)
waveform_batch = np.stack(padded_waveforms, axis=0)
input_features = self._extract_fbank_features(waveform_batch)
attention_mask = np.stack(attention_masks, axis=0)
return {
"input_features": input_features,
"attention_mask": attention_mask,
}
def _get_feature_lengths(self, audio_masks: torch.Tensor) -> torch.Tensor:
"""Compute feature lengths after conv downsampling.
Args:
audio_masks: Attention mask of shape (batch, nb_max_frames).
Returns:
Feature lengths tensor of shape (batch,).
"""
return (audio_masks.sum(-1) - 1) // 2 + 1
def _get_attention_mask(self, audio_masks: torch.Tensor) -> torch.Tensor:
"""Generate attention mask for the audio encoder.
Creates a causal-style mask where padded positions are filled with -inf.
Args:
audio_masks: Attention mask of shape (batch, nb_max_frames).
Returns:
Attention mask of shape (batch, 1, max_seq_len, max_seq_len).
"""
feature_lengths = self._get_feature_lengths(audio_masks=audio_masks)
max_seq_len = (self.nb_max_frames - 2) // 2 + 1
padding_mask = torch.arange(max_seq_len) >= feature_lengths.unsqueeze(1)
attention_mask = padding_mask[:, None, None, :].expand(padding_mask.shape[0], 1, max_seq_len, max_seq_len)
attention_mask = attention_mask.masked_fill(attention_mask, float("-inf"))
return attention_mask
def _preprocess_continuous_audio(
self,
audio_clips: List[np.ndarray],
sampling_rate: Optional[int] = None,
chunk_length: Optional[int] = None,
) -> dict:
"""Preprocess audio clips for continuous audio features.
Splits each audio clip into chunks of chunk_length seconds, extracts
mel-spectrogram features, and computes token counts from attention masks.
Args:
audio_clips: List of audio clips, each a 1D numpy array (mono, float32).
sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
chunk_length: Chunk duration in seconds. Defaults to self.chunk_length.
Returns:
Dictionary with:
- "audio_values": Tensor of shape (num_total_chunks, feature_size, nb_max_frames).
- "audio_masks": Tensor of shape (num_total_chunks, nb_max_frames).
- "audio_attention_mask": Tensor of shape (num_total_chunks, max_seq_len, max_seq_len).
- "num_audio_tokens": Tensor of shape (N,) with per-clip continuous token counts.
"""
if sampling_rate is None:
sampling_rate = self.sampling_rate
if chunk_length is None:
chunk_length = self.chunk_length
if len(audio_clips) == 0:
max_seq_len = (self.nb_max_frames - 2) // 2 + 1
return {
"audio_values": torch.zeros(0, self.feature_size, self.nb_max_frames),
"audio_masks": torch.zeros(0, self.nb_max_frames),
"audio_attention_mask": torch.zeros(0, max_seq_len, max_seq_len),
"num_audio_tokens": torch.tensor([], dtype=torch.long),
}
_audio_values, _audio_masks, _num_audio_tokens = [], [], []
for _audio in audio_clips:
chunks = []
chunk_samples = chunk_length * sampling_rate
for i in range(0, len(_audio), chunk_samples):
chunks.append(_audio[i : i + chunk_samples])
result = self._pad_and_extract_features(chunks, sampling_rate)
_audio_value = result["input_features"]
_audio_mask = result["attention_mask"]
_num_audio_token = 0
for _mask in _audio_mask:
_input_length = (_mask.shape[-1] - 1) // 2 + 1
_num_audio_token += (_input_length - self.pool_kernel_size) // self.pool_stride + 1
_audio_values.append(torch.from_numpy(_audio_value))
_audio_masks.append(torch.from_numpy(_audio_mask))
_num_audio_tokens.append(_num_audio_token)
_audio_values = torch.cat(_audio_values, dim=0)
_audio_masks = torch.cat(_audio_masks, dim=0)
_audio_attention_mask = self._get_attention_mask(audio_masks=_audio_masks)
return {
"audio_values": _audio_values,
"audio_masks": _audio_masks,
"audio_attention_mask": _audio_attention_mask,
"num_audio_tokens": torch.tensor(_num_audio_tokens, dtype=torch.long),
}
def _preprocess_discrete_audio(
self,
audio_clips: List[np.ndarray],
sampling_rate: Optional[int] = None,
chunk_unit: Optional[int] = None,
min_chunk_size: Optional[int] = None,
) -> dict:
"""Preprocess audio clips for discrete audio tokens.
Validates each audio clip and computes the number of discrete tokens
based on conv layer downsampling. Returns padded waveform tensors.
Args:
audio_clips: List of audio clips, each a 1D numpy array (mono, float32).
sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
chunk_unit: Chunk duration in seconds for long audio. Defaults to self.chunk_unit.
min_chunk_size: Minimum audio length in samples. Defaults to self.min_chunk_size.
Returns:
Dictionary with:
- "discrete_audio_values": Tensor of shape (N, max_audio_len).
- "num_discrete_audio_tokens": Tensor of shape (N,) with per-clip discrete token counts.
"""
if sampling_rate is None:
sampling_rate = self.sampling_rate
if chunk_unit is None:
chunk_unit = self.chunk_unit
if min_chunk_size is None:
min_chunk_size = self.min_chunk_size
_discrete_audio_values, _num_discrete_audio_tokens = [], []
for _audio in audio_clips:
audio_length = len(_audio)
max_audio_length = 600 * sampling_rate
audio_duration_sec = audio_length / sampling_rate
if audio_length < min_chunk_size:
raise ValueError(f"Discrete audio too short: {audio_length}")
if np.isnan(_audio).any() or np.isinf(_audio).any():
raise ValueError("Discrete audio contains NaN/Inf")
if audio_length > max_audio_length:
raise ValueError(
f"Discrete audio too long: {audio_length} samples = ({audio_duration_sec:.2f}s > 600s)"
)
audio_min, audio_max = _audio.min().item(), _audio.max().item()
if audio_min < -100.0 or audio_max > 100.0:
raise ValueError(f"Discrete audio values out of range: min {audio_min}, max {audio_max}")
if audio_length > chunk_unit * sampling_rate:
total_code_len = 0
chunk_size = chunk_unit * sampling_rate
for start in range(0, audio_length, chunk_size):
end = min(start + chunk_size, audio_length)
if end < audio_length and audio_length - end < min_chunk_size:
end = audio_length
chunk_len = end - start
mel_len = chunk_len // self.hop_length
after_conv1 = _conv_output_length(mel_len)
code_len = _conv_output_length(after_conv1)
total_code_len += code_len
if end >= audio_length:
break
_num_discrete = total_code_len
else:
mel_len = audio_length // self.hop_length
after_conv1 = _conv_output_length(mel_len)
code_len = _conv_output_length(after_conv1)
_num_discrete = code_len
_discrete_audio_values.append(torch.tensor(_audio))
_num_discrete_audio_tokens.append(_num_discrete)
return {
"discrete_audio_values": torch.cat(_discrete_audio_values, dim=0),
"num_discrete_audio_tokens": torch.tensor(_num_discrete_audio_tokens, dtype=torch.long),
}
def preprocess(
self,
audios: List[np.ndarray],
sampling_rate: Optional[int] = None,
chunk_length: Optional[int] = None,
chunk_unit: Optional[int] = None,
min_chunk_size: Optional[int] = None,
use_discrete_token: Optional[bool] = None,
prefix: Optional[str] = None,
**kwargs,
) -> BatchFeature:
"""Preprocess a list of audio clips.
Resolves all kwargs at the entry point, then routes to
``_preprocess_continuous_audio`` and optionally
``_preprocess_discrete_audio``.
Args:
audios: List of audio clips, each a 1D numpy array.
sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
chunk_length: Chunk duration in seconds for continuous processing.
Defaults to self.chunk_length.
chunk_unit: Chunk duration in seconds for discrete processing.
Defaults to self.chunk_unit.
min_chunk_size: Minimum audio length in samples for discrete processing.
Defaults to self.min_chunk_size.
use_discrete_token: Whether to run discrete audio processing.
Defaults to self.use_discrete_token.
prefix: Optional string to prefix all output keys. Keys starting with
``"num_"`` get the prefix inserted after ``"num_"`` (e.g. prefix
``"video_"`` turns ``"num_audio_tokens"`` into
``"num_video_audio_tokens"``); all other keys are simply prepended
(e.g. ``"audio_values"`` → ``"video_audio_values"``).
``None`` (default) leaves keys unchanged.
Returns:
BatchFeature with:
- audio_values: Tensor of shape (num_total_chunks, feature_size, nb_max_frames).
- audio_masks: Tensor of shape (num_total_chunks, nb_max_frames).
- audio_attention_mask: Tensor of shape (num_total_chunks, max_seq_len, max_seq_len).
- num_audio_tokens: Tensor of shape (N,) with per-clip continuous token counts.
- discrete_audio_values (optional): Tensor of shape (N, max_audio_len).
- num_discrete_audio_tokens (optional): Tensor of shape (N,) with per-clip discrete token counts.
All keys are renamed according to ``prefix`` when provided.
"""
# 1. Resolve all kwargs at the entry point
sampling_rate = sampling_rate if sampling_rate is not None else self.sampling_rate
chunk_length = chunk_length if chunk_length is not None else self.chunk_length
chunk_unit = chunk_unit if chunk_unit is not None else self.chunk_unit
min_chunk_size = min_chunk_size if min_chunk_size is not None else self.min_chunk_size
use_discrete = use_discrete_token if use_discrete_token is not None else self.use_discrete_token
# 2. Route to continuous sub-processor
continuous_result = self._preprocess_continuous_audio(
audios,
sampling_rate=sampling_rate,
chunk_length=chunk_length,
)
data = {
"audio_values": continuous_result["audio_values"],
"audio_attention_mask": continuous_result["audio_attention_mask"],
"audio_masks": continuous_result["audio_masks"],
"num_audio_tokens": continuous_result["num_audio_tokens"],
}
# 3. Optionally route to discrete sub-processor
if use_discrete:
discrete_result = self._preprocess_discrete_audio(
audios,
sampling_rate=sampling_rate,
chunk_unit=chunk_unit,
min_chunk_size=min_chunk_size,
)
data["discrete_audio_values"] = discrete_result["discrete_audio_values"]
data["num_discrete_audio_tokens"] = discrete_result["num_discrete_audio_tokens"]
if prefix is not None:
data = {
(f"num_{prefix}{k[len('num_'):]}" if k.startswith("num_") else f"{prefix}{k}"): v
for k, v in data.items()
}
return BatchFeature(data=data, tensor_type="pt")
def __call__(self, audios: List[np.ndarray], **kwargs) -> BatchFeature:
"""Alias for :meth:`preprocess`."""
return self.preprocess(audios, **kwargs)
def get_num_audio_tokens(
self,
audio_masks: torch.Tensor,
discrete_audio_values: Optional[torch.Tensor] = None,
include_boundary_tokens: bool = False,
chunk_unit: Optional[int] = None,
sampling_rate: Optional[int] = None,
return_tuple: Optional[bool] = None,
) -> Union[int, Tuple[int, int]]:
"""Compute the number of audio tokens for the given input.
Args:
audio_masks: Attention mask for continuous audio. Shape (N,) or (num_chunks, N).
discrete_audio_values: Discrete audio waveform. None to skip discrete computation.
include_boundary_tokens: Whether to include start/end boundary tokens.
chunk_unit: Chunk duration in seconds for discrete processing.
Defaults to self.chunk_unit.
sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
return_tuple: If True, return (continuous, discrete) tuple.
Otherwise return the sum.
Returns:
Token count as int, or (continuous, discrete) tuple if return_tuple is True.
"""
chunk_unit = chunk_unit if chunk_unit is not None else self.chunk_unit
sampling_rate = sampling_rate if sampling_rate is not None else self.sampling_rate
def _compute_continuous_tokens(audio_mask: torch.Tensor) -> int:
input_length = (audio_mask.shape[-1] - 1) // 2 + 1
return (input_length - self.pool_kernel_size) // self.pool_stride + 1
num_continuous_tokens, num_discrete_tokens = 0, 0
if len(audio_masks.shape) == 1:
num_continuous_tokens = _compute_continuous_tokens(audio_masks)
else:
num_continuous_tokens = sum(_compute_continuous_tokens(m) for m in audio_masks)
if include_boundary_tokens:
num_continuous_tokens += 2
if self.use_discrete_token and discrete_audio_values is not None:
audio_length = len(discrete_audio_values)
chunk_size = chunk_unit * sampling_rate
for _start in range(0, audio_length, chunk_size):
_end = min(_start + chunk_size, audio_length)
_chunked_length = _end - _start
mel_len = _chunked_length // self.hop_length
after_conv1 = _conv_output_length(mel_len)
code_len = _conv_output_length(after_conv1)
num_discrete_tokens += code_len
if include_boundary_tokens:
num_discrete_tokens += 2
if return_tuple:
return (num_continuous_tokens, num_discrete_tokens)
else:
return num_continuous_tokens + num_discrete_tokens