Commit
·
f0e23bb
1
Parent(s):
e4aa3d2
Refactor processing_moss_tts.py: Improve type hints, enhance message classes, and streamline audio token handling
Browse files- processing_moss_tts.py +432 -131
processing_moss_tts.py
CHANGED
|
@@ -14,14 +14,22 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
import os
|
| 17 |
-
from typing import Dict, List, Optional, Tuple, Type, Union, Literal, Final
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from pathlib import Path
|
| 20 |
import re
|
| 21 |
import torchaudio
|
| 22 |
|
| 23 |
import torch
|
| 24 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
from .configuration_moss_tts import MossTTSDelayConfig
|
| 27 |
|
|
@@ -34,8 +42,8 @@ AUDIO_PLACEHOLDER = "<|audio|>"
|
|
| 34 |
|
| 35 |
@dataclass
|
| 36 |
class Message:
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
|
| 41 |
@dataclass
|
|
@@ -78,13 +86,16 @@ class UserMessage(Message):
|
|
| 78 |
if speaker_reference is not None:
|
| 79 |
reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
|
| 80 |
reference = "\n".join(reference)
|
| 81 |
-
audio_codes_list = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
else:
|
| 83 |
raise TypeError("`reference` should be exactly a list when it is not None.")
|
| 84 |
|
| 85 |
content = (
|
| 86 |
-
template
|
| 87 |
-
.replace("{reference}", str(reference))
|
| 88 |
.replace("{instruction}", str(self.instruction))
|
| 89 |
.replace("{tokens}", str(self.tokens))
|
| 90 |
.replace("{quality}", str(self.quality))
|
|
@@ -101,22 +112,23 @@ class UserMessage(Message):
|
|
| 101 |
return {
|
| 102 |
"role": "user",
|
| 103 |
"content": self._content,
|
| 104 |
-
"audio_codes_list": self._audio_codes_list
|
| 105 |
}
|
| 106 |
|
| 107 |
|
| 108 |
@dataclass
|
| 109 |
class AssistantMessage(Message):
|
| 110 |
audio_codes_list: List[Union[str, torch.Tensor]]
|
| 111 |
-
content: str = AUDIO_PLACEHOLDER
|
| 112 |
|
| 113 |
def to_dict(self):
|
| 114 |
return {
|
| 115 |
"role": "assistant",
|
| 116 |
"content": self.content,
|
| 117 |
-
"audio_codes_list": self.audio_codes_list
|
| 118 |
}
|
| 119 |
|
|
|
|
| 120 |
USER_MESSAGE_FIELDS = (
|
| 121 |
"text",
|
| 122 |
"reference",
|
|
@@ -129,27 +141,25 @@ USER_MESSAGE_FIELDS = (
|
|
| 129 |
)
|
| 130 |
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
class MossTTSDelayProcessor(ProcessorMixin):
|
| 138 |
tokenizer_class = "AutoTokenizer"
|
| 139 |
audio_tokenizer_class = "AutoModel"
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
def __init__(
|
| 142 |
self,
|
| 143 |
tokenizer: PreTrainedTokenizerBase,
|
| 144 |
-
audio_tokenizer:
|
| 145 |
model_config: Optional[MossTTSDelayConfig] = None,
|
| 146 |
-
**kwargs
|
| 147 |
):
|
| 148 |
-
super().__init__(
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
if model_config is None:
|
| 154 |
model_config = MossTTSDelayConfig()
|
| 155 |
self.model_config = model_config
|
|
@@ -158,68 +168,107 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 158 |
self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 159 |
self.newline_token_id = 198
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
@classmethod
|
| 168 |
-
def from_pretrained(cls, pretrained_model_name_or_path,
|
| 169 |
-
kwargs.pop("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 171 |
-
model_config =
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
audio_tokenizer_name_or_path = kwargs.pop("codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer")
|
| 175 |
-
assert isinstance(audio_tokenizer_name_or_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_name_or_path)}"
|
| 176 |
-
audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
|
| 177 |
-
|
| 178 |
return cls(
|
| 179 |
tokenizer=tokenizer,
|
| 180 |
audio_tokenizer=audio_tokenizer,
|
| 181 |
model_config=model_config,
|
| 182 |
-
**kwargs
|
| 183 |
)
|
| 184 |
-
|
| 185 |
-
def __call__(
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
"""
|
| 194 |
-
mode
|
| 195 |
"""
|
| 196 |
-
|
| 197 |
if mode not in {"generation", "continuation"}:
|
| 198 |
raise RuntimeError
|
| 199 |
-
|
| 200 |
if isinstance(conversations, (Message, Dict)):
|
| 201 |
conversations = [conversations]
|
| 202 |
-
|
| 203 |
truncation = False
|
| 204 |
if mode == "continuation":
|
| 205 |
truncation = True
|
| 206 |
-
|
| 207 |
input_ids_list = []
|
| 208 |
for conversation in conversations:
|
| 209 |
if isinstance(conversation, (Message, Dict)):
|
| 210 |
conversation = [conversation]
|
| 211 |
|
|
|
|
|
|
|
|
|
|
| 212 |
if (mode == "generation") ^ (len(conversation) % 2 != 0):
|
| 213 |
raise ValueError
|
| 214 |
|
| 215 |
-
if (mode == "generation") ^ (conversation[-1][
|
| 216 |
raise ValueError
|
| 217 |
|
| 218 |
unified_codes = []
|
| 219 |
for message_idx, message in enumerate(conversation):
|
| 220 |
-
message = self._normalize_message(message)
|
| 221 |
if apply_chat_template:
|
| 222 |
-
add_generation_prompt =
|
|
|
|
|
|
|
| 223 |
try:
|
| 224 |
content = self.tokenizer.apply_chat_template(
|
| 225 |
[{"role": message["role"], "content": message["content"]}],
|
|
@@ -229,29 +278,76 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 229 |
except TypeError:
|
| 230 |
try:
|
| 231 |
content = self.tokenizer.apply_chat_template(
|
| 232 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
add_generation_prompt=add_generation_prompt,
|
| 234 |
)
|
| 235 |
except Exception:
|
| 236 |
-
logger.warning(
|
|
|
|
|
|
|
| 237 |
content = message["content"]
|
| 238 |
else:
|
| 239 |
-
content = message[
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
unified_codes = torch.cat(unified_codes)
|
| 252 |
input_ids_list.append(unified_codes)
|
| 253 |
|
| 254 |
-
return self._pad(input_ids_list)
|
| 255 |
|
| 256 |
@staticmethod
|
| 257 |
def build_user_message(
|
|
@@ -310,14 +406,23 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 310 |
def _pad(self, input_ids_list: List[torch.Tensor]):
|
| 311 |
device = input_ids_list[0].device
|
| 312 |
lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
|
| 313 |
-
pad_input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
|
| 316 |
-
attention_mask = torch.zeros(
|
|
|
|
|
|
|
| 317 |
attention_mask[~other_channel_mask] = 1
|
| 318 |
attention_mask = attention_mask.bool()
|
| 319 |
return {
|
| 320 |
-
"input_ids": pad_input_ids,
|
| 321 |
"attention_mask": attention_mask,
|
| 322 |
}
|
| 323 |
|
|
@@ -329,7 +434,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 329 |
gen_slot_token: str,
|
| 330 |
delay_slot_token: str,
|
| 331 |
audio_start_token: str,
|
| 332 |
-
audio_end_token: str
|
| 333 |
) -> str:
|
| 334 |
if n_vq < 1:
|
| 335 |
raise ValueError(f"n_vq must be >= 1, got {n_vq}")
|
|
@@ -371,7 +476,9 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 371 |
return content, audio_codes_list
|
| 372 |
|
| 373 |
if len(matches) != len(audio_codes_list):
|
| 374 |
-
raise ValueError(
|
|
|
|
|
|
|
| 375 |
|
| 376 |
new_audio_codes_list = []
|
| 377 |
new_parts = []
|
|
@@ -381,18 +488,20 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 381 |
j = i
|
| 382 |
while (
|
| 383 |
j + 1 < len(matches)
|
| 384 |
-
and content[matches[j].end():matches[j + 1].start()].strip() == ""
|
| 385 |
):
|
| 386 |
j += 1
|
| 387 |
|
| 388 |
-
new_parts.append(content[last_pos:matches[i].start()])
|
| 389 |
new_parts.append(AUDIO_PLACEHOLDER)
|
| 390 |
last_pos = matches[j].end()
|
| 391 |
|
| 392 |
if j == i:
|
| 393 |
new_audio_codes_list.append(audio_codes_list[i])
|
| 394 |
else:
|
| 395 |
-
new_audio_codes_list.append(
|
|
|
|
|
|
|
| 396 |
|
| 397 |
i = j + 1
|
| 398 |
|
|
@@ -408,9 +517,9 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 408 |
dtype=codes.dtype,
|
| 409 |
)
|
| 410 |
for i in range(codes.shape[1]):
|
| 411 |
-
delayed_tokens[i: i + codes.shape[0], i] = codes[:, i]
|
| 412 |
return delayed_tokens
|
| 413 |
-
|
| 414 |
@staticmethod
|
| 415 |
def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
|
| 416 |
tokens = torch.full(
|
|
@@ -420,11 +529,16 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 420 |
dtype=delay_codes.dtype,
|
| 421 |
)
|
| 422 |
for i in range(delay_codes.shape[1]):
|
| 423 |
-
tokens[:, i] = delay_codes[i: i + tokens.shape[0], i]
|
| 424 |
return tokens
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
"""
|
| 429 |
此时的 content 已经是带上了对话格式
|
| 430 |
"""
|
|
@@ -452,12 +566,23 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 452 |
audio_start_token=self.audio_start_token,
|
| 453 |
audio_end_token=self.audio_end_token,
|
| 454 |
)
|
| 455 |
-
text_codes = torch.tensor(
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
-
audio_start_indices = torch.where(
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
delay_audio_codes_list = []
|
| 463 |
if len(audio_codes_list) == 0:
|
|
@@ -469,8 +594,14 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 469 |
)
|
| 470 |
else:
|
| 471 |
prefix_idx = 0
|
| 472 |
-
for
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
pad_codes = torch.full(
|
| 475 |
(audio_start_idx - prefix_idx + 1, n_vq),
|
| 476 |
self.model_config.audio_pad_code,
|
|
@@ -481,10 +612,13 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 481 |
prefix_idx = audio_end_idx
|
| 482 |
|
| 483 |
if truncation:
|
| 484 |
-
delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
|
|
|
|
|
|
|
| 485 |
else:
|
|
|
|
| 486 |
pad_codes = torch.full(
|
| 487 |
-
(len(text_codes) -
|
| 488 |
self.model_config.audio_pad_code,
|
| 489 |
device=audio_codes_list[0].device,
|
| 490 |
dtype=audio_codes_list[0].dtype,
|
|
@@ -492,34 +626,36 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 492 |
delay_audio_codes_list.append(pad_codes)
|
| 493 |
|
| 494 |
delay_audio_codes_list = torch.cat(delay_audio_codes_list)
|
| 495 |
-
|
| 496 |
if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
|
| 497 |
-
text_codes = text_codes[:delay_audio_codes_list.shape[0]]
|
| 498 |
|
| 499 |
-
unified_codes = torch.cat(
|
|
|
|
|
|
|
| 500 |
return unified_codes
|
| 501 |
|
| 502 |
def _parse_text_codes(self, start_length, text_codes):
|
| 503 |
-
text = self.tokenizer.decode(text_codes)
|
| 504 |
-
prefix = self.tokenizer.decode(text_codes[:start_length])
|
| 505 |
-
text = text[len(prefix):]
|
| 506 |
|
| 507 |
AUDIO_PATTERN = re.compile(
|
| 508 |
-
rf
|
| 509 |
-
rf
|
| 510 |
-
rf
|
| 511 |
-
rf
|
| 512 |
)
|
| 513 |
|
| 514 |
def normalize_audio_segments(text: str) -> str:
|
| 515 |
def repl(match: re.Match) -> str:
|
| 516 |
seg = match.group(0)
|
| 517 |
-
#
|
| 518 |
if self.audio_assistant_gen_slot_token in seg:
|
| 519 |
return AUDIO_PLACEHOLDER
|
| 520 |
-
#
|
| 521 |
return ""
|
| 522 |
-
|
| 523 |
return AUDIO_PATTERN.sub(repl, text)
|
| 524 |
|
| 525 |
return normalize_audio_segments(text)
|
|
@@ -543,18 +679,21 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 543 |
|
| 544 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
| 545 |
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
decoded_segment = self.decode_audio_codes([segment_codes])
|
| 549 |
-
if len(decoded_segment) > 0:
|
| 550 |
-
decoded_audio_list.append(decoded_segment[0])
|
| 551 |
|
| 552 |
# Keep codec causal context by decoding the whole first segment first,
|
| 553 |
# then trim at waveform level according to start_length ratio.
|
| 554 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
first_codes_length = audio_codes_list[0].shape[0]
|
| 556 |
if first_codes_length > 0:
|
| 557 |
-
trim_ratio = max(
|
|
|
|
|
|
|
| 558 |
first_audio = decoded_audio_list[0]
|
| 559 |
if trim_ratio >= 1.0:
|
| 560 |
decoded_audio_list = decoded_audio_list[1:]
|
|
@@ -564,7 +703,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 564 |
|
| 565 |
return decoded_audio_list
|
| 566 |
|
| 567 |
-
|
| 568 |
def decode(self, output: List[Tuple[int, torch.Tensor]]):
|
| 569 |
"""
|
| 570 |
1. 这里不管怎样,都需要一个完整的 assistant generation ids;
|
|
@@ -574,56 +712,219 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 574 |
genearted_messages = []
|
| 575 |
for start_length, generation_ids in output:
|
| 576 |
content = self._parse_text_codes(start_length, generation_ids[:, 0])
|
| 577 |
-
audio_codes_list = self._parse_audio_codes(
|
|
|
|
|
|
|
| 578 |
if content == "":
|
| 579 |
message = None
|
| 580 |
else:
|
| 581 |
message = AssistantMessage(
|
| 582 |
content=content,
|
| 583 |
-
audio_codes_list=
|
|
|
|
|
|
|
| 584 |
)
|
| 585 |
genearted_messages.append(message)
|
| 586 |
return genearted_messages
|
| 587 |
|
| 588 |
@staticmethod
|
| 589 |
-
def loudness_normalize(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
wav = wav.to(torch.float32)
|
| 591 |
-
if wav.numel() == 0:
|
| 592 |
-
|
| 593 |
-
current_dbfs =
|
| 594 |
gain = float(target_dbfs - current_dbfs)
|
| 595 |
gain = max(gain_range[0], min(gain, gain_range[1]))
|
| 596 |
factor = 10.0 ** (gain / 20.0)
|
| 597 |
return wav * factor
|
| 598 |
|
| 599 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
if isinstance(wav_list, torch.Tensor):
|
| 601 |
wav_list = [wav_list]
|
| 602 |
wav_list_ = []
|
| 603 |
resample = False
|
| 604 |
if sampling_rate != self.model_config.sampling_rate:
|
| 605 |
resample = True
|
|
|
|
| 606 |
for wav in wav_list:
|
| 607 |
if wav.shape[0] > 1:
|
| 608 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
| 609 |
if resample:
|
| 610 |
-
wav = torchaudio.functional.resample(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
|
| 612 |
-
return self.audio_tokenizer.encode(wav_list_, n_vq)
|
| 613 |
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
if isinstance(wav_path_list, str):
|
| 616 |
wav_path_list = [wav_path_list]
|
| 617 |
-
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
for wav_path in wav_path_list:
|
| 620 |
wav, sr = torchaudio.load(wav_path)
|
| 621 |
-
if
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
wav_list.append(wav)
|
| 626 |
-
return
|
| 627 |
-
|
| 628 |
-
def decode_audio_codes(self, audio_tokens_list: List[torch.Tensor]):
|
| 629 |
-
return self.audio_tokenizer.decode(audio_tokens_list)
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
import os
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from pathlib import Path
|
| 20 |
import re
|
| 21 |
import torchaudio
|
| 22 |
|
| 23 |
import torch
|
| 24 |
+
from transformers import (
|
| 25 |
+
PreTrainedTokenizerBase,
|
| 26 |
+
BatchFeature,
|
| 27 |
+
ProcessorMixin,
|
| 28 |
+
logging,
|
| 29 |
+
AutoConfig,
|
| 30 |
+
AutoModel,
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
)
|
| 33 |
|
| 34 |
from .configuration_moss_tts import MossTTSDelayConfig
|
| 35 |
|
|
|
|
| 42 |
|
| 43 |
@dataclass
|
| 44 |
class Message:
|
| 45 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 46 |
+
raise NotImplementedError
|
| 47 |
|
| 48 |
|
| 49 |
@dataclass
|
|
|
|
| 86 |
if speaker_reference is not None:
|
| 87 |
reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
|
| 88 |
reference = "\n".join(reference)
|
| 89 |
+
audio_codes_list = [
|
| 90 |
+
speaker_reference
|
| 91 |
+
for speaker_reference in self.reference
|
| 92 |
+
if speaker_reference is not None
|
| 93 |
+
]
|
| 94 |
else:
|
| 95 |
raise TypeError("`reference` should be exactly a list when it is not None.")
|
| 96 |
|
| 97 |
content = (
|
| 98 |
+
template.replace("{reference}", str(reference))
|
|
|
|
| 99 |
.replace("{instruction}", str(self.instruction))
|
| 100 |
.replace("{tokens}", str(self.tokens))
|
| 101 |
.replace("{quality}", str(self.quality))
|
|
|
|
| 112 |
return {
|
| 113 |
"role": "user",
|
| 114 |
"content": self._content,
|
| 115 |
+
"audio_codes_list": self._audio_codes_list,
|
| 116 |
}
|
| 117 |
|
| 118 |
|
| 119 |
@dataclass
|
| 120 |
class AssistantMessage(Message):
|
| 121 |
audio_codes_list: List[Union[str, torch.Tensor]]
|
| 122 |
+
content: str = AUDIO_PLACEHOLDER
|
| 123 |
|
| 124 |
def to_dict(self):
|
| 125 |
return {
|
| 126 |
"role": "assistant",
|
| 127 |
"content": self.content,
|
| 128 |
+
"audio_codes_list": self.audio_codes_list,
|
| 129 |
}
|
| 130 |
|
| 131 |
+
|
| 132 |
USER_MESSAGE_FIELDS = (
|
| 133 |
"text",
|
| 134 |
"reference",
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
class MossTTSDelayProcessor(ProcessorMixin):
|
| 145 |
tokenizer_class = "AutoTokenizer"
|
| 146 |
audio_tokenizer_class = "AutoModel"
|
| 147 |
|
| 148 |
+
tokenizer: PreTrainedTokenizerBase
|
| 149 |
+
audio_tokenizer: Any
|
| 150 |
+
|
| 151 |
def __init__(
|
| 152 |
self,
|
| 153 |
tokenizer: PreTrainedTokenizerBase,
|
| 154 |
+
audio_tokenizer: Any = None,
|
| 155 |
model_config: Optional[MossTTSDelayConfig] = None,
|
| 156 |
+
**kwargs,
|
| 157 |
):
|
| 158 |
+
super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
|
| 159 |
+
|
| 160 |
+
# Explicit assignments for type-checkers; ProcessorMixin sets these too.
|
| 161 |
+
self.tokenizer = tokenizer
|
| 162 |
+
self.audio_tokenizer = audio_tokenizer
|
| 163 |
if model_config is None:
|
| 164 |
model_config = MossTTSDelayConfig()
|
| 165 |
self.model_config = model_config
|
|
|
|
| 168 |
self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 169 |
self.newline_token_id = 198
|
| 170 |
|
| 171 |
+
def _id_to_token(token_id: int) -> str:
|
| 172 |
+
tok = tokenizer.convert_ids_to_tokens(int(token_id))
|
| 173 |
+
if isinstance(tok, list):
|
| 174 |
+
return tok[0] if len(tok) > 0 else ""
|
| 175 |
+
return cast(str, tok)
|
| 176 |
+
|
| 177 |
+
self.audio_user_slot_token = _id_to_token(
|
| 178 |
+
self.model_config.audio_user_slot_token_id
|
| 179 |
+
)
|
| 180 |
+
self.audio_assistant_gen_slot_token = _id_to_token(
|
| 181 |
+
self.model_config.audio_assistant_gen_slot_token_id
|
| 182 |
+
)
|
| 183 |
+
self.audio_assistant_delay_slot_token = _id_to_token(
|
| 184 |
+
self.model_config.audio_assistant_delay_slot_token_id
|
| 185 |
+
)
|
| 186 |
+
self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
|
| 187 |
+
self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
|
| 188 |
|
| 189 |
@classmethod
|
| 190 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 191 |
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
| 192 |
+
kwargs.pop("_from_auto", None)
|
| 193 |
+
|
| 194 |
+
audio_tokenizer_name_or_path = kwargs.pop(
|
| 195 |
+
"codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 199 |
+
model_config = cast(
|
| 200 |
+
MossTTSDelayConfig,
|
| 201 |
+
AutoConfig.from_pretrained(
|
| 202 |
+
pretrained_model_name_or_path,
|
| 203 |
+
*args,
|
| 204 |
+
trust_remote_code=trust_remote_code,
|
| 205 |
+
**kwargs,
|
| 206 |
+
),
|
| 207 |
+
)
|
| 208 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 209 |
+
pretrained_model_name_or_path,
|
| 210 |
+
*args,
|
| 211 |
+
trust_remote_code=trust_remote_code,
|
| 212 |
+
**kwargs,
|
| 213 |
+
)
|
| 214 |
+
audio_tokenizer = AutoModel.from_pretrained(
|
| 215 |
+
audio_tokenizer_name_or_path,
|
| 216 |
+
trust_remote_code=trust_remote_code,
|
| 217 |
+
**kwargs,
|
| 218 |
+
)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
return cls(
|
| 221 |
tokenizer=tokenizer,
|
| 222 |
audio_tokenizer=audio_tokenizer,
|
| 223 |
model_config=model_config,
|
| 224 |
+
**kwargs,
|
| 225 |
)
|
| 226 |
+
|
| 227 |
+
def __call__(self, *args, **kwargs) -> BatchFeature:
|
| 228 |
+
conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
|
| 229 |
+
mode: str = kwargs.pop("mode", "generation")
|
| 230 |
+
apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
|
| 231 |
+
n_vq: Optional[int] = kwargs.pop("n_vq", None)
|
| 232 |
+
|
| 233 |
+
# Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
|
| 234 |
+
kwargs.pop("return_tensors", None)
|
| 235 |
+
kwargs.pop("padding", None)
|
| 236 |
+
kwargs.pop("truncation", None)
|
| 237 |
+
|
| 238 |
"""
|
| 239 |
+
mode only works when a Message is converted to a dict.
|
| 240 |
"""
|
| 241 |
+
|
| 242 |
if mode not in {"generation", "continuation"}:
|
| 243 |
raise RuntimeError
|
| 244 |
+
|
| 245 |
if isinstance(conversations, (Message, Dict)):
|
| 246 |
conversations = [conversations]
|
| 247 |
+
|
| 248 |
truncation = False
|
| 249 |
if mode == "continuation":
|
| 250 |
truncation = True
|
| 251 |
+
|
| 252 |
input_ids_list = []
|
| 253 |
for conversation in conversations:
|
| 254 |
if isinstance(conversation, (Message, Dict)):
|
| 255 |
conversation = [conversation]
|
| 256 |
|
| 257 |
+
# Normalize early so downstream logic always deals with dict messages.
|
| 258 |
+
conversation = [self._normalize_message(m) for m in conversation]
|
| 259 |
+
|
| 260 |
if (mode == "generation") ^ (len(conversation) % 2 != 0):
|
| 261 |
raise ValueError
|
| 262 |
|
| 263 |
+
if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
|
| 264 |
raise ValueError
|
| 265 |
|
| 266 |
unified_codes = []
|
| 267 |
for message_idx, message in enumerate(conversation):
|
|
|
|
| 268 |
if apply_chat_template:
|
| 269 |
+
add_generation_prompt = (
|
| 270 |
+
mode == "generation" and message_idx == len(conversation) - 1
|
| 271 |
+
)
|
| 272 |
try:
|
| 273 |
content = self.tokenizer.apply_chat_template(
|
| 274 |
[{"role": message["role"], "content": message["content"]}],
|
|
|
|
| 278 |
except TypeError:
|
| 279 |
try:
|
| 280 |
content = self.tokenizer.apply_chat_template(
|
| 281 |
+
[
|
| 282 |
+
{
|
| 283 |
+
"role": message["role"],
|
| 284 |
+
"content": message["content"],
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
add_generation_prompt=add_generation_prompt,
|
| 288 |
)
|
| 289 |
except Exception:
|
| 290 |
+
logger.warning(
|
| 291 |
+
"apply_chat_template failed; fallback to raw content."
|
| 292 |
+
)
|
| 293 |
content = message["content"]
|
| 294 |
else:
|
| 295 |
+
content = message["content"]
|
| 296 |
+
|
| 297 |
+
if not isinstance(content, str):
|
| 298 |
+
content = str(content)
|
| 299 |
+
|
| 300 |
+
# Batch-encode all path-based references in one call when possible.
|
| 301 |
+
# This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
|
| 302 |
+
# instead of repeatedly calling it with batch=1.
|
| 303 |
+
raw_audio_items = message.get("audio_codes_list", [])
|
| 304 |
+
|
| 305 |
+
audio_codes_list: List[torch.Tensor] = []
|
| 306 |
+
if len(raw_audio_items) > 0:
|
| 307 |
+
encoded_items: List[Optional[torch.Tensor]] = [None] * len(
|
| 308 |
+
raw_audio_items
|
| 309 |
+
)
|
| 310 |
+
paths: List[str] = []
|
| 311 |
+
path_positions: List[int] = []
|
| 312 |
+
|
| 313 |
+
for idx, item in enumerate(raw_audio_items):
|
| 314 |
+
if isinstance(item, torch.Tensor):
|
| 315 |
+
if n_vq is not None and item.shape[1] != n_vq:
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
"audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
|
| 318 |
+
)
|
| 319 |
+
encoded_items[idx] = item
|
| 320 |
+
continue
|
| 321 |
+
|
| 322 |
+
if isinstance(item, (str, os.PathLike)):
|
| 323 |
+
paths.append(str(item))
|
| 324 |
+
path_positions.append(idx)
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
raise TypeError(
|
| 328 |
+
"Each audio item must be a torch.Tensor of codes or a path-like string."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if len(paths) > 0:
|
| 332 |
+
encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
|
| 333 |
+
if len(encoded_from_paths) != len(paths):
|
| 334 |
+
raise RuntimeError(
|
| 335 |
+
"encode_audios_from_path returned an unexpected number of items."
|
| 336 |
+
)
|
| 337 |
+
for pos, codes in zip(path_positions, encoded_from_paths):
|
| 338 |
+
encoded_items[pos] = codes
|
| 339 |
+
|
| 340 |
+
audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
|
| 341 |
+
unified_codes.append(
|
| 342 |
+
self._get_unified_codes(
|
| 343 |
+
message["role"], content, audio_codes_list, truncation
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
|
| 347 |
unified_codes = torch.cat(unified_codes)
|
| 348 |
input_ids_list.append(unified_codes)
|
| 349 |
|
| 350 |
+
return BatchFeature(data=self._pad(input_ids_list))
|
| 351 |
|
| 352 |
@staticmethod
|
| 353 |
def build_user_message(
|
|
|
|
| 406 |
def _pad(self, input_ids_list: List[torch.Tensor]):
|
| 407 |
device = input_ids_list[0].device
|
| 408 |
lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
|
| 409 |
+
pad_input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 410 |
+
input_ids_list,
|
| 411 |
+
batch_first=True,
|
| 412 |
+
padding_value=self.model_config.audio_pad_code,
|
| 413 |
+
padding_side="left",
|
| 414 |
+
)
|
| 415 |
+
other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
|
| 416 |
+
1
|
| 417 |
+
) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
|
| 418 |
pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
|
| 419 |
+
attention_mask = torch.zeros(
|
| 420 |
+
pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
|
| 421 |
+
)
|
| 422 |
attention_mask[~other_channel_mask] = 1
|
| 423 |
attention_mask = attention_mask.bool()
|
| 424 |
return {
|
| 425 |
+
"input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
|
| 426 |
"attention_mask": attention_mask,
|
| 427 |
}
|
| 428 |
|
|
|
|
| 434 |
gen_slot_token: str,
|
| 435 |
delay_slot_token: str,
|
| 436 |
audio_start_token: str,
|
| 437 |
+
audio_end_token: str,
|
| 438 |
) -> str:
|
| 439 |
if n_vq < 1:
|
| 440 |
raise ValueError(f"n_vq must be >= 1, got {n_vq}")
|
|
|
|
| 476 |
return content, audio_codes_list
|
| 477 |
|
| 478 |
if len(matches) != len(audio_codes_list):
|
| 479 |
+
raise ValueError(
|
| 480 |
+
"Audio placeholders do not match the provided audio codes list."
|
| 481 |
+
)
|
| 482 |
|
| 483 |
new_audio_codes_list = []
|
| 484 |
new_parts = []
|
|
|
|
| 488 |
j = i
|
| 489 |
while (
|
| 490 |
j + 1 < len(matches)
|
| 491 |
+
and content[matches[j].end() : matches[j + 1].start()].strip() == ""
|
| 492 |
):
|
| 493 |
j += 1
|
| 494 |
|
| 495 |
+
new_parts.append(content[last_pos : matches[i].start()])
|
| 496 |
new_parts.append(AUDIO_PLACEHOLDER)
|
| 497 |
last_pos = matches[j].end()
|
| 498 |
|
| 499 |
if j == i:
|
| 500 |
new_audio_codes_list.append(audio_codes_list[i])
|
| 501 |
else:
|
| 502 |
+
new_audio_codes_list.append(
|
| 503 |
+
torch.cat(audio_codes_list[i : j + 1], dim=0)
|
| 504 |
+
)
|
| 505 |
|
| 506 |
i = j + 1
|
| 507 |
|
|
|
|
| 517 |
dtype=codes.dtype,
|
| 518 |
)
|
| 519 |
for i in range(codes.shape[1]):
|
| 520 |
+
delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
|
| 521 |
return delayed_tokens
|
| 522 |
+
|
| 523 |
@staticmethod
|
| 524 |
def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
|
| 525 |
tokens = torch.full(
|
|
|
|
| 529 |
dtype=delay_codes.dtype,
|
| 530 |
)
|
| 531 |
for i in range(delay_codes.shape[1]):
|
| 532 |
+
tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
|
| 533 |
return tokens
|
| 534 |
|
| 535 |
+
def _get_unified_codes(
|
| 536 |
+
self,
|
| 537 |
+
role: str,
|
| 538 |
+
content: str,
|
| 539 |
+
audio_codes_list: List[torch.Tensor],
|
| 540 |
+
truncation: bool,
|
| 541 |
+
) -> torch.Tensor:
|
| 542 |
"""
|
| 543 |
此时的 content 已经是带上了对话格式
|
| 544 |
"""
|
|
|
|
| 566 |
audio_start_token=self.audio_start_token,
|
| 567 |
audio_end_token=self.audio_end_token,
|
| 568 |
)
|
| 569 |
+
text_codes = torch.tensor(
|
| 570 |
+
self.tokenizer.encode(content),
|
| 571 |
+
device=audio_codes_list[0].device if audio_codes_list else None,
|
| 572 |
+
)
|
| 573 |
|
| 574 |
+
audio_start_indices = torch.where(
|
| 575 |
+
text_codes == self.model_config.audio_start_token_id
|
| 576 |
+
)[0]
|
| 577 |
+
audio_end_indices = torch.where(
|
| 578 |
+
text_codes == self.model_config.audio_end_token_id
|
| 579 |
+
)[0]
|
| 580 |
+
if len(audio_start_indices) != len(audio_codes_list) or len(
|
| 581 |
+
audio_end_indices
|
| 582 |
+
) != len(audio_codes_list):
|
| 583 |
+
raise ValueError(
|
| 584 |
+
"Audio placeholders do not match the provided audio codes list."
|
| 585 |
+
)
|
| 586 |
|
| 587 |
delay_audio_codes_list = []
|
| 588 |
if len(audio_codes_list) == 0:
|
|
|
|
| 594 |
)
|
| 595 |
else:
|
| 596 |
prefix_idx = 0
|
| 597 |
+
for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
|
| 598 |
+
audio_start_indices, audio_end_indices, audio_codes_list
|
| 599 |
+
):
|
| 600 |
+
audio_start_idx = int(audio_start_idx_t.item())
|
| 601 |
+
audio_end_idx = int(audio_end_idx_t.item())
|
| 602 |
+
delay_audio_codes = self.apply_delay_pattern(
|
| 603 |
+
audio_codes, self.model_config.audio_pad_code
|
| 604 |
+
)
|
| 605 |
pad_codes = torch.full(
|
| 606 |
(audio_start_idx - prefix_idx + 1, n_vq),
|
| 607 |
self.model_config.audio_pad_code,
|
|
|
|
| 612 |
prefix_idx = audio_end_idx
|
| 613 |
|
| 614 |
if truncation:
|
| 615 |
+
delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
|
| 616 |
+
: -(n_vq - 1), :
|
| 617 |
+
]
|
| 618 |
else:
|
| 619 |
+
last_audio_end_idx = int(audio_end_indices[-1].item())
|
| 620 |
pad_codes = torch.full(
|
| 621 |
+
(len(text_codes) - last_audio_end_idx, n_vq),
|
| 622 |
self.model_config.audio_pad_code,
|
| 623 |
device=audio_codes_list[0].device,
|
| 624 |
dtype=audio_codes_list[0].dtype,
|
|
|
|
| 626 |
delay_audio_codes_list.append(pad_codes)
|
| 627 |
|
| 628 |
delay_audio_codes_list = torch.cat(delay_audio_codes_list)
|
| 629 |
+
|
| 630 |
if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
|
| 631 |
+
text_codes = text_codes[: delay_audio_codes_list.shape[0]]
|
| 632 |
|
| 633 |
+
unified_codes = torch.cat(
|
| 634 |
+
[text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
|
| 635 |
+
)
|
| 636 |
return unified_codes
|
| 637 |
|
| 638 |
def _parse_text_codes(self, start_length, text_codes):
|
| 639 |
+
text = cast(str, self.tokenizer.decode(text_codes))
|
| 640 |
+
prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
|
| 641 |
+
text = text[len(prefix) :]
|
| 642 |
|
| 643 |
AUDIO_PATTERN = re.compile(
|
| 644 |
+
rf"(?:{self.audio_start_token})?"
|
| 645 |
+
rf"(?:{self.audio_assistant_gen_slot_token})*"
|
| 646 |
+
rf"(?:{self.audio_assistant_delay_slot_token})*"
|
| 647 |
+
rf"{self.audio_end_token}"
|
| 648 |
)
|
| 649 |
|
| 650 |
def normalize_audio_segments(text: str) -> str:
|
| 651 |
def repl(match: re.Match) -> str:
|
| 652 |
seg = match.group(0)
|
| 653 |
+
# Replace with <|audio|> if gen_slot is present in the segment;
|
| 654 |
if self.audio_assistant_gen_slot_token in seg:
|
| 655 |
return AUDIO_PLACEHOLDER
|
| 656 |
+
# Otherwise, remove it.
|
| 657 |
return ""
|
| 658 |
+
|
| 659 |
return AUDIO_PATTERN.sub(repl, text)
|
| 660 |
|
| 661 |
return normalize_audio_segments(text)
|
|
|
|
| 679 |
|
| 680 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
| 681 |
|
| 682 |
+
# Batch-decode all audio segments together.
|
| 683 |
+
decoded_audio_list = self.decode_audio_codes(audio_codes_list)
|
|
|
|
|
|
|
|
|
|
| 684 |
|
| 685 |
# Keep codec causal context by decoding the whole first segment first,
|
| 686 |
# then trim at waveform level according to start_length ratio.
|
| 687 |
+
if (
|
| 688 |
+
start_length > 0
|
| 689 |
+
and len(audio_codes_list) > 0
|
| 690 |
+
and len(decoded_audio_list) > 0
|
| 691 |
+
):
|
| 692 |
first_codes_length = audio_codes_list[0].shape[0]
|
| 693 |
if first_codes_length > 0:
|
| 694 |
+
trim_ratio = max(
|
| 695 |
+
0.0, min(float(start_length) / float(first_codes_length), 1.0)
|
| 696 |
+
)
|
| 697 |
first_audio = decoded_audio_list[0]
|
| 698 |
if trim_ratio >= 1.0:
|
| 699 |
decoded_audio_list = decoded_audio_list[1:]
|
|
|
|
| 703 |
|
| 704 |
return decoded_audio_list
|
| 705 |
|
|
|
|
| 706 |
def decode(self, output: List[Tuple[int, torch.Tensor]]):
|
| 707 |
"""
|
| 708 |
1. 这里不管怎样,都需要一个完整的 assistant generation ids;
|
|
|
|
| 712 |
genearted_messages = []
|
| 713 |
for start_length, generation_ids in output:
|
| 714 |
content = self._parse_text_codes(start_length, generation_ids[:, 0])
|
| 715 |
+
audio_codes_list = self._parse_audio_codes(
|
| 716 |
+
start_length, generation_ids[:, 1:]
|
| 717 |
+
)
|
| 718 |
if content == "":
|
| 719 |
message = None
|
| 720 |
else:
|
| 721 |
message = AssistantMessage(
|
| 722 |
content=content,
|
| 723 |
+
audio_codes_list=cast(
|
| 724 |
+
List[Union[str, torch.Tensor]], audio_codes_list
|
| 725 |
+
),
|
| 726 |
)
|
| 727 |
genearted_messages.append(message)
|
| 728 |
return genearted_messages
|
| 729 |
|
| 730 |
@staticmethod
|
| 731 |
+
def loudness_normalize(
|
| 732 |
+
wav: torch.Tensor,
|
| 733 |
+
target_dbfs: float = -20,
|
| 734 |
+
gain_range: tuple[float, float] = (-3.0, 3.0),
|
| 735 |
+
) -> torch.Tensor:
|
| 736 |
wav = wav.to(torch.float32)
|
| 737 |
+
if wav.numel() == 0:
|
| 738 |
+
return wav
|
| 739 |
+
current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
|
| 740 |
gain = float(target_dbfs - current_dbfs)
|
| 741 |
gain = max(gain_range[0], min(gain, gain_range[1]))
|
| 742 |
factor = 10.0 ** (gain / 20.0)
|
| 743 |
return wav * factor
|
| 744 |
|
| 745 |
+
def _get_audio_tokenizer_device(self) -> torch.device:
|
| 746 |
+
"""Best-effort device inference for `self.audio_tokenizer`.
|
| 747 |
+
|
| 748 |
+
Notes:
|
| 749 |
+
- Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
|
| 750 |
+
- New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
audio_tokenizer = getattr(self, "audio_tokenizer", None)
|
| 754 |
+
if audio_tokenizer is None:
|
| 755 |
+
logger.warning(
|
| 756 |
+
"audio_tokenizer is not set on processor. Using CPU as default."
|
| 757 |
+
)
|
| 758 |
+
return torch.device("cpu")
|
| 759 |
+
|
| 760 |
+
device_attr = getattr(audio_tokenizer, "device", None)
|
| 761 |
+
if isinstance(device_attr, torch.device):
|
| 762 |
+
return device_attr
|
| 763 |
+
|
| 764 |
+
try:
|
| 765 |
+
return next(audio_tokenizer.parameters()).device
|
| 766 |
+
except StopIteration:
|
| 767 |
+
# No parameters (shouldn't happen for real models); default to CPU.
|
| 768 |
+
logger.warning(
|
| 769 |
+
"No parameters found on audio_tokenizer. Using CPU as default."
|
| 770 |
+
)
|
| 771 |
+
return torch.device("cpu")
|
| 772 |
+
|
| 773 |
+
def encode_audios_from_wav(
|
| 774 |
+
self,
|
| 775 |
+
wav_list: List[torch.Tensor],
|
| 776 |
+
sampling_rate: int,
|
| 777 |
+
n_vq: Optional[int] = None,
|
| 778 |
+
):
|
| 779 |
+
if self.audio_tokenizer is None:
|
| 780 |
+
raise RuntimeError("audio_tokenizer is not set on processor.")
|
| 781 |
+
audio_tokenizer = self.audio_tokenizer
|
| 782 |
+
|
| 783 |
if isinstance(wav_list, torch.Tensor):
|
| 784 |
wav_list = [wav_list]
|
| 785 |
wav_list_ = []
|
| 786 |
resample = False
|
| 787 |
if sampling_rate != self.model_config.sampling_rate:
|
| 788 |
resample = True
|
| 789 |
+
device = self._get_audio_tokenizer_device()
|
| 790 |
for wav in wav_list:
|
| 791 |
if wav.shape[0] > 1:
|
| 792 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
| 793 |
if resample:
|
| 794 |
+
wav = torchaudio.functional.resample(
|
| 795 |
+
waveform=wav,
|
| 796 |
+
orig_freq=sampling_rate,
|
| 797 |
+
new_freq=self.model_config.sampling_rate,
|
| 798 |
+
)
|
| 799 |
+
wav = wav.to(device)
|
| 800 |
wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
|
|
|
|
| 801 |
|
| 802 |
+
# New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
|
| 803 |
+
if hasattr(audio_tokenizer, "batch_encode"):
|
| 804 |
+
enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
|
| 805 |
+
audio_codes = enc.audio_codes # (NQ, B, T)
|
| 806 |
+
audio_codes_lengths = enc.audio_codes_lengths # (B,)
|
| 807 |
+
else:
|
| 808 |
+
# Fallback: use encode() with explicit padding.
|
| 809 |
+
max_len = max(int(wav.shape[-1]) for wav in wav_list_)
|
| 810 |
+
input_values = torch.zeros(
|
| 811 |
+
len(wav_list_), 1, max_len, device=device, dtype=torch.float32
|
| 812 |
+
)
|
| 813 |
+
padding_mask = torch.zeros(
|
| 814 |
+
len(wav_list_), max_len, device=device, dtype=torch.bool
|
| 815 |
+
)
|
| 816 |
+
for i, wav in enumerate(wav_list_):
|
| 817 |
+
this_len = int(wav.shape[-1])
|
| 818 |
+
input_values[i, 0, :this_len] = wav
|
| 819 |
+
padding_mask[i, :this_len] = True
|
| 820 |
+
enc = audio_tokenizer.encode(
|
| 821 |
+
input_values,
|
| 822 |
+
padding_mask=padding_mask,
|
| 823 |
+
num_quantizers=n_vq,
|
| 824 |
+
return_dict=True,
|
| 825 |
+
)
|
| 826 |
+
audio_codes = enc.audio_codes
|
| 827 |
+
audio_codes_lengths = enc.audio_codes_lengths
|
| 828 |
+
|
| 829 |
+
if audio_codes is None or audio_codes_lengths is None:
|
| 830 |
+
raise RuntimeError(
|
| 831 |
+
"audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Keep processor's historical contract: list[Tensor] with shape (T, NQ)
|
| 835 |
+
# and on CPU (so downstream text/audio packing remains device-agnostic).
|
| 836 |
+
codes_list: List[torch.Tensor] = []
|
| 837 |
+
for i in range(int(audio_codes.shape[1])):
|
| 838 |
+
length_i = int(audio_codes_lengths[i].item())
|
| 839 |
+
codes_i = (
|
| 840 |
+
audio_codes[:, i, :length_i]
|
| 841 |
+
.transpose(0, 1)
|
| 842 |
+
.contiguous()
|
| 843 |
+
.to(torch.long)
|
| 844 |
+
.cpu()
|
| 845 |
+
)
|
| 846 |
+
codes_list.append(codes_i)
|
| 847 |
+
return codes_list
|
| 848 |
+
|
| 849 |
+
def encode_audios_from_path(
|
| 850 |
+
self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
|
| 851 |
+
):
|
| 852 |
if isinstance(wav_path_list, str):
|
| 853 |
wav_path_list = [wav_path_list]
|
| 854 |
+
|
| 855 |
+
if len(wav_path_list) == 0:
|
| 856 |
+
raise ValueError("Empty wav_path_list")
|
| 857 |
+
|
| 858 |
+
# Load + (if needed) resample each wav independently, so callers can
|
| 859 |
+
# pass a heterogeneous batch of files while still benefiting from
|
| 860 |
+
# audio_tokenizer.batch_encode.
|
| 861 |
+
target_sr = int(self.model_config.sampling_rate)
|
| 862 |
+
wav_list: List[torch.Tensor] = []
|
| 863 |
for wav_path in wav_path_list:
|
| 864 |
wav, sr = torchaudio.load(wav_path)
|
| 865 |
+
if int(sr) != target_sr:
|
| 866 |
+
wav = torchaudio.functional.resample(
|
| 867 |
+
waveform=wav,
|
| 868 |
+
orig_freq=int(sr),
|
| 869 |
+
new_freq=target_sr,
|
| 870 |
+
)
|
| 871 |
+
wav_list.append(wav)
|
| 872 |
+
|
| 873 |
+
return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
|
| 874 |
+
|
| 875 |
+
def decode_audio_codes(
|
| 876 |
+
self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
|
| 877 |
+
):
|
| 878 |
+
if self.audio_tokenizer is None:
|
| 879 |
+
raise RuntimeError("audio_tokenizer is not set on processor.")
|
| 880 |
+
audio_tokenizer = self.audio_tokenizer
|
| 881 |
+
|
| 882 |
+
if isinstance(audio_tokens_list, torch.Tensor):
|
| 883 |
+
audio_tokens_list = [audio_tokens_list]
|
| 884 |
+
if len(audio_tokens_list) == 0:
|
| 885 |
+
return []
|
| 886 |
+
|
| 887 |
+
device = self._get_audio_tokenizer_device()
|
| 888 |
+
|
| 889 |
+
# Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
|
| 890 |
+
codes_list = [
|
| 891 |
+
codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 892 |
+
for codes in audio_tokens_list
|
| 893 |
+
]
|
| 894 |
+
|
| 895 |
+
if hasattr(audio_tokenizer, "batch_decode"):
|
| 896 |
+
dec = audio_tokenizer.batch_decode(codes_list)
|
| 897 |
+
audio = dec.audio # (B, C, T)
|
| 898 |
+
audio_lengths = dec.audio_lengths # (B,)
|
| 899 |
+
else:
|
| 900 |
+
# Fallback: pad to (NQ, B, T) + mask, then decode.
|
| 901 |
+
nq = int(codes_list[0].shape[0])
|
| 902 |
+
max_t = max(int(c.shape[1]) for c in codes_list)
|
| 903 |
+
audio_codes = torch.zeros(
|
| 904 |
+
nq, len(codes_list), max_t, device=device, dtype=torch.long
|
| 905 |
+
)
|
| 906 |
+
padding_mask = torch.zeros(
|
| 907 |
+
len(codes_list), max_t, device=device, dtype=torch.bool
|
| 908 |
+
)
|
| 909 |
+
for i, c in enumerate(codes_list):
|
| 910 |
+
t = int(c.shape[1])
|
| 911 |
+
audio_codes[:, i, :t] = c
|
| 912 |
+
padding_mask[i, :t] = True
|
| 913 |
+
dec = audio_tokenizer.decode(
|
| 914 |
+
audio_codes, padding_mask=padding_mask, return_dict=True
|
| 915 |
+
)
|
| 916 |
+
audio = dec.audio
|
| 917 |
+
audio_lengths = dec.audio_lengths
|
| 918 |
+
|
| 919 |
+
if audio is None or audio_lengths is None:
|
| 920 |
+
raise RuntimeError(
|
| 921 |
+
"audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
# Return historical contract: list of 1D waveforms (T,)
|
| 925 |
+
wav_list: List[torch.Tensor] = []
|
| 926 |
+
for i in range(int(audio.shape[0])):
|
| 927 |
+
length_i = int(audio_lengths[i].item())
|
| 928 |
+
wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
|
| 929 |
wav_list.append(wav)
|
| 930 |
+
return wav_list
|
|
|
|
|
|
|
|
|