YWMditto commited on
Commit
e7e1239
·
1 Parent(s): a6f32ab

update readme

Browse files
Files changed (2) hide show
  1. README.md +2 -4
  2. modeling_moss_tts.py +4 -21
README.md CHANGED
@@ -104,10 +104,8 @@ For full details, see:
104
 
105
  | Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
106
  |---|---:|---:|---:|---:|
107
- | **MOSS-TTSDelay-8B** | 1.7 | 0.8 | 25 | 1.0 |
108
- | **MOSS-TTSLocal-1.7B** | 1.0 | 0.95 | 50 | 1.1 |
109
-
110
- > Note: `max_new_tokens` controls duration. At 12.5 tokens per second, **1s ≈ 12.5 tokens**.
111
 
112
 
113
 
 
104
 
105
  | Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
106
  |---|---:|---:|---:|---:|
107
+ | **MossTTSDelay-8B** | 1.7 | 0.8 | 25 | 1.0 |
108
+ | **MossTTSLocal-1.7B** | 1.0 | 0.95 | 50 | 1.1 |
 
 
109
 
110
 
111
 
modeling_moss_tts.py CHANGED
@@ -424,14 +424,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
424
  generation_ids = input_ids[:]
425
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
 
427
- # 三个阶段: 1. 非 audio; 2. audio not delay; 3. audio delay
428
- audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
429
  torch_int64_max = torch.iinfo(torch.int64).max
430
- delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
431
 
432
- # 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
433
- # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
434
- # 需要同时考虑 continuation 和直接生成的情况;
435
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
436
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
437
  audio_start_mask = is_continuation & (audio_start_indices != -1)
@@ -443,8 +439,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
443
  pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
444
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
445
 
446
-
447
- # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
448
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
449
  outputs = self(
450
  input_ids=current_input_ids,
@@ -456,9 +450,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
456
 
457
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
458
  next_token_logits[0] = next_token_logits[0].clone()
459
- # 1. 先处理 text token;
460
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
461
- # 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
462
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
463
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
464
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
@@ -471,7 +463,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
471
  if time_step <= n_vq:
472
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
473
 
474
- # 文本层不使用重复惩罚;
475
  next_text_token[sampling_text_mask] = sample_token(
476
  logits=next_token_logits[0][sampling_text_mask],
477
  top_p=text_top_p,
@@ -479,15 +470,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
479
  do_sample=text_do_sample
480
  )
481
  is_audio[next_text_token == self.config.audio_start_token_id] = True
482
- # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
483
  is_stopping[next_text_token == self.config.im_end_token_id] = True
484
 
485
- # 2. 再处理 audio tokens;
486
- # audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
487
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
488
 
489
- # 需要考虑的是与 audio_start 的距离;
490
- # 先查看是否是pad的情况; true 表示有值;
491
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
492
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
493
  post_audio_mask[delayed_lengths == torch_int64_max] = True
@@ -516,18 +502,15 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
516
  do_sample=audio_do_sample
517
  )
518
 
519
- # 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
520
- # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
521
- # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
522
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
523
  audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
524
  delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
525
  delayed_lengths[delayed_lengths != torch_int64_max] += 1
526
  delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
527
 
528
- current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
529
  current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
530
- generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
531
 
532
  if is_stopping.sum() == batch_size:
533
  break
 
424
  generation_ids = input_ids[:]
425
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
 
427
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device)
 
428
  torch_int64_max = torch.iinfo(torch.int64).max
429
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
430
 
 
 
 
431
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
432
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
433
  audio_start_mask = is_continuation & (audio_start_indices != -1)
 
439
  pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
440
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
441
 
 
 
442
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
443
  outputs = self(
444
  input_ids=current_input_ids,
 
450
 
451
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
452
  next_token_logits[0] = next_token_logits[0].clone()
 
453
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
 
454
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
455
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
456
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
 
463
  if time_step <= n_vq:
464
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
465
 
 
466
  next_text_token[sampling_text_mask] = sample_token(
467
  logits=next_token_logits[0][sampling_text_mask],
468
  top_p=text_top_p,
 
470
  do_sample=text_do_sample
471
  )
472
  is_audio[next_text_token == self.config.audio_start_token_id] = True
 
473
  is_stopping[next_text_token == self.config.im_end_token_id] = True
474
 
 
 
475
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
476
 
 
 
477
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
478
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
479
  post_audio_mask[delayed_lengths == torch_int64_max] = True
 
502
  do_sample=audio_do_sample
503
  )
504
 
 
 
 
505
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
506
  audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
507
  delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
508
  delayed_lengths[delayed_lengths != torch_int64_max] += 1
509
  delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
510
 
511
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2)
512
  current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
513
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1)
514
 
515
  if is_stopping.sum() == batch_size:
516
  break