Commit
·
eff8e23
1
Parent(s):
0750e3c
add support for optional num_quantizers argument
Browse files- modeling_moss_audio_tokenizer.py +50 -19
modeling_moss_audio_tokenizer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
#
|
| 4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -26,9 +25,8 @@ import torch
|
|
| 26 |
import torch.nn as nn
|
| 27 |
import torch.nn.functional as F
|
| 28 |
|
| 29 |
-
from
|
| 30 |
-
from
|
| 31 |
-
|
| 32 |
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 33 |
|
| 34 |
|
|
@@ -372,7 +370,7 @@ def create_sin_embedding(
|
|
| 372 |
dim: int,
|
| 373 |
max_period: float = 10000,
|
| 374 |
dtype: torch.dtype = torch.float32,
|
| 375 |
-
):
|
| 376 |
"""Create sinusoidal positional embedding with shape [B, T, C]."""
|
| 377 |
if dim % 2 != 0:
|
| 378 |
raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
|
|
@@ -406,7 +404,7 @@ class KVCacheResult:
|
|
| 406 |
return iter((self.keys, self.values, self.positions))
|
| 407 |
|
| 408 |
@staticmethod
|
| 409 |
-
def from_kv(keys: torch.Tensor, values: torch.Tensor) ->
|
| 410 |
B, H, T, D = keys.shape
|
| 411 |
positions = torch.arange(T, device=keys.device, dtype=torch.long)
|
| 412 |
return KVCacheResult(keys, values, positions.expand(B, -1))
|
|
@@ -506,7 +504,7 @@ def apply_weights_per_step(
|
|
| 506 |
schedule: list[int] | None,
|
| 507 |
x: torch.Tensor,
|
| 508 |
offset: int | None,
|
| 509 |
-
):
|
| 510 |
"""Apply different weights for each time step."""
|
| 511 |
if len(modules) == 1:
|
| 512 |
return modules[0](x)
|
|
@@ -1088,7 +1086,6 @@ class MossAudioTokenizerLFQ(nn.Module):
|
|
| 1088 |
- 2 * encodings @ codebook.t()
|
| 1089 |
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 1090 |
)
|
| 1091 |
-
|
| 1092 |
indices = (-dist).max(1)[1]
|
| 1093 |
indices = indices.reshape(latents.size(0), -1)
|
| 1094 |
z_q = self.decode_code_wo_out_proj(indices).float()
|
|
@@ -1306,10 +1303,6 @@ class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase):
|
|
| 1306 |
"MossAudioTokenizerResidualLFQ",
|
| 1307 |
]
|
| 1308 |
|
| 1309 |
-
def _init_weights(self, module: nn.Module) -> None:
|
| 1310 |
-
if isinstance(module, MossAudioTokenizerLayerScale):
|
| 1311 |
-
nn.init.constant_(module.scale, 1e-4)
|
| 1312 |
-
|
| 1313 |
|
| 1314 |
@auto_docstring(
|
| 1315 |
custom_intro="""
|
|
@@ -1348,7 +1341,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1348 |
context=int(current_frame_rate * self.causal_transformer_context_duration),
|
| 1349 |
)
|
| 1350 |
)
|
| 1351 |
-
current_frame_rate /=
|
| 1352 |
|
| 1353 |
# Build quantizer
|
| 1354 |
quantizer_kwargs = dict(config.quantizer_kwargs)
|
|
@@ -1375,7 +1368,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1375 |
context=int(current_frame_rate * self.causal_transformer_context_duration),
|
| 1376 |
)
|
| 1377 |
)
|
| 1378 |
-
current_frame_rate *=
|
| 1379 |
|
| 1380 |
self.post_init()
|
| 1381 |
|
|
@@ -1407,11 +1400,14 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1407 |
self._stop_streaming()
|
| 1408 |
|
| 1409 |
@torch.no_grad()
|
| 1410 |
-
def batch_encode(
|
|
|
|
|
|
|
| 1411 |
"""Batch encode a list of audio waveforms.
|
| 1412 |
|
| 1413 |
Args:
|
| 1414 |
wav_list: List of audio tensors, each of shape `(num_samples,)`.
|
|
|
|
| 1415 |
|
| 1416 |
Returns:
|
| 1417 |
[`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
|
|
@@ -1430,14 +1426,18 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1430 |
input_values[i, 0, : wav.shape[-1]] = wav
|
| 1431 |
input_lengths[i] = wav.shape[-1]
|
| 1432 |
|
| 1433 |
-
return self._encode_frame(input_values, input_lengths)
|
| 1434 |
|
| 1435 |
@torch.no_grad()
|
| 1436 |
-
def batch_decode(
|
|
|
|
|
|
|
| 1437 |
"""Batch decode a list of audio codes.
|
| 1438 |
|
| 1439 |
Args:
|
| 1440 |
codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
|
|
|
|
|
|
|
| 1441 |
|
| 1442 |
Returns:
|
| 1443 |
[`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
|
|
@@ -1447,13 +1447,28 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1447 |
|
| 1448 |
batch_size = len(codes_list)
|
| 1449 |
device = codes_list[0].device
|
| 1450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1451 |
max_length = max(codes.shape[-1] for codes in codes_list)
|
| 1452 |
|
| 1453 |
audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
|
| 1454 |
audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
|
| 1455 |
|
| 1456 |
for i, codes in enumerate(codes_list):
|
|
|
|
| 1457 |
audio_codes[:, i, : codes.shape[-1]] = codes
|
| 1458 |
audio_codes_lengths[i] = codes.shape[-1]
|
| 1459 |
|
|
@@ -1638,6 +1653,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1638 |
padding_mask: torch.Tensor | None = None,
|
| 1639 |
return_dict: bool | None = None,
|
| 1640 |
chunk_duration: float | None = None,
|
|
|
|
| 1641 |
):
|
| 1642 |
"""
|
| 1643 |
Decodes the given codes into an output audio waveform.
|
|
@@ -1653,6 +1669,9 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1653 |
If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
|
| 1654 |
streaming KV cache for the causal transformers.
|
| 1655 |
|
|
|
|
|
|
|
|
|
|
| 1656 |
`chunk_duration` must be <= `config.causal_transformer_context_duration`, and
|
| 1657 |
`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
|
| 1658 |
|
|
@@ -1664,6 +1683,13 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1664 |
if audio_codes.dim() == 2:
|
| 1665 |
audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
|
| 1666 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1667 |
_, B, T = audio_codes.shape
|
| 1668 |
device = audio_codes.device
|
| 1669 |
|
|
@@ -1793,7 +1819,12 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
|
|
| 1793 |
if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
|
| 1794 |
decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
|
| 1795 |
else:
|
| 1796 |
-
decoder_output = self.decode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1797 |
decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
|
| 1798 |
output_audio = decoder_output.audio
|
| 1799 |
output_audio_lengths = decoder_output.audio_lengths
|
|
|
|
|
|
|
| 1 |
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
|
| 2 |
#
|
| 3 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
| 25 |
import torch.nn as nn
|
| 26 |
import torch.nn.functional as F
|
| 27 |
|
| 28 |
+
from ...modeling_utils import PreTrainedAudioTokenizerBase
|
| 29 |
+
from ...utils import ModelOutput, auto_docstring, logging
|
|
|
|
| 30 |
from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
|
| 31 |
|
| 32 |
|
|
|
|
| 370 |
dim: int,
|
| 371 |
max_period: float = 10000,
|
| 372 |
dtype: torch.dtype = torch.float32,
|
| 373 |
+
) -> torch.Tensor:
|
| 374 |
"""Create sinusoidal positional embedding with shape [B, T, C]."""
|
| 375 |
if dim % 2 != 0:
|
| 376 |
raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
|
|
|
|
| 404 |
return iter((self.keys, self.values, self.positions))
|
| 405 |
|
| 406 |
@staticmethod
|
| 407 |
+
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult:
|
| 408 |
B, H, T, D = keys.shape
|
| 409 |
positions = torch.arange(T, device=keys.device, dtype=torch.long)
|
| 410 |
return KVCacheResult(keys, values, positions.expand(B, -1))
|
|
|
|
| 504 |
schedule: list[int] | None,
|
| 505 |
x: torch.Tensor,
|
| 506 |
offset: int | None,
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
"""Apply different weights for each time step."""
|
| 509 |
if len(modules) == 1:
|
| 510 |
return modules[0](x)
|
|
|
|
| 1086 |
- 2 * encodings @ codebook.t()
|
| 1087 |
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 1088 |
)
|
|
|
|
| 1089 |
indices = (-dist).max(1)[1]
|
| 1090 |
indices = indices.reshape(latents.size(0), -1)
|
| 1091 |
z_q = self.decode_code_wo_out_proj(indices).float()
|
|
|
|
| 1303 |
"MossAudioTokenizerResidualLFQ",
|
| 1304 |
]
|
| 1305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1306 |
|
| 1307 |
@auto_docstring(
|
| 1308 |
custom_intro="""
|
|
|
|
| 1341 |
context=int(current_frame_rate * self.causal_transformer_context_duration),
|
| 1342 |
)
|
| 1343 |
)
|
| 1344 |
+
current_frame_rate /= self.encoder[-1].downsample_ratio
|
| 1345 |
|
| 1346 |
# Build quantizer
|
| 1347 |
quantizer_kwargs = dict(config.quantizer_kwargs)
|
|
|
|
| 1368 |
context=int(current_frame_rate * self.causal_transformer_context_duration),
|
| 1369 |
)
|
| 1370 |
)
|
| 1371 |
+
current_frame_rate *= self.decoder[-1].downsample_ratio
|
| 1372 |
|
| 1373 |
self.post_init()
|
| 1374 |
|
|
|
|
| 1400 |
self._stop_streaming()
|
| 1401 |
|
| 1402 |
@torch.no_grad()
|
| 1403 |
+
def batch_encode(
|
| 1404 |
+
self, wav_list: list[torch.Tensor], num_quantizers: int | None = None
|
| 1405 |
+
) -> MossAudioTokenizerEncoderOutput:
|
| 1406 |
"""Batch encode a list of audio waveforms.
|
| 1407 |
|
| 1408 |
Args:
|
| 1409 |
wav_list: List of audio tensors, each of shape `(num_samples,)`.
|
| 1410 |
+
num_quantizers: Number of quantizers to use. By default, all quantizers are used.
|
| 1411 |
|
| 1412 |
Returns:
|
| 1413 |
[`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
|
|
|
|
| 1426 |
input_values[i, 0, : wav.shape[-1]] = wav
|
| 1427 |
input_lengths[i] = wav.shape[-1]
|
| 1428 |
|
| 1429 |
+
return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers)
|
| 1430 |
|
| 1431 |
@torch.no_grad()
|
| 1432 |
+
def batch_decode(
|
| 1433 |
+
self, codes_list: list[torch.Tensor], num_quantizers: int | None = None
|
| 1434 |
+
) -> MossAudioTokenizerDecoderOutput:
|
| 1435 |
"""Batch decode a list of audio codes.
|
| 1436 |
|
| 1437 |
Args:
|
| 1438 |
codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
|
| 1439 |
+
num_quantizers: If provided, decode only the first `num_quantizers` quantizers from each element in
|
| 1440 |
+
`codes_list`. If omitted, all elements in `codes_list` must have the same number of quantizers.
|
| 1441 |
|
| 1442 |
Returns:
|
| 1443 |
[`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
|
|
|
|
| 1447 |
|
| 1448 |
batch_size = len(codes_list)
|
| 1449 |
device = codes_list[0].device
|
| 1450 |
+
nqs = [codes.shape[0] for codes in codes_list]
|
| 1451 |
+
if num_quantizers is None:
|
| 1452 |
+
num_quantizers = nqs[0]
|
| 1453 |
+
if any(nq != num_quantizers for nq in nqs):
|
| 1454 |
+
raise ValueError(
|
| 1455 |
+
"All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. "
|
| 1456 |
+
"Pass `num_quantizers=...` to decode a common prefix."
|
| 1457 |
+
)
|
| 1458 |
+
else:
|
| 1459 |
+
min_nq = min(nqs)
|
| 1460 |
+
if min_nq < num_quantizers:
|
| 1461 |
+
raise ValueError(
|
| 1462 |
+
"`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. "
|
| 1463 |
+
f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min_nq}."
|
| 1464 |
+
)
|
| 1465 |
max_length = max(codes.shape[-1] for codes in codes_list)
|
| 1466 |
|
| 1467 |
audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
|
| 1468 |
audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
|
| 1469 |
|
| 1470 |
for i, codes in enumerate(codes_list):
|
| 1471 |
+
codes = codes[:num_quantizers]
|
| 1472 |
audio_codes[:, i, : codes.shape[-1]] = codes
|
| 1473 |
audio_codes_lengths[i] = codes.shape[-1]
|
| 1474 |
|
|
|
|
| 1653 |
padding_mask: torch.Tensor | None = None,
|
| 1654 |
return_dict: bool | None = None,
|
| 1655 |
chunk_duration: float | None = None,
|
| 1656 |
+
num_quantizers: int | None = None,
|
| 1657 |
):
|
| 1658 |
"""
|
| 1659 |
Decodes the given codes into an output audio waveform.
|
|
|
|
| 1669 |
If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
|
| 1670 |
streaming KV cache for the causal transformers.
|
| 1671 |
|
| 1672 |
+
num_quantizers (`int`, *optional*):
|
| 1673 |
+
Number of quantizers to use. By default, all quantizers in `audio_codes` are used.
|
| 1674 |
+
|
| 1675 |
`chunk_duration` must be <= `config.causal_transformer_context_duration`, and
|
| 1676 |
`chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
|
| 1677 |
|
|
|
|
| 1683 |
if audio_codes.dim() == 2:
|
| 1684 |
audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
|
| 1685 |
|
| 1686 |
+
if num_quantizers is not None:
|
| 1687 |
+
if num_quantizers > audio_codes.shape[0]:
|
| 1688 |
+
raise ValueError(
|
| 1689 |
+
f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})."
|
| 1690 |
+
)
|
| 1691 |
+
audio_codes = audio_codes[:num_quantizers]
|
| 1692 |
+
|
| 1693 |
_, B, T = audio_codes.shape
|
| 1694 |
device = audio_codes.device
|
| 1695 |
|
|
|
|
| 1819 |
if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
|
| 1820 |
decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
|
| 1821 |
else:
|
| 1822 |
+
decoder_output = self.decode(
|
| 1823 |
+
audio_codes,
|
| 1824 |
+
padding_mask=padding_mask,
|
| 1825 |
+
return_dict=True,
|
| 1826 |
+
num_quantizers=num_quantizers,
|
| 1827 |
+
)
|
| 1828 |
decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
|
| 1829 |
output_audio = decoder_output.audio
|
| 1830 |
output_audio_lengths = decoder_output.audio_lengths
|