Li-Ruixiao commited on
Commit
eff8e23
·
1 Parent(s): 0750e3c

add support for optional num_quantizers argument

Browse files
Files changed (1) hide show
  1. modeling_moss_audio_tokenizer.py +50 -19
modeling_moss_audio_tokenizer.py CHANGED
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
  # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,9 +25,8 @@ import torch
26
  import torch.nn as nn
27
  import torch.nn.functional as F
28
 
29
- from transformers.modeling_utils import PreTrainedAudioTokenizerBase
30
- from transformers.utils import ModelOutput, auto_docstring, logging
31
-
32
  from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
33
 
34
 
@@ -372,7 +370,7 @@ def create_sin_embedding(
372
  dim: int,
373
  max_period: float = 10000,
374
  dtype: torch.dtype = torch.float32,
375
- ):
376
  """Create sinusoidal positional embedding with shape [B, T, C]."""
377
  if dim % 2 != 0:
378
  raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
@@ -406,7 +404,7 @@ class KVCacheResult:
406
  return iter((self.keys, self.values, self.positions))
407
 
408
  @staticmethod
409
- def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
410
  B, H, T, D = keys.shape
411
  positions = torch.arange(T, device=keys.device, dtype=torch.long)
412
  return KVCacheResult(keys, values, positions.expand(B, -1))
@@ -506,7 +504,7 @@ def apply_weights_per_step(
506
  schedule: list[int] | None,
507
  x: torch.Tensor,
508
  offset: int | None,
509
- ):
510
  """Apply different weights for each time step."""
511
  if len(modules) == 1:
512
  return modules[0](x)
@@ -1088,7 +1086,6 @@ class MossAudioTokenizerLFQ(nn.Module):
1088
  - 2 * encodings @ codebook.t()
1089
  + codebook.pow(2).sum(1, keepdim=True).t()
1090
  )
1091
-
1092
  indices = (-dist).max(1)[1]
1093
  indices = indices.reshape(latents.size(0), -1)
1094
  z_q = self.decode_code_wo_out_proj(indices).float()
@@ -1306,10 +1303,6 @@ class MossAudioTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase):
1306
  "MossAudioTokenizerResidualLFQ",
1307
  ]
1308
 
1309
- def _init_weights(self, module: nn.Module) -> None:
1310
- if isinstance(module, MossAudioTokenizerLayerScale):
1311
- nn.init.constant_(module.scale, 1e-4)
1312
-
1313
 
1314
  @auto_docstring(
1315
  custom_intro="""
@@ -1348,7 +1341,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1348
  context=int(current_frame_rate * self.causal_transformer_context_duration),
1349
  )
1350
  )
1351
- current_frame_rate /= cast(MossAudioTokenizerPatchedPretransform, self.encoder[-1]).downsample_ratio
1352
 
1353
  # Build quantizer
1354
  quantizer_kwargs = dict(config.quantizer_kwargs)
@@ -1375,7 +1368,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1375
  context=int(current_frame_rate * self.causal_transformer_context_duration),
1376
  )
1377
  )
1378
- current_frame_rate *= cast(MossAudioTokenizerPatchedPretransform, self.decoder[-1]).downsample_ratio
1379
 
1380
  self.post_init()
1381
 
@@ -1407,11 +1400,14 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1407
  self._stop_streaming()
1408
 
1409
  @torch.no_grad()
1410
- def batch_encode(self, wav_list: list[torch.Tensor]) -> MossAudioTokenizerEncoderOutput:
 
 
1411
  """Batch encode a list of audio waveforms.
1412
 
1413
  Args:
1414
  wav_list: List of audio tensors, each of shape `(num_samples,)`.
 
1415
 
1416
  Returns:
1417
  [`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
@@ -1430,14 +1426,18 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1430
  input_values[i, 0, : wav.shape[-1]] = wav
1431
  input_lengths[i] = wav.shape[-1]
1432
 
1433
- return self._encode_frame(input_values, input_lengths)
1434
 
1435
  @torch.no_grad()
1436
- def batch_decode(self, codes_list: list[torch.Tensor]) -> MossAudioTokenizerDecoderOutput:
 
 
1437
  """Batch decode a list of audio codes.
1438
 
1439
  Args:
1440
  codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
 
 
1441
 
1442
  Returns:
1443
  [`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
@@ -1447,13 +1447,28 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1447
 
1448
  batch_size = len(codes_list)
1449
  device = codes_list[0].device
1450
- num_quantizers = codes_list[0].shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1451
  max_length = max(codes.shape[-1] for codes in codes_list)
1452
 
1453
  audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
1454
  audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
1455
 
1456
  for i, codes in enumerate(codes_list):
 
1457
  audio_codes[:, i, : codes.shape[-1]] = codes
1458
  audio_codes_lengths[i] = codes.shape[-1]
1459
 
@@ -1638,6 +1653,7 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1638
  padding_mask: torch.Tensor | None = None,
1639
  return_dict: bool | None = None,
1640
  chunk_duration: float | None = None,
 
1641
  ):
1642
  """
1643
  Decodes the given codes into an output audio waveform.
@@ -1653,6 +1669,9 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1653
  If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
1654
  streaming KV cache for the causal transformers.
1655
 
 
 
 
1656
  `chunk_duration` must be <= `config.causal_transformer_context_duration`, and
1657
  `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
1658
 
@@ -1664,6 +1683,13 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1664
  if audio_codes.dim() == 2:
1665
  audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
1666
 
 
 
 
 
 
 
 
1667
  _, B, T = audio_codes.shape
1668
  device = audio_codes.device
1669
 
@@ -1793,7 +1819,12 @@ class MossAudioTokenizerModel(MossAudioTokenizerPreTrainedModel):
1793
  if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
1794
  decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
1795
  else:
1796
- decoder_output = self.decode(audio_codes, padding_mask=padding_mask, return_dict=True)
 
 
 
 
 
1797
  decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
1798
  output_audio = decoder_output.audio
1799
  output_audio_lengths = decoder_output.audio_lengths
 
 
1
  # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
 
25
  import torch.nn as nn
26
  import torch.nn.functional as F
27
 
28
+ from ...modeling_utils import PreTrainedAudioTokenizerBase
29
+ from ...utils import ModelOutput, auto_docstring, logging
 
30
  from .configuration_moss_audio_tokenizer import MossAudioTokenizerConfig
31
 
32
 
 
370
  dim: int,
371
  max_period: float = 10000,
372
  dtype: torch.dtype = torch.float32,
373
+ ) -> torch.Tensor:
374
  """Create sinusoidal positional embedding with shape [B, T, C]."""
375
  if dim % 2 != 0:
376
  raise ValueError(f"Sinusoidal embedding requires even dim, got dim={dim}")
 
404
  return iter((self.keys, self.values, self.positions))
405
 
406
  @staticmethod
407
+ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> KVCacheResult:
408
  B, H, T, D = keys.shape
409
  positions = torch.arange(T, device=keys.device, dtype=torch.long)
410
  return KVCacheResult(keys, values, positions.expand(B, -1))
 
504
  schedule: list[int] | None,
505
  x: torch.Tensor,
506
  offset: int | None,
507
+ ) -> torch.Tensor:
508
  """Apply different weights for each time step."""
509
  if len(modules) == 1:
510
  return modules[0](x)
 
1086
  - 2 * encodings @ codebook.t()
1087
  + codebook.pow(2).sum(1, keepdim=True).t()
1088
  )
 
1089
  indices = (-dist).max(1)[1]
1090
  indices = indices.reshape(latents.size(0), -1)
1091
  z_q = self.decode_code_wo_out_proj(indices).float()
 
1303
  "MossAudioTokenizerResidualLFQ",
1304
  ]
1305
 
 
 
 
 
1306
 
1307
  @auto_docstring(
1308
  custom_intro="""
 
1341
  context=int(current_frame_rate * self.causal_transformer_context_duration),
1342
  )
1343
  )
1344
+ current_frame_rate /= self.encoder[-1].downsample_ratio
1345
 
1346
  # Build quantizer
1347
  quantizer_kwargs = dict(config.quantizer_kwargs)
 
1368
  context=int(current_frame_rate * self.causal_transformer_context_duration),
1369
  )
1370
  )
1371
+ current_frame_rate *= self.decoder[-1].downsample_ratio
1372
 
1373
  self.post_init()
1374
 
 
1400
  self._stop_streaming()
1401
 
1402
  @torch.no_grad()
1403
+ def batch_encode(
1404
+ self, wav_list: list[torch.Tensor], num_quantizers: int | None = None
1405
+ ) -> MossAudioTokenizerEncoderOutput:
1406
  """Batch encode a list of audio waveforms.
1407
 
1408
  Args:
1409
  wav_list: List of audio tensors, each of shape `(num_samples,)`.
1410
+ num_quantizers: Number of quantizers to use. By default, all quantizers are used.
1411
 
1412
  Returns:
1413
  [`MossAudioTokenizerEncoderOutput`] with `audio_codes` and `audio_codes_lengths`.
 
1426
  input_values[i, 0, : wav.shape[-1]] = wav
1427
  input_lengths[i] = wav.shape[-1]
1428
 
1429
+ return self._encode_frame(input_values, input_lengths, n_quantizers=num_quantizers)
1430
 
1431
  @torch.no_grad()
1432
+ def batch_decode(
1433
+ self, codes_list: list[torch.Tensor], num_quantizers: int | None = None
1434
+ ) -> MossAudioTokenizerDecoderOutput:
1435
  """Batch decode a list of audio codes.
1436
 
1437
  Args:
1438
  codes_list: List of audio code tensors, each of shape `(num_quantizers, codes_length)`.
1439
+ num_quantizers: If provided, decode only the first `num_quantizers` quantizers from each element in
1440
+ `codes_list`. If omitted, all elements in `codes_list` must have the same number of quantizers.
1441
 
1442
  Returns:
1443
  [`MossAudioTokenizerDecoderOutput`] with `audio` and `audio_lengths`.
 
1447
 
1448
  batch_size = len(codes_list)
1449
  device = codes_list[0].device
1450
+ nqs = [codes.shape[0] for codes in codes_list]
1451
+ if num_quantizers is None:
1452
+ num_quantizers = nqs[0]
1453
+ if any(nq != num_quantizers for nq in nqs):
1454
+ raise ValueError(
1455
+ "All elements in `codes_list` must have the same number of quantizers when `num_quantizers` is None. "
1456
+ "Pass `num_quantizers=...` to decode a common prefix."
1457
+ )
1458
+ else:
1459
+ min_nq = min(nqs)
1460
+ if min_nq < num_quantizers:
1461
+ raise ValueError(
1462
+ "`num_quantizers` must be <= the number of quantizers for every element in `codes_list`. "
1463
+ f"Got num_quantizers={num_quantizers}, min(codes.shape[0])={min_nq}."
1464
+ )
1465
  max_length = max(codes.shape[-1] for codes in codes_list)
1466
 
1467
  audio_codes = torch.zeros(num_quantizers, batch_size, max_length, device=device, dtype=torch.long)
1468
  audio_codes_lengths = torch.zeros(batch_size, device=device, dtype=torch.long)
1469
 
1470
  for i, codes in enumerate(codes_list):
1471
+ codes = codes[:num_quantizers]
1472
  audio_codes[:, i, : codes.shape[-1]] = codes
1473
  audio_codes_lengths[i] = codes.shape[-1]
1474
 
 
1653
  padding_mask: torch.Tensor | None = None,
1654
  return_dict: bool | None = None,
1655
  chunk_duration: float | None = None,
1656
+ num_quantizers: int | None = None,
1657
  ):
1658
  """
1659
  Decodes the given codes into an output audio waveform.
 
1669
  If provided, decode the input codes in successive chunks of `chunk_duration` seconds while keeping a
1670
  streaming KV cache for the causal transformers.
1671
 
1672
+ num_quantizers (`int`, *optional*):
1673
+ Number of quantizers to use. By default, all quantizers in `audio_codes` are used.
1674
+
1675
  `chunk_duration` must be <= `config.causal_transformer_context_duration`, and
1676
  `chunk_duration * config.sampling_rate` must be divisible by `config.downsample_rate`.
1677
 
 
1683
  if audio_codes.dim() == 2:
1684
  audio_codes = audio_codes.unsqueeze(1) # nq, T -> nq, B=1, T
1685
 
1686
+ if num_quantizers is not None:
1687
+ if num_quantizers > audio_codes.shape[0]:
1688
+ raise ValueError(
1689
+ f"`num_quantizers` ({num_quantizers}) must be <= audio_codes.shape[0] ({audio_codes.shape[0]})."
1690
+ )
1691
+ audio_codes = audio_codes[:num_quantizers]
1692
+
1693
  _, B, T = audio_codes.shape
1694
  device = audio_codes.device
1695
 
 
1819
  if decoded_from_encoded_codes and output_audio_codes_lengths is not None:
1820
  decoder_output = self._decode_frame(audio_codes, output_audio_codes_lengths)
1821
  else:
1822
+ decoder_output = self.decode(
1823
+ audio_codes,
1824
+ padding_mask=padding_mask,
1825
+ return_dict=True,
1826
+ num_quantizers=num_quantizers,
1827
+ )
1828
  decoder_output = cast(MossAudioTokenizerDecoderOutput, decoder_output)
1829
  output_audio = decoder_output.audio
1830
  output_audio_lengths = decoder_output.audio_lengths