| import re |
| import types |
| from dataclasses import dataclass |
| from typing import List, Optional, Sequence, Union |
|
|
| import numpy as np |
| import torch |
| from transformers import BatchFeature |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.models.whisper.feature_extraction_whisper import ( |
| WhisperFeatureExtractor, |
| ) |
|
|
|
|
| @dataclass |
| class MelConfig: |
| mel_sr: int = 16000 |
| mel_dim: int = 128 |
| mel_n_fft: int = 400 |
| mel_hop_length: int = 160 |
| mel_dtype: torch.dtype = torch.bfloat16 |
| use_whisper_feature_extractor: bool = True |
|
|
|
|
| def _normalize_mel_config(mel_config) -> dict[str, object]: |
| default_config = MelConfig() |
| if mel_config is None: |
| source = {} |
| elif isinstance(mel_config, MelConfig): |
| source = { |
| key: getattr(mel_config, key) for key in MelConfig.__dataclass_fields__.keys() |
| } |
| else: |
| source = dict(mel_config) |
|
|
| normalized = {} |
| for key in MelConfig.__dataclass_fields__.keys(): |
| value = source.get(key, getattr(default_config, key)) |
| if key == "mel_dtype": |
| if isinstance(value, torch.dtype): |
| value = str(value).removeprefix("torch.") |
| elif isinstance(value, str) and value.startswith("torch."): |
| value = value.removeprefix("torch.") |
| normalized[key] = value |
| return normalized |
|
|
|
|
| def _build_mel_config(mel_config_dict: dict[str, object]) -> MelConfig: |
| default_config = MelConfig() |
|
|
| def _int_value(key: str, default: int) -> int: |
| value = mel_config_dict.get(key, default) |
| if isinstance(value, bool): |
| return int(value) |
| if isinstance(value, (int, str)): |
| return int(value) |
| return default |
|
|
| def _bool_value(key: str, default: bool) -> bool: |
| value = mel_config_dict.get(key, default) |
| if isinstance(value, bool): |
| return value |
| if isinstance(value, str): |
| return value.lower() in {"1", "true", "yes", "on"} |
| if isinstance(value, int): |
| return bool(value) |
| return default |
|
|
| mel_dtype_value = mel_config_dict.get("mel_dtype", default_config.mel_dtype) |
| if isinstance(mel_dtype_value, str): |
| mel_dtype = getattr(torch, mel_dtype_value.removeprefix("torch.")) |
| elif isinstance(mel_dtype_value, torch.dtype): |
| mel_dtype = mel_dtype_value |
| else: |
| mel_dtype = default_config.mel_dtype |
|
|
| return MelConfig( |
| mel_sr=_int_value("mel_sr", default_config.mel_sr), |
| mel_dim=_int_value("mel_dim", default_config.mel_dim), |
| mel_n_fft=_int_value("mel_n_fft", default_config.mel_n_fft), |
| mel_hop_length=_int_value("mel_hop_length", default_config.mel_hop_length), |
| mel_dtype=mel_dtype, |
| use_whisper_feature_extractor=_bool_value( |
| "use_whisper_feature_extractor", |
| default_config.use_whisper_feature_extractor, |
| ), |
| ) |
|
|
|
|
| class MossAudioProcessor(ProcessorMixin): |
| attributes = ["tokenizer"] |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
| _AUDIO_SPAN_RE = re.compile(r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>") |
|
|
| def __init__( |
| self, |
| tokenizer=None, |
| mel_config=None, |
| enable_time_marker: bool = False, |
| audio_token_id: int = 151654, |
| audio_start_id: int = 151669, |
| audio_end_id: int = 151670, |
| chat_template=None, |
| ): |
| super().__init__(tokenizer, chat_template=chat_template) |
| if tokenizer is None: |
| raise ValueError("MossAudioProcessor requires a tokenizer.") |
|
|
| self._base_tokenizer = tokenizer |
| self.mel_config = _normalize_mel_config(mel_config) |
| self.config = _build_mel_config(self.mel_config) |
| self.enable_time_marker = bool(enable_time_marker) |
| self.audio_token_id = int(audio_token_id) |
| self.audio_start_id = int(audio_start_id) |
| self.audio_end_id = int(audio_end_id) |
| self._whisper_feature_extractor = None |
|
|
| alias_map = { |
| "<|AUDIO|>": self.audio_token_id, |
| "<|audio_bos|>": self.audio_start_id, |
| "<|audio_eos|>": self.audio_end_id, |
| } |
| orig_convert_tokens_to_ids = tokenizer.convert_tokens_to_ids |
|
|
| def _patched_convert_tokens_to_ids(tokenizer_self, tokens): |
| if isinstance(tokens, (list, tuple)): |
| converted = [ |
| _patched_convert_tokens_to_ids(tokenizer_self, token) |
| for token in tokens |
| ] |
| return converted if isinstance(tokens, list) else tuple(converted) |
| if isinstance(tokens, str) and tokens in alias_map: |
| return alias_map[tokens] |
| return orig_convert_tokens_to_ids(tokens) |
|
|
| tokenizer.convert_tokens_to_ids = types.MethodType( |
| _patched_convert_tokens_to_ids, tokenizer |
| ) |
|
|
| self._digit_token_ids = { |
| "0": 15, |
| "1": 16, |
| "2": 17, |
| "3": 18, |
| "4": 19, |
| "5": 20, |
| "6": 21, |
| "7": 22, |
| "8": 23, |
| "9": 24, |
| } |
| self.audio_tokens_per_second = 12.5 |
| self.time_marker_every_seconds = 2 |
| self.time_marker_every_audio_tokens = int( |
| self.audio_tokens_per_second * self.time_marker_every_seconds |
| ) |
|
|
| @property |
| def model_input_names(self): |
| return [ |
| "input_ids", |
| "attention_mask", |
| "audio_data", |
| "audio_data_seqlens", |
| ] |
|
|
| @staticmethod |
| def _conv3_downsample_len(raw_mel_len: int) -> int: |
| def conv_out_len(length: int) -> int: |
| return (length - 1) // 2 + 1 |
|
|
| length1 = conv_out_len(int(raw_mel_len)) |
| length2 = conv_out_len(length1) |
| length3 = conv_out_len(length2) |
| return int(length3) |
|
|
| def _get_whisper_feature_extractor(self): |
| if self._whisper_feature_extractor is not None: |
| return self._whisper_feature_extractor |
|
|
| self._whisper_feature_extractor = WhisperFeatureExtractor( |
| feature_size=int(self.config.mel_dim), |
| sampling_rate=int(self.config.mel_sr), |
| hop_length=int(self.config.mel_hop_length), |
| n_fft=int(self.config.mel_n_fft), |
| ) |
| return self._whisper_feature_extractor |
|
|
| def _extract_mel(self, audio: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: |
| if isinstance(audio, np.ndarray): |
| wav = torch.from_numpy(audio) |
| else: |
| wav = audio |
| wav = wav.to(dtype=torch.float32) |
| if wav.dim() == 1: |
| wav = wav.unsqueeze(0) |
|
|
| if bool(getattr(self.config, "use_whisper_feature_extractor", False)): |
| fe = self._get_whisper_feature_extractor() |
| wav_np = wav.detach().to("cpu", torch.float32).contiguous().numpy() |
| if wav_np.ndim == 2: |
| wav_np = wav_np[0] |
| feats = fe._np_extract_fbank_features(wav_np[None, ...], device="cpu") |
| mel = torch.from_numpy(feats[0]) |
| else: |
| raise ValueError("MossAudioProcessor requires whisper feature extraction.") |
|
|
| return mel.to(dtype=self.config.mel_dtype) |
|
|
| def _get_time_marker_token_ids(self, second: int) -> List[int]: |
| return [self._digit_token_ids[digit] for digit in str(second)] |
|
|
| def _build_audio_tokens_with_time_markers(self, audio_seq_len: int) -> List[int]: |
| total_duration_seconds = audio_seq_len / self.audio_tokens_per_second |
| num_full_seconds = int(total_duration_seconds) |
|
|
| token_ids: List[int] = [] |
| audio_tokens_consumed = 0 |
| for second in range( |
| self.time_marker_every_seconds, |
| num_full_seconds + 1, |
| self.time_marker_every_seconds, |
| ): |
| marker_pos = ( |
| second // self.time_marker_every_seconds |
| ) * self.time_marker_every_audio_tokens |
| audio_segment_len = marker_pos - audio_tokens_consumed |
| if audio_segment_len > 0: |
| token_ids.extend([self.audio_token_id] * audio_segment_len) |
| audio_tokens_consumed += audio_segment_len |
| token_ids.extend(self._get_time_marker_token_ids(second)) |
|
|
| remaining = audio_seq_len - audio_tokens_consumed |
| if remaining > 0: |
| token_ids.extend([self.audio_token_id] * remaining) |
| return token_ids |
|
|
| def _build_audio_placeholder_ids(self, num_audio_tokens: int) -> List[int]: |
| if self.enable_time_marker: |
| return self._build_audio_tokens_with_time_markers(num_audio_tokens) |
| return [self.audio_token_id] * num_audio_tokens |
|
|
| def _build_default_prompt(self, text: str, has_audio: bool) -> str: |
| if has_audio: |
| return ( |
| "<|im_start|>system\n" |
| "You are a helpful assistant.<|im_end|>\n" |
| "<|im_start|>user\n" |
| "<|audio_bos|><|AUDIO|><|audio_eos|>\n" |
| f"{text}<|im_end|>\n" |
| "<|im_start|>assistant\n" |
| ) |
| return ( |
| "<|im_start|>system\n" |
| "You are a helpful assistant.<|im_end|>\n" |
| "<|im_start|>user\n" |
| f"{text}<|im_end|>\n" |
| "<|im_start|>assistant\n" |
| ) |
|
|
| def _build_input_from_prompt(self, prompt: str, token_lens: List[int]) -> List[int]: |
| spans = list(self._AUDIO_SPAN_RE.finditer(prompt)) |
| if len(spans) != len(token_lens): |
| raise ValueError( |
| f"Audio placeholder count mismatch: found {len(spans)} spans in text, " |
| f"but got {len(token_lens)} audio inputs." |
| ) |
|
|
| input_ids: List[int] = [] |
| cursor = 0 |
| for index, match in enumerate(spans): |
| prefix = prompt[cursor : match.start()] |
| if prefix: |
| input_ids.extend( |
| self._base_tokenizer.encode(prefix, add_special_tokens=False) |
| ) |
|
|
| input_ids.append(self.audio_start_id) |
| input_ids.extend(self._build_audio_placeholder_ids(int(token_lens[index]))) |
| input_ids.append(self.audio_end_id) |
| cursor = match.end() |
|
|
| suffix = prompt[cursor:] |
| if suffix: |
| input_ids.extend( |
| self._base_tokenizer.encode(suffix, add_special_tokens=False) |
| ) |
| return input_ids |
|
|
| def __call__( |
| self, |
| *args, |
| text: Union[str, Sequence[str], None] = None, |
| audios: Optional[Sequence[Union[np.ndarray, torch.Tensor]]] = None, |
| audio: Optional[Sequence[Union[np.ndarray, torch.Tensor]]] = None, |
| return_tensors: str = "pt", |
| **kwargs, |
| ) -> BatchFeature: |
| _ = args, kwargs |
|
|
| if isinstance(text, str): |
| prompt_text: Optional[str] = text |
| elif isinstance(text, (list, tuple)): |
| if len(text) != 1: |
| raise ValueError(f"Expected text batch size 1, got {len(text)}") |
| prompt_text = text[0] |
| if not isinstance(prompt_text, str): |
| raise TypeError("Expected text batch size 1 with string content.") |
| elif text is None: |
| prompt_text = None |
| else: |
| raise TypeError("MossAudioProcessor text must be a string or a batch of one string.") |
|
|
| audio_list = audios if audios is not None else audio |
| audio_list = [] if audio_list is None else list(audio_list) |
|
|
| mels: List[torch.Tensor] = [] |
| raw_lengths: List[int] = [] |
| token_lens: List[int] = [] |
| for one_audio in audio_list: |
| mel = self._extract_mel(one_audio) |
| raw_len = int(mel.shape[-1]) |
| mels.append(mel) |
| raw_lengths.append(raw_len) |
| token_lens.append(self._conv3_downsample_len(raw_len)) |
|
|
| if mels: |
| max_length = max(raw_lengths) |
| audio_batch = torch.zeros( |
| (len(mels), self.config.mel_dim, max_length), |
| dtype=self.config.mel_dtype, |
| ) |
| for index, mel in enumerate(mels): |
| audio_batch[index, :, : mel.shape[-1]] = mel |
| seqlens_tensor = torch.tensor(raw_lengths, dtype=torch.long) |
| else: |
| audio_batch = None |
| seqlens_tensor = None |
|
|
| if prompt_text is None: |
| raise ValueError( |
| "MossAudioProcessor requires text input. Apply a chat template before calling the processor if needed." |
| ) |
|
|
| if self._AUDIO_SPAN_RE.search(prompt_text) is None and audio_list: |
| prompt_text = self._build_default_prompt(prompt_text, has_audio=True) |
| elif self._AUDIO_SPAN_RE.search(prompt_text) is None and not audio_list: |
| prompt_text = self._build_default_prompt(prompt_text, has_audio=False) |
| input_ids_list = self._build_input_from_prompt(prompt_text, token_lens) |
|
|
| input_ids_tensor = torch.tensor([input_ids_list], dtype=torch.long) |
| attention_mask_tensor = torch.ones_like(input_ids_tensor) |
|
|
| data = { |
| "input_ids": input_ids_tensor, |
| "attention_mask": attention_mask_tensor, |
| } |
| if audio_batch is not None and seqlens_tensor is not None: |
| data["audio_data"] = audio_batch |
| data["audio_data_seqlens"] = seqlens_tensor |
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self._base_tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self._base_tokenizer.decode(*args, **kwargs) |
|
|
|
|
| __all__ = ["MelConfig", "MossAudioProcessor"] |
|
|