update readme
Browse files- README.md +2 -4
- modeling_moss_tts.py +4 -21
README.md
CHANGED
|
@@ -104,10 +104,8 @@ For full details, see:
|
|
| 104 |
|
| 105 |
| Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
|
| 106 |
|---|---:|---:|---:|---:|
|
| 107 |
-
| **
|
| 108 |
-
| **
|
| 109 |
-
|
| 110 |
-
> Note: `max_new_tokens` controls duration. At 12.5 tokens per second, **1s ≈ 12.5 tokens**.
|
| 111 |
|
| 112 |
|
| 113 |
|
|
|
|
| 104 |
|
| 105 |
| Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
|
| 106 |
|---|---:|---:|---:|---:|
|
| 107 |
+
| **MossTTSDelay-8B** | 1.7 | 0.8 | 25 | 1.0 |
|
| 108 |
+
| **MossTTSLocal-1.7B** | 1.0 | 0.95 | 50 | 1.1 |
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
|
modeling_moss_tts.py
CHANGED
|
@@ -424,14 +424,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 424 |
generation_ids = input_ids[:]
|
| 425 |
is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 426 |
|
| 427 |
-
|
| 428 |
-
audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
|
| 429 |
torch_int64_max = torch.iinfo(torch.int64).max
|
| 430 |
-
delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
|
| 431 |
|
| 432 |
-
# 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
|
| 433 |
-
# NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
|
| 434 |
-
# 需要同时考虑 continuation 和直接生成的情况;
|
| 435 |
is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
|
| 436 |
audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
|
| 437 |
audio_start_mask = is_continuation & (audio_start_indices != -1)
|
|
@@ -443,8 +439,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 443 |
pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
|
| 444 |
pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
|
| 445 |
|
| 446 |
-
|
| 447 |
-
# 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
|
| 448 |
for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
|
| 449 |
outputs = self(
|
| 450 |
input_ids=current_input_ids,
|
|
@@ -456,9 +450,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 456 |
|
| 457 |
next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
|
| 458 |
next_token_logits[0] = next_token_logits[0].clone()
|
| 459 |
-
# 1. 先处理 text token;
|
| 460 |
next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
|
| 461 |
-
# 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
|
| 462 |
next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
|
| 463 |
is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
|
| 464 |
next_text_token[is_audio_eos] = self.config.audio_end_token_id
|
|
@@ -471,7 +463,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 471 |
if time_step <= n_vq:
|
| 472 |
next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
|
| 473 |
|
| 474 |
-
# 文本层不使用重复惩罚;
|
| 475 |
next_text_token[sampling_text_mask] = sample_token(
|
| 476 |
logits=next_token_logits[0][sampling_text_mask],
|
| 477 |
top_p=text_top_p,
|
|
@@ -479,15 +470,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 479 |
do_sample=text_do_sample
|
| 480 |
)
|
| 481 |
is_audio[next_text_token == self.config.audio_start_token_id] = True
|
| 482 |
-
# 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
|
| 483 |
is_stopping[next_text_token == self.config.im_end_token_id] = True
|
| 484 |
|
| 485 |
-
# 2. 再处理 audio tokens;
|
| 486 |
-
# audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
|
| 487 |
next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
|
| 488 |
|
| 489 |
-
# 需要考虑的是与 audio_start 的距离;
|
| 490 |
-
# 先查看是否是pad的情况; true 表示有值;
|
| 491 |
pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
|
| 492 |
post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
|
| 493 |
post_audio_mask[delayed_lengths == torch_int64_max] = True
|
|
@@ -516,18 +502,15 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
|
|
| 516 |
do_sample=audio_do_sample
|
| 517 |
)
|
| 518 |
|
| 519 |
-
# 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
|
| 520 |
-
# audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
|
| 521 |
-
# audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
|
| 522 |
audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
|
| 523 |
audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
|
| 524 |
delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
|
| 525 |
delayed_lengths[delayed_lengths != torch_int64_max] += 1
|
| 526 |
delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
|
| 527 |
|
| 528 |
-
current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2)
|
| 529 |
current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
|
| 530 |
-
generation_ids = torch.cat([generation_ids, current_input_ids], dim=1)
|
| 531 |
|
| 532 |
if is_stopping.sum() == batch_size:
|
| 533 |
break
|
|
|
|
| 424 |
generation_ids = input_ids[:]
|
| 425 |
is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 426 |
|
| 427 |
+
audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
|
|
|
| 428 |
torch_int64_max = torch.iinfo(torch.int64).max
|
| 429 |
+
delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
|
| 430 |
|
|
|
|
|
|
|
|
|
|
| 431 |
is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
|
| 432 |
audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
|
| 433 |
audio_start_mask = is_continuation & (audio_start_indices != -1)
|
|
|
|
| 439 |
pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
|
| 440 |
pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
|
| 441 |
|
|
|
|
|
|
|
| 442 |
for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
|
| 443 |
outputs = self(
|
| 444 |
input_ids=current_input_ids,
|
|
|
|
| 450 |
|
| 451 |
next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
|
| 452 |
next_token_logits[0] = next_token_logits[0].clone()
|
|
|
|
| 453 |
next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
|
|
|
|
| 454 |
next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
|
| 455 |
is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
|
| 456 |
next_text_token[is_audio_eos] = self.config.audio_end_token_id
|
|
|
|
| 463 |
if time_step <= n_vq:
|
| 464 |
next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
|
| 465 |
|
|
|
|
| 466 |
next_text_token[sampling_text_mask] = sample_token(
|
| 467 |
logits=next_token_logits[0][sampling_text_mask],
|
| 468 |
top_p=text_top_p,
|
|
|
|
| 470 |
do_sample=text_do_sample
|
| 471 |
)
|
| 472 |
is_audio[next_text_token == self.config.audio_start_token_id] = True
|
|
|
|
| 473 |
is_stopping[next_text_token == self.config.im_end_token_id] = True
|
| 474 |
|
|
|
|
|
|
|
| 475 |
next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
|
| 476 |
|
|
|
|
|
|
|
| 477 |
pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
|
| 478 |
post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
|
| 479 |
post_audio_mask[delayed_lengths == torch_int64_max] = True
|
|
|
|
| 502 |
do_sample=audio_do_sample
|
| 503 |
)
|
| 504 |
|
|
|
|
|
|
|
|
|
|
| 505 |
audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
|
| 506 |
audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
|
| 507 |
delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
|
| 508 |
delayed_lengths[delayed_lengths != torch_int64_max] += 1
|
| 509 |
delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
|
| 510 |
|
| 511 |
+
current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2)
|
| 512 |
current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
|
| 513 |
+
generation_ids = torch.cat([generation_ids, current_input_ids], dim=1)
|
| 514 |
|
| 515 |
if is_stopping.sum() == batch_size:
|
| 516 |
break
|