MOSS-Audio-4B-Thinking / processing_moss_audio.py
kiiic's picture
Upload folder using huggingface_hub
3dfd141 verified
import re
import types
from dataclasses import dataclass
from typing import List, Optional, Sequence, Union
import numpy as np
import torch
from transformers import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.models.whisper.feature_extraction_whisper import (
WhisperFeatureExtractor,
)
@dataclass
class MelConfig:
mel_sr: int = 16000
mel_dim: int = 128
mel_n_fft: int = 400
mel_hop_length: int = 160
mel_dtype: torch.dtype = torch.bfloat16
use_whisper_feature_extractor: bool = True
def _normalize_mel_config(mel_config) -> dict[str, object]:
default_config = MelConfig()
if mel_config is None:
source = {}
elif isinstance(mel_config, MelConfig):
source = {
key: getattr(mel_config, key) for key in MelConfig.__dataclass_fields__.keys()
}
else:
source = dict(mel_config)
normalized = {}
for key in MelConfig.__dataclass_fields__.keys():
value = source.get(key, getattr(default_config, key))
if key == "mel_dtype":
if isinstance(value, torch.dtype):
value = str(value).removeprefix("torch.")
elif isinstance(value, str) and value.startswith("torch."):
value = value.removeprefix("torch.")
normalized[key] = value
return normalized
def _build_mel_config(mel_config_dict: dict[str, object]) -> MelConfig:
default_config = MelConfig()
def _int_value(key: str, default: int) -> int:
value = mel_config_dict.get(key, default)
if isinstance(value, bool):
return int(value)
if isinstance(value, (int, str)):
return int(value)
return default
def _bool_value(key: str, default: bool) -> bool:
value = mel_config_dict.get(key, default)
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in {"1", "true", "yes", "on"}
if isinstance(value, int):
return bool(value)
return default
mel_dtype_value = mel_config_dict.get("mel_dtype", default_config.mel_dtype)
if isinstance(mel_dtype_value, str):
mel_dtype = getattr(torch, mel_dtype_value.removeprefix("torch."))
elif isinstance(mel_dtype_value, torch.dtype):
mel_dtype = mel_dtype_value
else:
mel_dtype = default_config.mel_dtype
return MelConfig(
mel_sr=_int_value("mel_sr", default_config.mel_sr),
mel_dim=_int_value("mel_dim", default_config.mel_dim),
mel_n_fft=_int_value("mel_n_fft", default_config.mel_n_fft),
mel_hop_length=_int_value("mel_hop_length", default_config.mel_hop_length),
mel_dtype=mel_dtype,
use_whisper_feature_extractor=_bool_value(
"use_whisper_feature_extractor",
default_config.use_whisper_feature_extractor,
),
)
class MossAudioProcessor(ProcessorMixin):
attributes = ["tokenizer"]
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
_AUDIO_SPAN_RE = re.compile(r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>")
def __init__(
self,
tokenizer=None,
mel_config=None,
enable_time_marker: bool = False,
audio_token_id: int = 151654,
audio_start_id: int = 151669,
audio_end_id: int = 151670,
chat_template=None,
):
super().__init__(tokenizer, chat_template=chat_template)
if tokenizer is None:
raise ValueError("MossAudioProcessor requires a tokenizer.")
self._base_tokenizer = tokenizer
self.mel_config = _normalize_mel_config(mel_config)
self.config = _build_mel_config(self.mel_config)
self.enable_time_marker = bool(enable_time_marker)
self.audio_token_id = int(audio_token_id)
self.audio_start_id = int(audio_start_id)
self.audio_end_id = int(audio_end_id)
self._whisper_feature_extractor = None
alias_map = {
"<|AUDIO|>": self.audio_token_id,
"<|audio_bos|>": self.audio_start_id,
"<|audio_eos|>": self.audio_end_id,
}
orig_convert_tokens_to_ids = tokenizer.convert_tokens_to_ids
def _patched_convert_tokens_to_ids(tokenizer_self, tokens):
if isinstance(tokens, (list, tuple)):
converted = [
_patched_convert_tokens_to_ids(tokenizer_self, token)
for token in tokens
]
return converted if isinstance(tokens, list) else tuple(converted)
if isinstance(tokens, str) and tokens in alias_map:
return alias_map[tokens]
return orig_convert_tokens_to_ids(tokens)
tokenizer.convert_tokens_to_ids = types.MethodType(
_patched_convert_tokens_to_ids, tokenizer
)
self._digit_token_ids = {
"0": 15,
"1": 16,
"2": 17,
"3": 18,
"4": 19,
"5": 20,
"6": 21,
"7": 22,
"8": 23,
"9": 24,
}
self.audio_tokens_per_second = 12.5
self.time_marker_every_seconds = 2
self.time_marker_every_audio_tokens = int(
self.audio_tokens_per_second * self.time_marker_every_seconds
)
@property
def model_input_names(self):
return [
"input_ids",
"attention_mask",
"audio_data",
"audio_data_seqlens",
]
@staticmethod
def _conv3_downsample_len(raw_mel_len: int) -> int:
def conv_out_len(length: int) -> int:
return (length - 1) // 2 + 1
length1 = conv_out_len(int(raw_mel_len))
length2 = conv_out_len(length1)
length3 = conv_out_len(length2)
return int(length3)
def _get_whisper_feature_extractor(self):
if self._whisper_feature_extractor is not None:
return self._whisper_feature_extractor
self._whisper_feature_extractor = WhisperFeatureExtractor(
feature_size=int(self.config.mel_dim),
sampling_rate=int(self.config.mel_sr),
hop_length=int(self.config.mel_hop_length),
n_fft=int(self.config.mel_n_fft),
)
return self._whisper_feature_extractor
def _extract_mel(self, audio: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if isinstance(audio, np.ndarray):
wav = torch.from_numpy(audio)
else:
wav = audio
wav = wav.to(dtype=torch.float32)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
if bool(getattr(self.config, "use_whisper_feature_extractor", False)):
fe = self._get_whisper_feature_extractor()
wav_np = wav.detach().to("cpu", torch.float32).contiguous().numpy()
if wav_np.ndim == 2:
wav_np = wav_np[0]
feats = fe._np_extract_fbank_features(wav_np[None, ...], device="cpu")
mel = torch.from_numpy(feats[0])
else:
raise ValueError("MossAudioProcessor requires whisper feature extraction.")
return mel.to(dtype=self.config.mel_dtype)
def _get_time_marker_token_ids(self, second: int) -> List[int]:
return [self._digit_token_ids[digit] for digit in str(second)]
def _build_audio_tokens_with_time_markers(self, audio_seq_len: int) -> List[int]:
total_duration_seconds = audio_seq_len / self.audio_tokens_per_second
num_full_seconds = int(total_duration_seconds)
token_ids: List[int] = []
audio_tokens_consumed = 0
for second in range(
self.time_marker_every_seconds,
num_full_seconds + 1,
self.time_marker_every_seconds,
):
marker_pos = (
second // self.time_marker_every_seconds
) * self.time_marker_every_audio_tokens
audio_segment_len = marker_pos - audio_tokens_consumed
if audio_segment_len > 0:
token_ids.extend([self.audio_token_id] * audio_segment_len)
audio_tokens_consumed += audio_segment_len
token_ids.extend(self._get_time_marker_token_ids(second))
remaining = audio_seq_len - audio_tokens_consumed
if remaining > 0:
token_ids.extend([self.audio_token_id] * remaining)
return token_ids
def _build_audio_placeholder_ids(self, num_audio_tokens: int) -> List[int]:
if self.enable_time_marker:
return self._build_audio_tokens_with_time_markers(num_audio_tokens)
return [self.audio_token_id] * num_audio_tokens
def _build_default_prompt(self, text: str, has_audio: bool) -> str:
if has_audio:
return (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"<|audio_bos|><|AUDIO|><|audio_eos|>\n"
f"{text}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{text}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def _build_input_from_prompt(self, prompt: str, token_lens: List[int]) -> List[int]:
spans = list(self._AUDIO_SPAN_RE.finditer(prompt))
if len(spans) != len(token_lens):
raise ValueError(
f"Audio placeholder count mismatch: found {len(spans)} spans in text, "
f"but got {len(token_lens)} audio inputs."
)
input_ids: List[int] = []
cursor = 0
for index, match in enumerate(spans):
prefix = prompt[cursor : match.start()]
if prefix:
input_ids.extend(
self._base_tokenizer.encode(prefix, add_special_tokens=False)
)
input_ids.append(self.audio_start_id)
input_ids.extend(self._build_audio_placeholder_ids(int(token_lens[index])))
input_ids.append(self.audio_end_id)
cursor = match.end()
suffix = prompt[cursor:]
if suffix:
input_ids.extend(
self._base_tokenizer.encode(suffix, add_special_tokens=False)
)
return input_ids
def __call__(
self,
*args,
text: Union[str, Sequence[str], None] = None,
audios: Optional[Sequence[Union[np.ndarray, torch.Tensor]]] = None,
audio: Optional[Sequence[Union[np.ndarray, torch.Tensor]]] = None,
return_tensors: str = "pt",
**kwargs,
) -> BatchFeature:
_ = args, kwargs
if isinstance(text, str):
prompt_text: Optional[str] = text
elif isinstance(text, (list, tuple)):
if len(text) != 1:
raise ValueError(f"Expected text batch size 1, got {len(text)}")
prompt_text = text[0]
if not isinstance(prompt_text, str):
raise TypeError("Expected text batch size 1 with string content.")
elif text is None:
prompt_text = None
else:
raise TypeError("MossAudioProcessor text must be a string or a batch of one string.")
audio_list = audios if audios is not None else audio
audio_list = [] if audio_list is None else list(audio_list)
mels: List[torch.Tensor] = []
raw_lengths: List[int] = []
token_lens: List[int] = []
for one_audio in audio_list:
mel = self._extract_mel(one_audio)
raw_len = int(mel.shape[-1])
mels.append(mel)
raw_lengths.append(raw_len)
token_lens.append(self._conv3_downsample_len(raw_len))
if mels:
max_length = max(raw_lengths)
audio_batch = torch.zeros(
(len(mels), self.config.mel_dim, max_length),
dtype=self.config.mel_dtype,
)
for index, mel in enumerate(mels):
audio_batch[index, :, : mel.shape[-1]] = mel
seqlens_tensor = torch.tensor(raw_lengths, dtype=torch.long)
else:
audio_batch = None
seqlens_tensor = None
if prompt_text is None:
raise ValueError(
"MossAudioProcessor requires text input. Apply a chat template before calling the processor if needed."
)
if self._AUDIO_SPAN_RE.search(prompt_text) is None and audio_list:
prompt_text = self._build_default_prompt(prompt_text, has_audio=True)
elif self._AUDIO_SPAN_RE.search(prompt_text) is None and not audio_list:
prompt_text = self._build_default_prompt(prompt_text, has_audio=False)
input_ids_list = self._build_input_from_prompt(prompt_text, token_lens)
input_ids_tensor = torch.tensor([input_ids_list], dtype=torch.long)
attention_mask_tensor = torch.ones_like(input_ids_tensor)
data = {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
}
if audio_batch is not None and seqlens_tensor is not None:
data["audio_data"] = audio_batch
data["audio_data_seqlens"] = seqlens_tensor
return BatchFeature(data=data, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self._base_tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self._base_tokenizer.decode(*args, **kwargs)
__all__ = ["MelConfig", "MossAudioProcessor"]