Li-Ruixiao commited on
Commit
f0e23bb
·
1 Parent(s): e4aa3d2

Refactor processing_moss_tts.py: Improve type hints, enhance message classes, and streamline audio token handling

Browse files
Files changed (1) hide show
  1. processing_moss_tts.py +432 -131
processing_moss_tts.py CHANGED
@@ -14,14 +14,22 @@
14
  # limitations under the License.
15
 
16
  import os
17
- from typing import Dict, List, Optional, Tuple, Type, Union, Literal, Final
18
  from dataclasses import dataclass
19
  from pathlib import Path
20
  import re
21
  import torchaudio
22
 
23
  import torch
24
- from transformers import PreTrainedTokenizerBase, BatchFeature, ProcessorMixin, logging, AutoConfig, AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
25
 
26
  from .configuration_moss_tts import MossTTSDelayConfig
27
 
@@ -34,8 +42,8 @@ AUDIO_PLACEHOLDER = "<|audio|>"
34
 
35
  @dataclass
36
  class Message:
37
- pass
38
-
39
 
40
 
41
  @dataclass
@@ -78,13 +86,16 @@ class UserMessage(Message):
78
  if speaker_reference is not None:
79
  reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
80
  reference = "\n".join(reference)
81
- audio_codes_list = [speaker_reference for speaker_reference in self.reference if speaker_reference is not None]
 
 
 
 
82
  else:
83
  raise TypeError("`reference` should be exactly a list when it is not None.")
84
 
85
  content = (
86
- template
87
- .replace("{reference}", str(reference))
88
  .replace("{instruction}", str(self.instruction))
89
  .replace("{tokens}", str(self.tokens))
90
  .replace("{quality}", str(self.quality))
@@ -101,22 +112,23 @@ class UserMessage(Message):
101
  return {
102
  "role": "user",
103
  "content": self._content,
104
- "audio_codes_list": self._audio_codes_list
105
  }
106
 
107
 
108
  @dataclass
109
  class AssistantMessage(Message):
110
  audio_codes_list: List[Union[str, torch.Tensor]]
111
- content: str = AUDIO_PLACEHOLDER
112
 
113
  def to_dict(self):
114
  return {
115
  "role": "assistant",
116
  "content": self.content,
117
- "audio_codes_list": self.audio_codes_list
118
  }
119
 
 
120
  USER_MESSAGE_FIELDS = (
121
  "text",
122
  "reference",
@@ -129,27 +141,25 @@ USER_MESSAGE_FIELDS = (
129
  )
130
 
131
 
132
-
133
-
134
-
135
-
136
-
137
  class MossTTSDelayProcessor(ProcessorMixin):
138
  tokenizer_class = "AutoTokenizer"
139
  audio_tokenizer_class = "AutoModel"
140
 
 
 
 
141
  def __init__(
142
  self,
143
  tokenizer: PreTrainedTokenizerBase,
144
- audio_tokenizer: AutoModel = None,
145
  model_config: Optional[MossTTSDelayConfig] = None,
146
- **kwargs
147
  ):
148
- super().__init__(
149
- tokenizer=tokenizer,
150
- audio_tokenizer=audio_tokenizer,
151
- **kwargs
152
- )
153
  if model_config is None:
154
  model_config = MossTTSDelayConfig()
155
  self.model_config = model_config
@@ -158,68 +168,107 @@ class MossTTSDelayProcessor(ProcessorMixin):
158
  self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
159
  self.newline_token_id = 198
160
 
161
- self.audio_user_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_user_slot_token_id)
162
- self.audio_assistant_gen_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_gen_slot_token_id)
163
- self.audio_assistant_delay_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_delay_slot_token_id)
164
- self.audio_start_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_start_token_id)
165
- self.audio_end_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_end_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  @classmethod
168
- def from_pretrained(cls, pretrained_model_name_or_path, trust_remote_code=True, **kwargs):
169
- kwargs.pop("_from_auto")
 
 
 
 
 
 
170
  pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
171
- model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
172
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- audio_tokenizer_name_or_path = kwargs.pop("codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer")
175
- assert isinstance(audio_tokenizer_name_or_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_name_or_path)}"
176
- audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
177
-
178
  return cls(
179
  tokenizer=tokenizer,
180
  audio_tokenizer=audio_tokenizer,
181
  model_config=model_config,
182
- **kwargs
183
  )
184
-
185
- def __call__(
186
- self,
187
- conversations: Union[Message, Dict, List[Message], List[Dict], List[List[Message]], List[List[Dict]]],
188
- mode: str = "generation",
189
- apply_chat_template: bool = True,
190
- n_vq: Optional[int] = None
191
- ) -> BatchFeature:
192
-
 
 
 
193
  """
194
- mode 只会在将 Message 转换为 to_dict 时起作用;
195
  """
196
-
197
  if mode not in {"generation", "continuation"}:
198
  raise RuntimeError
199
-
200
  if isinstance(conversations, (Message, Dict)):
201
  conversations = [conversations]
202
-
203
  truncation = False
204
  if mode == "continuation":
205
  truncation = True
206
-
207
  input_ids_list = []
208
  for conversation in conversations:
209
  if isinstance(conversation, (Message, Dict)):
210
  conversation = [conversation]
211
 
 
 
 
212
  if (mode == "generation") ^ (len(conversation) % 2 != 0):
213
  raise ValueError
214
 
215
- if (mode == "generation") ^ (conversation[-1]['role'] == "user"):
216
  raise ValueError
217
 
218
  unified_codes = []
219
  for message_idx, message in enumerate(conversation):
220
- message = self._normalize_message(message)
221
  if apply_chat_template:
222
- add_generation_prompt = mode == "generation" and message_idx == len(conversation) - 1
 
 
223
  try:
224
  content = self.tokenizer.apply_chat_template(
225
  [{"role": message["role"], "content": message["content"]}],
@@ -229,29 +278,76 @@ class MossTTSDelayProcessor(ProcessorMixin):
229
  except TypeError:
230
  try:
231
  content = self.tokenizer.apply_chat_template(
232
- [{"role": message["role"], "content": message["content"]}],
 
 
 
 
 
233
  add_generation_prompt=add_generation_prompt,
234
  )
235
  except Exception:
236
- logger.warning("apply_chat_template failed; fallback to raw content.")
 
 
237
  content = message["content"]
238
  else:
239
- content = message['content']
240
-
241
- audio_codes_list = []
242
- for audio_codes in message["audio_codes_list"]:
243
- if isinstance(audio_codes, torch.Tensor):
244
- if n_vq is not None and audio_codes.shape[1] != n_vq:
245
- raise RuntimeError("audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs.")
246
- else:
247
- audio_codes = self.encode_audios_from_path(audio_codes, n_vq)[0]
248
- audio_codes_list.append(audio_codes)
249
- unified_codes.append(self._get_unified_codes(message['role'], content, audio_codes_list, truncation))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  unified_codes = torch.cat(unified_codes)
252
  input_ids_list.append(unified_codes)
253
 
254
- return self._pad(input_ids_list)
255
 
256
  @staticmethod
257
  def build_user_message(
@@ -310,14 +406,23 @@ class MossTTSDelayProcessor(ProcessorMixin):
310
  def _pad(self, input_ids_list: List[torch.Tensor]):
311
  device = input_ids_list[0].device
312
  lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
313
- pad_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=self.model_config.audio_pad_code, padding_side="left")
314
- other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(1) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
 
 
 
 
 
 
 
315
  pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
316
- attention_mask = torch.zeros(pad_input_ids.shape[0], pad_input_ids.shape[1], device=device)
 
 
317
  attention_mask[~other_channel_mask] = 1
318
  attention_mask = attention_mask.bool()
319
  return {
320
- "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
321
  "attention_mask": attention_mask,
322
  }
323
 
@@ -329,7 +434,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
329
  gen_slot_token: str,
330
  delay_slot_token: str,
331
  audio_start_token: str,
332
- audio_end_token: str
333
  ) -> str:
334
  if n_vq < 1:
335
  raise ValueError(f"n_vq must be >= 1, got {n_vq}")
@@ -371,7 +476,9 @@ class MossTTSDelayProcessor(ProcessorMixin):
371
  return content, audio_codes_list
372
 
373
  if len(matches) != len(audio_codes_list):
374
- raise ValueError("Audio placeholders do not match the provided audio codes list.")
 
 
375
 
376
  new_audio_codes_list = []
377
  new_parts = []
@@ -381,18 +488,20 @@ class MossTTSDelayProcessor(ProcessorMixin):
381
  j = i
382
  while (
383
  j + 1 < len(matches)
384
- and content[matches[j].end():matches[j + 1].start()].strip() == ""
385
  ):
386
  j += 1
387
 
388
- new_parts.append(content[last_pos:matches[i].start()])
389
  new_parts.append(AUDIO_PLACEHOLDER)
390
  last_pos = matches[j].end()
391
 
392
  if j == i:
393
  new_audio_codes_list.append(audio_codes_list[i])
394
  else:
395
- new_audio_codes_list.append(torch.cat(audio_codes_list[i:j + 1], dim=0))
 
 
396
 
397
  i = j + 1
398
 
@@ -408,9 +517,9 @@ class MossTTSDelayProcessor(ProcessorMixin):
408
  dtype=codes.dtype,
409
  )
410
  for i in range(codes.shape[1]):
411
- delayed_tokens[i: i + codes.shape[0], i] = codes[:, i]
412
  return delayed_tokens
413
-
414
  @staticmethod
415
  def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
416
  tokens = torch.full(
@@ -420,11 +529,16 @@ class MossTTSDelayProcessor(ProcessorMixin):
420
  dtype=delay_codes.dtype,
421
  )
422
  for i in range(delay_codes.shape[1]):
423
- tokens[:, i] = delay_codes[i: i + tokens.shape[0], i]
424
  return tokens
425
 
426
-
427
- def _get_unified_codes(self, role: str, content: str, audio_codes_list: List[Union[str, torch.Tensor]], truncation: bool) -> torch.Tensor:
 
 
 
 
 
428
  """
429
  此时的 content 已经是带上了对话格式
430
  """
@@ -452,12 +566,23 @@ class MossTTSDelayProcessor(ProcessorMixin):
452
  audio_start_token=self.audio_start_token,
453
  audio_end_token=self.audio_end_token,
454
  )
455
- text_codes = torch.tensor(self.tokenizer.encode(content), device=audio_codes_list[0].device if audio_codes_list else None)
 
 
 
456
 
457
- audio_start_indices = torch.where(text_codes == self.model_config.audio_start_token_id)[0]
458
- audio_end_indices = torch.where(text_codes == self.model_config.audio_end_token_id)[0]
459
- if len(audio_start_indices) != len(audio_codes_list) or len(audio_end_indices) != len(audio_codes_list):
460
- raise ValueError("Audio placeholders do not match the provided audio codes list.")
 
 
 
 
 
 
 
 
461
 
462
  delay_audio_codes_list = []
463
  if len(audio_codes_list) == 0:
@@ -469,8 +594,14 @@ class MossTTSDelayProcessor(ProcessorMixin):
469
  )
470
  else:
471
  prefix_idx = 0
472
- for audio_start_idx, audio_end_idx, audio_codes in zip(audio_start_indices, audio_end_indices, audio_codes_list):
473
- delay_audio_codes = self.apply_delay_pattern(audio_codes, self.model_config.audio_pad_code)
 
 
 
 
 
 
474
  pad_codes = torch.full(
475
  (audio_start_idx - prefix_idx + 1, n_vq),
476
  self.model_config.audio_pad_code,
@@ -481,10 +612,13 @@ class MossTTSDelayProcessor(ProcessorMixin):
481
  prefix_idx = audio_end_idx
482
 
483
  if truncation:
484
- delay_audio_codes_list[-1] = delay_audio_codes_list[-1][:-(n_vq - 1), :]
 
 
485
  else:
 
486
  pad_codes = torch.full(
487
- (len(text_codes) - audio_end_indices[-1], n_vq),
488
  self.model_config.audio_pad_code,
489
  device=audio_codes_list[0].device,
490
  dtype=audio_codes_list[0].dtype,
@@ -492,34 +626,36 @@ class MossTTSDelayProcessor(ProcessorMixin):
492
  delay_audio_codes_list.append(pad_codes)
493
 
494
  delay_audio_codes_list = torch.cat(delay_audio_codes_list)
495
-
496
  if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
497
- text_codes = text_codes[:delay_audio_codes_list.shape[0]]
498
 
499
- unified_codes = torch.cat([text_codes.unsqueeze(1), delay_audio_codes_list], dim=1)
 
 
500
  return unified_codes
501
 
502
  def _parse_text_codes(self, start_length, text_codes):
503
- text = self.tokenizer.decode(text_codes)
504
- prefix = self.tokenizer.decode(text_codes[:start_length])
505
- text = text[len(prefix):]
506
 
507
  AUDIO_PATTERN = re.compile(
508
- rf'(?:{self.audio_start_token})?'
509
- rf'(?:{self.audio_assistant_gen_slot_token})*'
510
- rf'(?:{self.audio_assistant_delay_slot_token})*'
511
- rf'{self.audio_end_token}'
512
  )
513
 
514
  def normalize_audio_segments(text: str) -> str:
515
  def repl(match: re.Match) -> str:
516
  seg = match.group(0)
517
- # 如果片段内包含至少一个 gen_slot,则替换为 <|audio|>
518
  if self.audio_assistant_gen_slot_token in seg:
519
  return AUDIO_PLACEHOLDER
520
- # 否则直接删除
521
  return ""
522
-
523
  return AUDIO_PATTERN.sub(repl, text)
524
 
525
  return normalize_audio_segments(text)
@@ -543,18 +679,21 @@ class MossTTSDelayProcessor(ProcessorMixin):
543
 
544
  audio_codes_list = [audio_codes[s] for s in segments_idx]
545
 
546
- decoded_audio_list = []
547
- for segment_codes in audio_codes_list:
548
- decoded_segment = self.decode_audio_codes([segment_codes])
549
- if len(decoded_segment) > 0:
550
- decoded_audio_list.append(decoded_segment[0])
551
 
552
  # Keep codec causal context by decoding the whole first segment first,
553
  # then trim at waveform level according to start_length ratio.
554
- if start_length > 0 and len(audio_codes_list) > 0 and len(decoded_audio_list) > 0:
 
 
 
 
555
  first_codes_length = audio_codes_list[0].shape[0]
556
  if first_codes_length > 0:
557
- trim_ratio = max(0.0, min(float(start_length) / float(first_codes_length), 1.0))
 
 
558
  first_audio = decoded_audio_list[0]
559
  if trim_ratio >= 1.0:
560
  decoded_audio_list = decoded_audio_list[1:]
@@ -564,7 +703,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
564
 
565
  return decoded_audio_list
566
 
567
-
568
  def decode(self, output: List[Tuple[int, torch.Tensor]]):
569
  """
570
  1. 这里不管怎样,都需要一个完整的 assistant generation ids;
@@ -574,56 +712,219 @@ class MossTTSDelayProcessor(ProcessorMixin):
574
  genearted_messages = []
575
  for start_length, generation_ids in output:
576
  content = self._parse_text_codes(start_length, generation_ids[:, 0])
577
- audio_codes_list = self._parse_audio_codes(start_length, generation_ids[:, 1:])
 
 
578
  if content == "":
579
  message = None
580
  else:
581
  message = AssistantMessage(
582
  content=content,
583
- audio_codes_list=audio_codes_list
 
 
584
  )
585
  genearted_messages.append(message)
586
  return genearted_messages
587
 
588
  @staticmethod
589
- def loudness_normalize(wav: torch.Tensor, target_dbfs: float = -20, gain_range: tuple[float, float] = (-3.0, 3.0)) -> torch.Tensor:
 
 
 
 
590
  wav = wav.to(torch.float32)
591
- if wav.numel() == 0: return wav
592
- rms = torch.sqrt(torch.mean(wav ** 2))
593
- current_dbfs = 20.0 * torch.log10(rms + 1e-9)
594
  gain = float(target_dbfs - current_dbfs)
595
  gain = max(gain_range[0], min(gain, gain_range[1]))
596
  factor = 10.0 ** (gain / 20.0)
597
  return wav * factor
598
 
599
- def encode_audios_from_wav(self, wav_list: List[torch.Tensor], sampling_rate: int, n_vq: int = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  if isinstance(wav_list, torch.Tensor):
601
  wav_list = [wav_list]
602
  wav_list_ = []
603
  resample = False
604
  if sampling_rate != self.model_config.sampling_rate:
605
  resample = True
 
606
  for wav in wav_list:
607
  if wav.shape[0] > 1:
608
  wav = torch.mean(wav, dim=0, keepdim=True)
609
  if resample:
610
- wav = torchaudio.functional.resample(waveform=wav, orig_freq=sampling_rate, new_freq=self.model_config.sampling_rate)
 
 
 
 
 
611
  wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
612
- return self.audio_tokenizer.encode(wav_list_, n_vq)
613
 
614
- def encode_audios_from_path(self, wav_path_list: List[str], n_vq: int = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  if isinstance(wav_path_list, str):
616
  wav_path_list = [wav_path_list]
617
- wav_list = []
618
- sampling_rate = None
 
 
 
 
 
 
 
619
  for wav_path in wav_path_list:
620
  wav, sr = torchaudio.load(wav_path)
621
- if sampling_rate is None:
622
- sampling_rate = sr
623
- elif sampling_rate != sr:
624
- raise ValueError("sampling_rate of audios in the same batch should be the same.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  wav_list.append(wav)
626
- return self.encode_audios_from_wav(wav_list, sampling_rate, n_vq)
627
-
628
- def decode_audio_codes(self, audio_tokens_list: List[torch.Tensor]):
629
- return self.audio_tokenizer.decode(audio_tokens_list)
 
14
  # limitations under the License.
15
 
16
  import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18
  from dataclasses import dataclass
19
  from pathlib import Path
20
  import re
21
  import torchaudio
22
 
23
  import torch
24
+ from transformers import (
25
+ PreTrainedTokenizerBase,
26
+ BatchFeature,
27
+ ProcessorMixin,
28
+ logging,
29
+ AutoConfig,
30
+ AutoModel,
31
+ AutoTokenizer,
32
+ )
33
 
34
  from .configuration_moss_tts import MossTTSDelayConfig
35
 
 
42
 
43
  @dataclass
44
  class Message:
45
+ def to_dict(self) -> Dict[str, Any]:
46
+ raise NotImplementedError
47
 
48
 
49
  @dataclass
 
86
  if speaker_reference is not None:
87
  reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
88
  reference = "\n".join(reference)
89
+ audio_codes_list = [
90
+ speaker_reference
91
+ for speaker_reference in self.reference
92
+ if speaker_reference is not None
93
+ ]
94
  else:
95
  raise TypeError("`reference` should be exactly a list when it is not None.")
96
 
97
  content = (
98
+ template.replace("{reference}", str(reference))
 
99
  .replace("{instruction}", str(self.instruction))
100
  .replace("{tokens}", str(self.tokens))
101
  .replace("{quality}", str(self.quality))
 
112
  return {
113
  "role": "user",
114
  "content": self._content,
115
+ "audio_codes_list": self._audio_codes_list,
116
  }
117
 
118
 
119
  @dataclass
120
  class AssistantMessage(Message):
121
  audio_codes_list: List[Union[str, torch.Tensor]]
122
+ content: str = AUDIO_PLACEHOLDER
123
 
124
  def to_dict(self):
125
  return {
126
  "role": "assistant",
127
  "content": self.content,
128
+ "audio_codes_list": self.audio_codes_list,
129
  }
130
 
131
+
132
  USER_MESSAGE_FIELDS = (
133
  "text",
134
  "reference",
 
141
  )
142
 
143
 
 
 
 
 
 
144
  class MossTTSDelayProcessor(ProcessorMixin):
145
  tokenizer_class = "AutoTokenizer"
146
  audio_tokenizer_class = "AutoModel"
147
 
148
+ tokenizer: PreTrainedTokenizerBase
149
+ audio_tokenizer: Any
150
+
151
  def __init__(
152
  self,
153
  tokenizer: PreTrainedTokenizerBase,
154
+ audio_tokenizer: Any = None,
155
  model_config: Optional[MossTTSDelayConfig] = None,
156
+ **kwargs,
157
  ):
158
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
159
+
160
+ # Explicit assignments for type-checkers; ProcessorMixin sets these too.
161
+ self.tokenizer = tokenizer
162
+ self.audio_tokenizer = audio_tokenizer
163
  if model_config is None:
164
  model_config = MossTTSDelayConfig()
165
  self.model_config = model_config
 
168
  self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
169
  self.newline_token_id = 198
170
 
171
+ def _id_to_token(token_id: int) -> str:
172
+ tok = tokenizer.convert_ids_to_tokens(int(token_id))
173
+ if isinstance(tok, list):
174
+ return tok[0] if len(tok) > 0 else ""
175
+ return cast(str, tok)
176
+
177
+ self.audio_user_slot_token = _id_to_token(
178
+ self.model_config.audio_user_slot_token_id
179
+ )
180
+ self.audio_assistant_gen_slot_token = _id_to_token(
181
+ self.model_config.audio_assistant_gen_slot_token_id
182
+ )
183
+ self.audio_assistant_delay_slot_token = _id_to_token(
184
+ self.model_config.audio_assistant_delay_slot_token_id
185
+ )
186
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
187
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
188
 
189
  @classmethod
190
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
191
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
192
+ kwargs.pop("_from_auto", None)
193
+
194
+ audio_tokenizer_name_or_path = kwargs.pop(
195
+ "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
196
+ )
197
+
198
  pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
199
+ model_config = cast(
200
+ MossTTSDelayConfig,
201
+ AutoConfig.from_pretrained(
202
+ pretrained_model_name_or_path,
203
+ *args,
204
+ trust_remote_code=trust_remote_code,
205
+ **kwargs,
206
+ ),
207
+ )
208
+ tokenizer = AutoTokenizer.from_pretrained(
209
+ pretrained_model_name_or_path,
210
+ *args,
211
+ trust_remote_code=trust_remote_code,
212
+ **kwargs,
213
+ )
214
+ audio_tokenizer = AutoModel.from_pretrained(
215
+ audio_tokenizer_name_or_path,
216
+ trust_remote_code=trust_remote_code,
217
+ **kwargs,
218
+ )
219
 
 
 
 
 
220
  return cls(
221
  tokenizer=tokenizer,
222
  audio_tokenizer=audio_tokenizer,
223
  model_config=model_config,
224
+ **kwargs,
225
  )
226
+
227
+ def __call__(self, *args, **kwargs) -> BatchFeature:
228
+ conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
229
+ mode: str = kwargs.pop("mode", "generation")
230
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
231
+ n_vq: Optional[int] = kwargs.pop("n_vq", None)
232
+
233
+ # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
234
+ kwargs.pop("return_tensors", None)
235
+ kwargs.pop("padding", None)
236
+ kwargs.pop("truncation", None)
237
+
238
  """
239
+ mode only works when a Message is converted to a dict.
240
  """
241
+
242
  if mode not in {"generation", "continuation"}:
243
  raise RuntimeError
244
+
245
  if isinstance(conversations, (Message, Dict)):
246
  conversations = [conversations]
247
+
248
  truncation = False
249
  if mode == "continuation":
250
  truncation = True
251
+
252
  input_ids_list = []
253
  for conversation in conversations:
254
  if isinstance(conversation, (Message, Dict)):
255
  conversation = [conversation]
256
 
257
+ # Normalize early so downstream logic always deals with dict messages.
258
+ conversation = [self._normalize_message(m) for m in conversation]
259
+
260
  if (mode == "generation") ^ (len(conversation) % 2 != 0):
261
  raise ValueError
262
 
263
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
264
  raise ValueError
265
 
266
  unified_codes = []
267
  for message_idx, message in enumerate(conversation):
 
268
  if apply_chat_template:
269
+ add_generation_prompt = (
270
+ mode == "generation" and message_idx == len(conversation) - 1
271
+ )
272
  try:
273
  content = self.tokenizer.apply_chat_template(
274
  [{"role": message["role"], "content": message["content"]}],
 
278
  except TypeError:
279
  try:
280
  content = self.tokenizer.apply_chat_template(
281
+ [
282
+ {
283
+ "role": message["role"],
284
+ "content": message["content"],
285
+ }
286
+ ],
287
  add_generation_prompt=add_generation_prompt,
288
  )
289
  except Exception:
290
+ logger.warning(
291
+ "apply_chat_template failed; fallback to raw content."
292
+ )
293
  content = message["content"]
294
  else:
295
+ content = message["content"]
296
+
297
+ if not isinstance(content, str):
298
+ content = str(content)
299
+
300
+ # Batch-encode all path-based references in one call when possible.
301
+ # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
302
+ # instead of repeatedly calling it with batch=1.
303
+ raw_audio_items = message.get("audio_codes_list", [])
304
+
305
+ audio_codes_list: List[torch.Tensor] = []
306
+ if len(raw_audio_items) > 0:
307
+ encoded_items: List[Optional[torch.Tensor]] = [None] * len(
308
+ raw_audio_items
309
+ )
310
+ paths: List[str] = []
311
+ path_positions: List[int] = []
312
+
313
+ for idx, item in enumerate(raw_audio_items):
314
+ if isinstance(item, torch.Tensor):
315
+ if n_vq is not None and item.shape[1] != n_vq:
316
+ raise RuntimeError(
317
+ "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
318
+ )
319
+ encoded_items[idx] = item
320
+ continue
321
+
322
+ if isinstance(item, (str, os.PathLike)):
323
+ paths.append(str(item))
324
+ path_positions.append(idx)
325
+ continue
326
+
327
+ raise TypeError(
328
+ "Each audio item must be a torch.Tensor of codes or a path-like string."
329
+ )
330
+
331
+ if len(paths) > 0:
332
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
333
+ if len(encoded_from_paths) != len(paths):
334
+ raise RuntimeError(
335
+ "encode_audios_from_path returned an unexpected number of items."
336
+ )
337
+ for pos, codes in zip(path_positions, encoded_from_paths):
338
+ encoded_items[pos] = codes
339
+
340
+ audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
341
+ unified_codes.append(
342
+ self._get_unified_codes(
343
+ message["role"], content, audio_codes_list, truncation
344
+ )
345
+ )
346
 
347
  unified_codes = torch.cat(unified_codes)
348
  input_ids_list.append(unified_codes)
349
 
350
+ return BatchFeature(data=self._pad(input_ids_list))
351
 
352
  @staticmethod
353
  def build_user_message(
 
406
  def _pad(self, input_ids_list: List[torch.Tensor]):
407
  device = input_ids_list[0].device
408
  lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
409
+ pad_input_ids = torch.nn.utils.rnn.pad_sequence(
410
+ input_ids_list,
411
+ batch_first=True,
412
+ padding_value=self.model_config.audio_pad_code,
413
+ padding_side="left",
414
+ )
415
+ other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
416
+ 1
417
+ ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
418
  pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
419
+ attention_mask = torch.zeros(
420
+ pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
421
+ )
422
  attention_mask[~other_channel_mask] = 1
423
  attention_mask = attention_mask.bool()
424
  return {
425
+ "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
426
  "attention_mask": attention_mask,
427
  }
428
 
 
434
  gen_slot_token: str,
435
  delay_slot_token: str,
436
  audio_start_token: str,
437
+ audio_end_token: str,
438
  ) -> str:
439
  if n_vq < 1:
440
  raise ValueError(f"n_vq must be >= 1, got {n_vq}")
 
476
  return content, audio_codes_list
477
 
478
  if len(matches) != len(audio_codes_list):
479
+ raise ValueError(
480
+ "Audio placeholders do not match the provided audio codes list."
481
+ )
482
 
483
  new_audio_codes_list = []
484
  new_parts = []
 
488
  j = i
489
  while (
490
  j + 1 < len(matches)
491
+ and content[matches[j].end() : matches[j + 1].start()].strip() == ""
492
  ):
493
  j += 1
494
 
495
+ new_parts.append(content[last_pos : matches[i].start()])
496
  new_parts.append(AUDIO_PLACEHOLDER)
497
  last_pos = matches[j].end()
498
 
499
  if j == i:
500
  new_audio_codes_list.append(audio_codes_list[i])
501
  else:
502
+ new_audio_codes_list.append(
503
+ torch.cat(audio_codes_list[i : j + 1], dim=0)
504
+ )
505
 
506
  i = j + 1
507
 
 
517
  dtype=codes.dtype,
518
  )
519
  for i in range(codes.shape[1]):
520
+ delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
521
  return delayed_tokens
522
+
523
  @staticmethod
524
  def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
525
  tokens = torch.full(
 
529
  dtype=delay_codes.dtype,
530
  )
531
  for i in range(delay_codes.shape[1]):
532
+ tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
533
  return tokens
534
 
535
+ def _get_unified_codes(
536
+ self,
537
+ role: str,
538
+ content: str,
539
+ audio_codes_list: List[torch.Tensor],
540
+ truncation: bool,
541
+ ) -> torch.Tensor:
542
  """
543
  此时的 content 已经是带上了对话格式
544
  """
 
566
  audio_start_token=self.audio_start_token,
567
  audio_end_token=self.audio_end_token,
568
  )
569
+ text_codes = torch.tensor(
570
+ self.tokenizer.encode(content),
571
+ device=audio_codes_list[0].device if audio_codes_list else None,
572
+ )
573
 
574
+ audio_start_indices = torch.where(
575
+ text_codes == self.model_config.audio_start_token_id
576
+ )[0]
577
+ audio_end_indices = torch.where(
578
+ text_codes == self.model_config.audio_end_token_id
579
+ )[0]
580
+ if len(audio_start_indices) != len(audio_codes_list) or len(
581
+ audio_end_indices
582
+ ) != len(audio_codes_list):
583
+ raise ValueError(
584
+ "Audio placeholders do not match the provided audio codes list."
585
+ )
586
 
587
  delay_audio_codes_list = []
588
  if len(audio_codes_list) == 0:
 
594
  )
595
  else:
596
  prefix_idx = 0
597
+ for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
598
+ audio_start_indices, audio_end_indices, audio_codes_list
599
+ ):
600
+ audio_start_idx = int(audio_start_idx_t.item())
601
+ audio_end_idx = int(audio_end_idx_t.item())
602
+ delay_audio_codes = self.apply_delay_pattern(
603
+ audio_codes, self.model_config.audio_pad_code
604
+ )
605
  pad_codes = torch.full(
606
  (audio_start_idx - prefix_idx + 1, n_vq),
607
  self.model_config.audio_pad_code,
 
612
  prefix_idx = audio_end_idx
613
 
614
  if truncation:
615
+ delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
616
+ : -(n_vq - 1), :
617
+ ]
618
  else:
619
+ last_audio_end_idx = int(audio_end_indices[-1].item())
620
  pad_codes = torch.full(
621
+ (len(text_codes) - last_audio_end_idx, n_vq),
622
  self.model_config.audio_pad_code,
623
  device=audio_codes_list[0].device,
624
  dtype=audio_codes_list[0].dtype,
 
626
  delay_audio_codes_list.append(pad_codes)
627
 
628
  delay_audio_codes_list = torch.cat(delay_audio_codes_list)
629
+
630
  if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
631
+ text_codes = text_codes[: delay_audio_codes_list.shape[0]]
632
 
633
+ unified_codes = torch.cat(
634
+ [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
635
+ )
636
  return unified_codes
637
 
638
  def _parse_text_codes(self, start_length, text_codes):
639
+ text = cast(str, self.tokenizer.decode(text_codes))
640
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
641
+ text = text[len(prefix) :]
642
 
643
  AUDIO_PATTERN = re.compile(
644
+ rf"(?:{self.audio_start_token})?"
645
+ rf"(?:{self.audio_assistant_gen_slot_token})*"
646
+ rf"(?:{self.audio_assistant_delay_slot_token})*"
647
+ rf"{self.audio_end_token}"
648
  )
649
 
650
  def normalize_audio_segments(text: str) -> str:
651
  def repl(match: re.Match) -> str:
652
  seg = match.group(0)
653
+ # Replace with <|audio|> if gen_slot is present in the segment;
654
  if self.audio_assistant_gen_slot_token in seg:
655
  return AUDIO_PLACEHOLDER
656
+ # Otherwise, remove it.
657
  return ""
658
+
659
  return AUDIO_PATTERN.sub(repl, text)
660
 
661
  return normalize_audio_segments(text)
 
679
 
680
  audio_codes_list = [audio_codes[s] for s in segments_idx]
681
 
682
+ # Batch-decode all audio segments together.
683
+ decoded_audio_list = self.decode_audio_codes(audio_codes_list)
 
 
 
684
 
685
  # Keep codec causal context by decoding the whole first segment first,
686
  # then trim at waveform level according to start_length ratio.
687
+ if (
688
+ start_length > 0
689
+ and len(audio_codes_list) > 0
690
+ and len(decoded_audio_list) > 0
691
+ ):
692
  first_codes_length = audio_codes_list[0].shape[0]
693
  if first_codes_length > 0:
694
+ trim_ratio = max(
695
+ 0.0, min(float(start_length) / float(first_codes_length), 1.0)
696
+ )
697
  first_audio = decoded_audio_list[0]
698
  if trim_ratio >= 1.0:
699
  decoded_audio_list = decoded_audio_list[1:]
 
703
 
704
  return decoded_audio_list
705
 
 
706
  def decode(self, output: List[Tuple[int, torch.Tensor]]):
707
  """
708
  1. 这里不管怎样,都需要一个完整的 assistant generation ids;
 
712
  genearted_messages = []
713
  for start_length, generation_ids in output:
714
  content = self._parse_text_codes(start_length, generation_ids[:, 0])
715
+ audio_codes_list = self._parse_audio_codes(
716
+ start_length, generation_ids[:, 1:]
717
+ )
718
  if content == "":
719
  message = None
720
  else:
721
  message = AssistantMessage(
722
  content=content,
723
+ audio_codes_list=cast(
724
+ List[Union[str, torch.Tensor]], audio_codes_list
725
+ ),
726
  )
727
  genearted_messages.append(message)
728
  return genearted_messages
729
 
730
  @staticmethod
731
+ def loudness_normalize(
732
+ wav: torch.Tensor,
733
+ target_dbfs: float = -20,
734
+ gain_range: tuple[float, float] = (-3.0, 3.0),
735
+ ) -> torch.Tensor:
736
  wav = wav.to(torch.float32)
737
+ if wav.numel() == 0:
738
+ return wav
739
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
740
  gain = float(target_dbfs - current_dbfs)
741
  gain = max(gain_range[0], min(gain, gain_range[1]))
742
  factor = 10.0 ** (gain / 20.0)
743
  return wav * factor
744
 
745
+ def _get_audio_tokenizer_device(self) -> torch.device:
746
+ """Best-effort device inference for `self.audio_tokenizer`.
747
+
748
+ Notes:
749
+ - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
750
+ - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
751
+ """
752
+
753
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
754
+ if audio_tokenizer is None:
755
+ logger.warning(
756
+ "audio_tokenizer is not set on processor. Using CPU as default."
757
+ )
758
+ return torch.device("cpu")
759
+
760
+ device_attr = getattr(audio_tokenizer, "device", None)
761
+ if isinstance(device_attr, torch.device):
762
+ return device_attr
763
+
764
+ try:
765
+ return next(audio_tokenizer.parameters()).device
766
+ except StopIteration:
767
+ # No parameters (shouldn't happen for real models); default to CPU.
768
+ logger.warning(
769
+ "No parameters found on audio_tokenizer. Using CPU as default."
770
+ )
771
+ return torch.device("cpu")
772
+
773
+ def encode_audios_from_wav(
774
+ self,
775
+ wav_list: List[torch.Tensor],
776
+ sampling_rate: int,
777
+ n_vq: Optional[int] = None,
778
+ ):
779
+ if self.audio_tokenizer is None:
780
+ raise RuntimeError("audio_tokenizer is not set on processor.")
781
+ audio_tokenizer = self.audio_tokenizer
782
+
783
  if isinstance(wav_list, torch.Tensor):
784
  wav_list = [wav_list]
785
  wav_list_ = []
786
  resample = False
787
  if sampling_rate != self.model_config.sampling_rate:
788
  resample = True
789
+ device = self._get_audio_tokenizer_device()
790
  for wav in wav_list:
791
  if wav.shape[0] > 1:
792
  wav = torch.mean(wav, dim=0, keepdim=True)
793
  if resample:
794
+ wav = torchaudio.functional.resample(
795
+ waveform=wav,
796
+ orig_freq=sampling_rate,
797
+ new_freq=self.model_config.sampling_rate,
798
+ )
799
+ wav = wav.to(device)
800
  wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
 
801
 
802
+ # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
803
+ if hasattr(audio_tokenizer, "batch_encode"):
804
+ enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
805
+ audio_codes = enc.audio_codes # (NQ, B, T)
806
+ audio_codes_lengths = enc.audio_codes_lengths # (B,)
807
+ else:
808
+ # Fallback: use encode() with explicit padding.
809
+ max_len = max(int(wav.shape[-1]) for wav in wav_list_)
810
+ input_values = torch.zeros(
811
+ len(wav_list_), 1, max_len, device=device, dtype=torch.float32
812
+ )
813
+ padding_mask = torch.zeros(
814
+ len(wav_list_), max_len, device=device, dtype=torch.bool
815
+ )
816
+ for i, wav in enumerate(wav_list_):
817
+ this_len = int(wav.shape[-1])
818
+ input_values[i, 0, :this_len] = wav
819
+ padding_mask[i, :this_len] = True
820
+ enc = audio_tokenizer.encode(
821
+ input_values,
822
+ padding_mask=padding_mask,
823
+ num_quantizers=n_vq,
824
+ return_dict=True,
825
+ )
826
+ audio_codes = enc.audio_codes
827
+ audio_codes_lengths = enc.audio_codes_lengths
828
+
829
+ if audio_codes is None or audio_codes_lengths is None:
830
+ raise RuntimeError(
831
+ "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
832
+ )
833
+
834
+ # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
835
+ # and on CPU (so downstream text/audio packing remains device-agnostic).
836
+ codes_list: List[torch.Tensor] = []
837
+ for i in range(int(audio_codes.shape[1])):
838
+ length_i = int(audio_codes_lengths[i].item())
839
+ codes_i = (
840
+ audio_codes[:, i, :length_i]
841
+ .transpose(0, 1)
842
+ .contiguous()
843
+ .to(torch.long)
844
+ .cpu()
845
+ )
846
+ codes_list.append(codes_i)
847
+ return codes_list
848
+
849
+ def encode_audios_from_path(
850
+ self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
851
+ ):
852
  if isinstance(wav_path_list, str):
853
  wav_path_list = [wav_path_list]
854
+
855
+ if len(wav_path_list) == 0:
856
+ raise ValueError("Empty wav_path_list")
857
+
858
+ # Load + (if needed) resample each wav independently, so callers can
859
+ # pass a heterogeneous batch of files while still benefiting from
860
+ # audio_tokenizer.batch_encode.
861
+ target_sr = int(self.model_config.sampling_rate)
862
+ wav_list: List[torch.Tensor] = []
863
  for wav_path in wav_path_list:
864
  wav, sr = torchaudio.load(wav_path)
865
+ if int(sr) != target_sr:
866
+ wav = torchaudio.functional.resample(
867
+ waveform=wav,
868
+ orig_freq=int(sr),
869
+ new_freq=target_sr,
870
+ )
871
+ wav_list.append(wav)
872
+
873
+ return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
874
+
875
+ def decode_audio_codes(
876
+ self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
877
+ ):
878
+ if self.audio_tokenizer is None:
879
+ raise RuntimeError("audio_tokenizer is not set on processor.")
880
+ audio_tokenizer = self.audio_tokenizer
881
+
882
+ if isinstance(audio_tokens_list, torch.Tensor):
883
+ audio_tokens_list = [audio_tokens_list]
884
+ if len(audio_tokens_list) == 0:
885
+ return []
886
+
887
+ device = self._get_audio_tokenizer_device()
888
+
889
+ # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
890
+ codes_list = [
891
+ codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
892
+ for codes in audio_tokens_list
893
+ ]
894
+
895
+ if hasattr(audio_tokenizer, "batch_decode"):
896
+ dec = audio_tokenizer.batch_decode(codes_list)
897
+ audio = dec.audio # (B, C, T)
898
+ audio_lengths = dec.audio_lengths # (B,)
899
+ else:
900
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
901
+ nq = int(codes_list[0].shape[0])
902
+ max_t = max(int(c.shape[1]) for c in codes_list)
903
+ audio_codes = torch.zeros(
904
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
905
+ )
906
+ padding_mask = torch.zeros(
907
+ len(codes_list), max_t, device=device, dtype=torch.bool
908
+ )
909
+ for i, c in enumerate(codes_list):
910
+ t = int(c.shape[1])
911
+ audio_codes[:, i, :t] = c
912
+ padding_mask[i, :t] = True
913
+ dec = audio_tokenizer.decode(
914
+ audio_codes, padding_mask=padding_mask, return_dict=True
915
+ )
916
+ audio = dec.audio
917
+ audio_lengths = dec.audio_lengths
918
+
919
+ if audio is None or audio_lengths is None:
920
+ raise RuntimeError(
921
+ "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
922
+ )
923
+
924
+ # Return historical contract: list of 1D waveforms (T,)
925
+ wav_list: List[torch.Tensor] = []
926
+ for i in range(int(audio.shape[0])):
927
+ length_i = int(audio_lengths[i].item())
928
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
929
  wav_list.append(wav)
930
+ return wav_list