File size: 24,914 Bytes
0c1d6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HyperCLOVAX-SEED Audio Processor

Implements Whisper-compatible audio feature extraction:
- Log-mel spectrogram extraction from waveform
- Chunked processing for long audio clips
- Attention mask generation for padded sequences
- Discrete audio token count calculation (conv-based)
"""

from typing import List, Optional, Tuple, Union

import numpy as np
import torch
try:
    from transformers.image_processing_utils import BatchFeature
except ImportError:
    from transformers import BatchFeature
try:
    from torchaudio.functional import melscale_fbanks as _melscale_fbanks
except (ImportError, AttributeError):
    # fallback: transformers mel_filter_bank wrapped to return torch.Tensor
    from transformers.audio_utils import mel_filter_bank as _mel_filter_bank
    def _melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_scale):
        return torch.from_numpy(_mel_filter_bank(
            num_frequency_bins=n_freqs,
            num_mel_filters=n_mels,
            min_frequency=f_min,
            max_frequency=f_max,
            sampling_rate=sample_rate,
            norm=norm,
            mel_scale=mel_scale,
        ))
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
try:
    from transformers.processing_utils import AudioKwargs
except ImportError:
    from typing import TypedDict as AudioKwargs  # transformers < 4.46


def _conv_output_length(
    input_length: int,
    kernel_size: int = 3,
    stride: int = 2,
    padding: int = 1,
    dilation: int = 1,
) -> int:
    """Compute output length of a 1D convolution.

    Formula: (input + 2*padding - dilation*(kernel-1) - 1) // stride + 1
    """
    return (input_length + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1


class HyperCLOVAXSeedAudioKwargs(AudioKwargs, total=False):
    feature_size: Optional[int]
    hop_length: Optional[int]
    chunk_length: Optional[int]
    n_fft: Optional[int]
    n_samples: Optional[int]
    nb_max_frames: Optional[int]
    chunk_unit: Optional[int]
    min_chunk_size: Optional[int]
    dither: Optional[float]
    # Token parameters
    audio_token: Optional[str]
    audio_start_token: Optional[str]
    audio_end_token: Optional[str]
    # Discrete audio parameters
    use_discrete_token: Optional[bool]
    discrete_audio_token: Optional[str]
    discrete_audio_start_token: Optional[str]
    discrete_audio_end_token: Optional[str]


class HyperCLOVAXSeedAudioProcessor(SequenceFeatureExtractor):
    """Audio processor for HyperCLOVAX-SEED.

    Extracts Whisper-compatible log-mel spectrogram features and computes
    attention masks for the audio encoder. Also supports discrete audio
    token count calculation.
    """

    model_input_names = ["audio_values", "audio_masks", "discrete_audio_values"]

    def __init__(
        self,
        feature_size: int = 128,
        sampling_rate: int = 16000,
        hop_length: int = 160,
        chunk_length: int = 30,
        n_fft: int = 400,
        padding_value: float = 0.0,
        padding_side: str = "right",
        dither: float = 0.0,
        return_attention_mask: bool = False,
        n_samples: int = 480000,
        nb_max_frames: int = 3000,
        chunk_unit: int = 80,
        min_chunk_size: int = 1600,
        # Temporal pooling parameters
        pool_kernel_size: int = 5,
        pool_stride: int = 5,
        # Token parameters
        audio_token: str = "<|AUDIO_PAD|>",
        audio_start_token: str = "<|audio_start|>",
        audio_end_token: str = "<|audio_end|>",
        video_audio_pool_size: int = 25,
        # Discrete audio parameters
        use_discrete_token: bool = False,
        discrete_audio_token: str = "<|DISCRETE_AUDIO_PAD|>",
        discrete_audio_start_token: str = "<|discrete_audio_start|>",
        discrete_audio_end_token: str = "<|discrete_audio_end|>",
        **kwargs,
    ):
        super().__init__(
            feature_size=feature_size,
            sampling_rate=sampling_rate,
            hop_length=hop_length,
            chunk_length=chunk_length,
            n_fft=n_fft,
            padding_value=padding_value,
            padding_side=padding_side,
            dither=dither,
            return_attention_mask=return_attention_mask,
            n_samples=n_samples,
            nb_max_frames=nb_max_frames,
            chunk_unit=chunk_unit,
            min_chunk_size=min_chunk_size,
            # Token parameters
            audio_token=audio_token,
            audio_start_token=audio_start_token,
            audio_end_token=audio_end_token,
            video_audio_pool_size=video_audio_pool_size,
            pool_kernel_size=pool_kernel_size,
            pool_stride=pool_stride,
            # Discrete audio parameters
            use_discrete_token=use_discrete_token,
            discrete_audio_token=discrete_audio_token,
            discrete_audio_start_token=discrete_audio_start_token,
            discrete_audio_end_token=discrete_audio_end_token,
        )

        # Mel filter bank (Whisper-compatible) — torchaudio primary, transformers fallback
        self.mel_filters = _melscale_fbanks(
            n_freqs=1 + n_fft // 2,
            f_min=0.0,
            f_max=8000.0,
            n_mels=feature_size,
            sample_rate=sampling_rate,
            norm="slaney",
            mel_scale="slaney",
        )  # torch.Tensor, shape (n_freqs, n_mels)

    def _extract_fbank_features(
        self,
        waveform_batch: np.ndarray,
        device: str = "cpu",
    ) -> np.ndarray:
        """Extract log-mel spectrogram features from a waveform batch.

        Follows the OpenAI Whisper feature extraction pipeline.
        Reference: https://github.com/openai/whisper (MIT License)

        Adapted from WhisperFeatureExtractor._torch_extract_fbank_features.

        Args:
            waveform_batch: Waveform array of shape (batch_size, n_samples).
            device: Device for computation. Defaults to "cpu".

        Returns:
            Log-mel spectrogram of shape (batch_size, feature_size, num_frames).
        """
        waveform = torch.from_numpy(waveform_batch).to(device, torch.float32)
        window = torch.hann_window(self.n_fft, device=device)

        if self.dither != 0.0:
            waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)

        stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
        magnitudes = stft[..., :-1].abs() ** 2

        mel_filters = self.mel_filters.to(device=device, dtype=torch.float32)
        mel_spec = mel_filters.T @ magnitudes

        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        if waveform.dim() == 2:
            max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
            log_spec = torch.maximum(log_spec, max_val - 8.0)
        else:
            log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0

        if device != "cpu":
            log_spec = log_spec.detach().cpu()

        return log_spec.numpy()

    def _pad_and_extract_features(
        self,
        chunks: List[np.ndarray],
        sampling_rate: int,
    ) -> dict:
        """Pad audio chunks and extract mel-spectrogram features.

        Each chunk is padded to n_samples length, then mel-spectrogram is
        extracted and an attention mask is generated.

        Args:
            chunks: List of 1D numpy arrays, each representing an audio chunk.
            sampling_rate: Audio sampling rate.

        Returns:
            Dictionary with:
                - "input_features": Array of shape (num_chunks, feature_size, nb_max_frames).
                - "attention_mask": Array of shape (num_chunks, nb_max_frames).
        """
        n_samples = self.chunk_length * sampling_rate
        nb_max_frames = n_samples // self.hop_length

        padded_waveforms = []
        attention_masks = []

        for chunk in chunks:
            chunk = np.asarray(chunk, dtype=np.float32)
            chunk_len = len(chunk)

            # Pad or truncate
            if chunk_len < n_samples:
                padded = np.full(n_samples, self.padding_value, dtype=np.float32)
                padded[:chunk_len] = chunk
            else:
                padded = chunk[:n_samples]
                chunk_len = n_samples

            padded_waveforms.append(padded)

            # Attention mask (sample-level -> frame-level)
            sample_mask = np.zeros(n_samples, dtype=np.int32)
            sample_mask[:chunk_len] = 1
            frame_mask = sample_mask[:: self.hop_length]
            if len(frame_mask) > nb_max_frames:
                frame_mask = frame_mask[:nb_max_frames]
            elif len(frame_mask) < nb_max_frames:
                frame_mask = np.pad(frame_mask, (0, nb_max_frames - len(frame_mask)))
            attention_masks.append(frame_mask)

        waveform_batch = np.stack(padded_waveforms, axis=0)
        input_features = self._extract_fbank_features(waveform_batch)
        attention_mask = np.stack(attention_masks, axis=0)

        return {
            "input_features": input_features,
            "attention_mask": attention_mask,
        }

    def _get_feature_lengths(self, audio_masks: torch.Tensor) -> torch.Tensor:
        """Compute feature lengths after conv downsampling.

        Args:
            audio_masks: Attention mask of shape (batch, nb_max_frames).

        Returns:
            Feature lengths tensor of shape (batch,).
        """
        return (audio_masks.sum(-1) - 1) // 2 + 1

    def _get_attention_mask(self, audio_masks: torch.Tensor) -> torch.Tensor:
        """Generate attention mask for the audio encoder.

        Creates a causal-style mask where padded positions are filled with -inf.

        Args:
            audio_masks: Attention mask of shape (batch, nb_max_frames).

        Returns:
            Attention mask of shape (batch, 1, max_seq_len, max_seq_len).
        """
        feature_lengths = self._get_feature_lengths(audio_masks=audio_masks)
        max_seq_len = (self.nb_max_frames - 2) // 2 + 1
        padding_mask = torch.arange(max_seq_len) >= feature_lengths.unsqueeze(1)
        attention_mask = padding_mask[:, None, None, :].expand(padding_mask.shape[0], 1, max_seq_len, max_seq_len)
        attention_mask = attention_mask.masked_fill(attention_mask, float("-inf"))
        return attention_mask

    def _preprocess_continuous_audio(
        self,
        audio_clips: List[np.ndarray],
        sampling_rate: Optional[int] = None,
        chunk_length: Optional[int] = None,
    ) -> dict:
        """Preprocess audio clips for continuous audio features.

        Splits each audio clip into chunks of chunk_length seconds, extracts
        mel-spectrogram features, and computes token counts from attention masks.

        Args:
            audio_clips: List of audio clips, each a 1D numpy array (mono, float32).
            sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
            chunk_length: Chunk duration in seconds. Defaults to self.chunk_length.

        Returns:
            Dictionary with:
                - "audio_values": Tensor of shape (num_total_chunks, feature_size, nb_max_frames).
                - "audio_masks": Tensor of shape (num_total_chunks, nb_max_frames).
                - "audio_attention_mask": Tensor of shape (num_total_chunks, max_seq_len, max_seq_len).
                - "num_audio_tokens": Tensor of shape (N,) with per-clip continuous token counts.
        """
        if sampling_rate is None:
            sampling_rate = self.sampling_rate
        if chunk_length is None:
            chunk_length = self.chunk_length

        if len(audio_clips) == 0:
            max_seq_len = (self.nb_max_frames - 2) // 2 + 1
            return {
                "audio_values": torch.zeros(0, self.feature_size, self.nb_max_frames),
                "audio_masks": torch.zeros(0, self.nb_max_frames),
                "audio_attention_mask": torch.zeros(0, max_seq_len, max_seq_len),
                "num_audio_tokens": torch.tensor([], dtype=torch.long),
            }

        _audio_values, _audio_masks, _num_audio_tokens = [], [], []
        for _audio in audio_clips:
            chunks = []
            chunk_samples = chunk_length * sampling_rate
            for i in range(0, len(_audio), chunk_samples):
                chunks.append(_audio[i : i + chunk_samples])

            result = self._pad_and_extract_features(chunks, sampling_rate)

            _audio_value = result["input_features"]
            _audio_mask = result["attention_mask"]
            _num_audio_token = 0
            for _mask in _audio_mask:
                _input_length = (_mask.shape[-1] - 1) // 2 + 1
                _num_audio_token += (_input_length - self.pool_kernel_size) // self.pool_stride + 1

            _audio_values.append(torch.from_numpy(_audio_value))
            _audio_masks.append(torch.from_numpy(_audio_mask))
            _num_audio_tokens.append(_num_audio_token)

        _audio_values = torch.cat(_audio_values, dim=0)
        _audio_masks = torch.cat(_audio_masks, dim=0)
        _audio_attention_mask = self._get_attention_mask(audio_masks=_audio_masks)

        return {
            "audio_values": _audio_values,
            "audio_masks": _audio_masks,
            "audio_attention_mask": _audio_attention_mask,
            "num_audio_tokens": torch.tensor(_num_audio_tokens, dtype=torch.long),
        }

    def _preprocess_discrete_audio(
        self,
        audio_clips: List[np.ndarray],
        sampling_rate: Optional[int] = None,
        chunk_unit: Optional[int] = None,
        min_chunk_size: Optional[int] = None,
    ) -> dict:
        """Preprocess audio clips for discrete audio tokens.

        Validates each audio clip and computes the number of discrete tokens
        based on conv layer downsampling. Returns padded waveform tensors.

        Args:
            audio_clips: List of audio clips, each a 1D numpy array (mono, float32).
            sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
            chunk_unit: Chunk duration in seconds for long audio. Defaults to self.chunk_unit.
            min_chunk_size: Minimum audio length in samples. Defaults to self.min_chunk_size.

        Returns:
            Dictionary with:
                - "discrete_audio_values": Tensor of shape (N, max_audio_len).
                - "num_discrete_audio_tokens": Tensor of shape (N,) with per-clip discrete token counts.
        """
        if sampling_rate is None:
            sampling_rate = self.sampling_rate
        if chunk_unit is None:
            chunk_unit = self.chunk_unit
        if min_chunk_size is None:
            min_chunk_size = self.min_chunk_size

        _discrete_audio_values, _num_discrete_audio_tokens = [], []
        for _audio in audio_clips:
            audio_length = len(_audio)
            max_audio_length = 600 * sampling_rate
            audio_duration_sec = audio_length / sampling_rate

            if audio_length < min_chunk_size:
                raise ValueError(f"Discrete audio too short: {audio_length}")
            if np.isnan(_audio).any() or np.isinf(_audio).any():
                raise ValueError("Discrete audio contains NaN/Inf")
            if audio_length > max_audio_length:
                raise ValueError(
                    f"Discrete audio too long: {audio_length} samples = ({audio_duration_sec:.2f}s > 600s)"
                )

            audio_min, audio_max = _audio.min().item(), _audio.max().item()
            if audio_min < -100.0 or audio_max > 100.0:
                raise ValueError(f"Discrete audio values out of range: min {audio_min}, max {audio_max}")

            if audio_length > chunk_unit * sampling_rate:
                total_code_len = 0
                chunk_size = chunk_unit * sampling_rate
                for start in range(0, audio_length, chunk_size):
                    end = min(start + chunk_size, audio_length)
                    if end < audio_length and audio_length - end < min_chunk_size:
                        end = audio_length
                    chunk_len = end - start
                    mel_len = chunk_len // self.hop_length
                    after_conv1 = _conv_output_length(mel_len)
                    code_len = _conv_output_length(after_conv1)
                    total_code_len += code_len
                    if end >= audio_length:
                        break
                _num_discrete = total_code_len
            else:
                mel_len = audio_length // self.hop_length
                after_conv1 = _conv_output_length(mel_len)
                code_len = _conv_output_length(after_conv1)
                _num_discrete = code_len

            _discrete_audio_values.append(torch.tensor(_audio))
            _num_discrete_audio_tokens.append(_num_discrete)

        return {
            "discrete_audio_values": torch.cat(_discrete_audio_values, dim=0),
            "num_discrete_audio_tokens": torch.tensor(_num_discrete_audio_tokens, dtype=torch.long),
        }

    def preprocess(
        self,
        audios: List[np.ndarray],
        sampling_rate: Optional[int] = None,
        chunk_length: Optional[int] = None,
        chunk_unit: Optional[int] = None,
        min_chunk_size: Optional[int] = None,
        use_discrete_token: Optional[bool] = None,
        prefix: Optional[str] = None,
        **kwargs,
    ) -> BatchFeature:
        """Preprocess a list of audio clips.

        Resolves all kwargs at the entry point, then routes to
        ``_preprocess_continuous_audio`` and optionally
        ``_preprocess_discrete_audio``.

        Args:
            audios: List of audio clips, each a 1D numpy array.
            sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
            chunk_length: Chunk duration in seconds for continuous processing.
                Defaults to self.chunk_length.
            chunk_unit: Chunk duration in seconds for discrete processing.
                Defaults to self.chunk_unit.
            min_chunk_size: Minimum audio length in samples for discrete processing.
                Defaults to self.min_chunk_size.
            use_discrete_token: Whether to run discrete audio processing.
                Defaults to self.use_discrete_token.
            prefix: Optional string to prefix all output keys. Keys starting with
                ``"num_"`` get the prefix inserted after ``"num_"`` (e.g. prefix
                ``"video_"`` turns ``"num_audio_tokens"`` into
                ``"num_video_audio_tokens"``); all other keys are simply prepended
                (e.g. ``"audio_values"`` → ``"video_audio_values"``).
                ``None`` (default) leaves keys unchanged.

        Returns:
            BatchFeature with:
                - audio_values: Tensor of shape (num_total_chunks, feature_size, nb_max_frames).
                - audio_masks: Tensor of shape (num_total_chunks, nb_max_frames).
                - audio_attention_mask: Tensor of shape (num_total_chunks, max_seq_len, max_seq_len).
                - num_audio_tokens: Tensor of shape (N,) with per-clip continuous token counts.
                - discrete_audio_values (optional): Tensor of shape (N, max_audio_len).
                - num_discrete_audio_tokens (optional): Tensor of shape (N,) with per-clip discrete token counts.

            All keys are renamed according to ``prefix`` when provided.
        """
        # 1. Resolve all kwargs at the entry point
        sampling_rate = sampling_rate if sampling_rate is not None else self.sampling_rate
        chunk_length = chunk_length if chunk_length is not None else self.chunk_length
        chunk_unit = chunk_unit if chunk_unit is not None else self.chunk_unit
        min_chunk_size = min_chunk_size if min_chunk_size is not None else self.min_chunk_size
        use_discrete = use_discrete_token if use_discrete_token is not None else self.use_discrete_token

        # 2. Route to continuous sub-processor
        continuous_result = self._preprocess_continuous_audio(
            audios,
            sampling_rate=sampling_rate,
            chunk_length=chunk_length,
        )
        data = {
            "audio_values": continuous_result["audio_values"],
            "audio_attention_mask": continuous_result["audio_attention_mask"],
            "audio_masks": continuous_result["audio_masks"],
            "num_audio_tokens": continuous_result["num_audio_tokens"],
        }

        # 3. Optionally route to discrete sub-processor
        if use_discrete:
            discrete_result = self._preprocess_discrete_audio(
                audios,
                sampling_rate=sampling_rate,
                chunk_unit=chunk_unit,
                min_chunk_size=min_chunk_size,
            )
            data["discrete_audio_values"] = discrete_result["discrete_audio_values"]
            data["num_discrete_audio_tokens"] = discrete_result["num_discrete_audio_tokens"]

        if prefix is not None:
            data = {
                (f"num_{prefix}{k[len('num_'):]}" if k.startswith("num_") else f"{prefix}{k}"): v
                for k, v in data.items()
            }

        return BatchFeature(data=data, tensor_type="pt")

    def __call__(self, audios: List[np.ndarray], **kwargs) -> BatchFeature:
        """Alias for :meth:`preprocess`."""
        return self.preprocess(audios, **kwargs)

    def get_num_audio_tokens(
        self,
        audio_masks: torch.Tensor,
        discrete_audio_values: Optional[torch.Tensor] = None,
        include_boundary_tokens: bool = False,
        chunk_unit: Optional[int] = None,
        sampling_rate: Optional[int] = None,
        return_tuple: Optional[bool] = None,
    ) -> Union[int, Tuple[int, int]]:
        """Compute the number of audio tokens for the given input.

        Args:
            audio_masks: Attention mask for continuous audio. Shape (N,) or (num_chunks, N).
            discrete_audio_values: Discrete audio waveform. None to skip discrete computation.
            include_boundary_tokens: Whether to include start/end boundary tokens.
            chunk_unit: Chunk duration in seconds for discrete processing.
                Defaults to self.chunk_unit.
            sampling_rate: Audio sampling rate. Defaults to self.sampling_rate.
            return_tuple: If True, return (continuous, discrete) tuple.
                Otherwise return the sum.

        Returns:
            Token count as int, or (continuous, discrete) tuple if return_tuple is True.
        """
        chunk_unit = chunk_unit if chunk_unit is not None else self.chunk_unit
        sampling_rate = sampling_rate if sampling_rate is not None else self.sampling_rate

        def _compute_continuous_tokens(audio_mask: torch.Tensor) -> int:
            input_length = (audio_mask.shape[-1] - 1) // 2 + 1
            return (input_length - self.pool_kernel_size) // self.pool_stride + 1

        num_continuous_tokens, num_discrete_tokens = 0, 0
        if len(audio_masks.shape) == 1:
            num_continuous_tokens = _compute_continuous_tokens(audio_masks)
        else:
            num_continuous_tokens = sum(_compute_continuous_tokens(m) for m in audio_masks)
        if include_boundary_tokens:
            num_continuous_tokens += 2

        if self.use_discrete_token and discrete_audio_values is not None:
            audio_length = len(discrete_audio_values)
            chunk_size = chunk_unit * sampling_rate
            for _start in range(0, audio_length, chunk_size):
                _end = min(_start + chunk_size, audio_length)
                _chunked_length = _end - _start
                mel_len = _chunked_length // self.hop_length
                after_conv1 = _conv_output_length(mel_len)
                code_len = _conv_output_length(after_conv1)
                num_discrete_tokens += code_len
            if include_boundary_tokens:
                num_discrete_tokens += 2

        if return_tuple:
            return (num_continuous_tokens, num_discrete_tokens)
        else:
            return num_continuous_tokens + num_discrete_tokens