YWMditto commited on
Commit
eea4939
·
1 Parent(s): 53fca80

update readme

Browse files
Files changed (1) hide show
  1. modeling_moss_tts.py +18 -25
modeling_moss_tts.py CHANGED
@@ -401,7 +401,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
401
  audio_temperature: float = 1.5,
402
  audio_top_p: float = 0.6,
403
  audio_top_k: int = 50,
404
- audio_repetition_penalty: float = 1.1,
405
  ):
406
  if text_temperature > 0:
407
  text_do_sample = True
@@ -424,14 +424,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
424
  generation_ids = input_ids[:]
425
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
 
427
- # 三个阶段: 1. 非 audio; 2. audio not delay; 3. audio delay
428
- audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
429
  torch_int64_max = torch.iinfo(torch.int64).max
430
- delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
431
 
432
- # 考虑 continuation 时 audio_start 已经在 input_ids 中的情况;
433
- # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
434
- # 需要同时考虑 continuation 和直接生成的情况;
435
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
436
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
437
  audio_start_mask = is_continuation & (audio_start_indices != -1)
@@ -443,8 +439,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
443
  pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
444
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
445
 
446
-
447
- # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
448
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
449
  outputs = self(
450
  input_ids=current_input_ids,
@@ -456,9 +450,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
456
 
457
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
458
  next_token_logits[0] = next_token_logits[0].clone()
459
- # 1. 先处理 text token;
460
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
461
- # 第二个 audio_assistant_delay_slot_token_id 和 audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
462
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
463
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
464
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
@@ -471,7 +463,6 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
471
  if time_step <= n_vq:
472
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
473
 
474
- # 文本层不使用重复惩罚;
475
  next_text_token[sampling_text_mask] = sample_token(
476
  logits=next_token_logits[0][sampling_text_mask],
477
  top_p=text_top_p,
@@ -479,15 +470,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
479
  do_sample=text_do_sample
480
  )
481
  is_audio[next_text_token == self.config.audio_start_token_id] = True
482
- # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
483
  is_stopping[next_text_token == self.config.im_end_token_id] = True
484
 
485
- # 2. 再处理 audio tokens;
486
- # audio_start 和 audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
487
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
488
 
489
- # 需要考虑的是与 audio_start 的距离;
490
- # 先查看是否是pad的情况; true 表示有值;
491
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
492
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
493
  post_audio_mask[delayed_lengths == torch_int64_max] = True
@@ -495,29 +481,36 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
495
  next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
496
 
497
  if sampling_audio_mask.sum() > 0:
498
- audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
 
 
499
  audio_logits[..., self.config.audio_pad_code] = float('-inf')
500
- next_audio_tokens[sampling_audio_mask] = sample_token(
 
 
 
 
 
 
 
 
501
  logits=audio_logits,
502
- prev_tokens=generation_ids[:, :, 1:],
503
  repetition_penalty=audio_repetition_penalty,
504
  top_p=audio_top_p,
505
  top_k=audio_top_k,
506
  do_sample=audio_do_sample
507
  )
508
 
509
- # 这里显示的是下一个时间步时可以直接使用的 audio_lengths 和 delayed_lengths 的状态;
510
- # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
511
- # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
512
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
513
  audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
514
  delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
515
  delayed_lengths[delayed_lengths != torch_int64_max] += 1
516
  delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
517
 
518
- current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) # [batch_size, 1, n_vq + 1]
519
  current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
520
- generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) # [batch_size, seq_len, n_vq + 1]
521
 
522
  if is_stopping.sum() == batch_size:
523
  break
 
401
  audio_temperature: float = 1.5,
402
  audio_top_p: float = 0.6,
403
  audio_top_k: int = 50,
404
+ audio_repetition_penalty: float = 1.1
405
  ):
406
  if text_temperature > 0:
407
  text_do_sample = True
 
424
  generation_ids = input_ids[:]
425
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
 
427
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device)
 
428
  torch_int64_max = torch.iinfo(torch.int64).max
429
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
430
 
 
 
 
431
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
432
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
433
  audio_start_mask = is_continuation & (audio_start_indices != -1)
 
439
  pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
440
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
441
 
 
 
442
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
443
  outputs = self(
444
  input_ids=current_input_ids,
 
450
 
451
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
452
  next_token_logits[0] = next_token_logits[0].clone()
 
453
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
 
454
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
455
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
456
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
 
463
  if time_step <= n_vq:
464
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
465
 
 
466
  next_text_token[sampling_text_mask] = sample_token(
467
  logits=next_token_logits[0][sampling_text_mask],
468
  top_p=text_top_p,
 
470
  do_sample=text_do_sample
471
  )
472
  is_audio[next_text_token == self.config.audio_start_token_id] = True
 
473
  is_stopping[next_text_token == self.config.im_end_token_id] = True
474
 
 
 
475
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
476
 
 
 
477
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
478
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
479
  post_audio_mask[delayed_lengths == torch_int64_max] = True
 
481
  next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
482
 
483
  if sampling_audio_mask.sum() > 0:
484
+ audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]]
485
+ audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]]
486
+ audio_ch0_logits[..., self.config.audio_pad_code] = float('-inf')
487
  audio_logits[..., self.config.audio_pad_code] = float('-inf')
488
+ next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token(
489
+ logits=audio_ch0_logits,
490
+ prev_tokens=generation_ids[:, :, 1],
491
+ repetition_penalty=audio_repetition_penalty,
492
+ top_p=audio_top_p,
493
+ top_k=audio_top_k,
494
+ do_sample=audio_do_sample
495
+ )
496
+ next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token(
497
  logits=audio_logits,
498
+ prev_tokens=generation_ids[:, :, 2:],
499
  repetition_penalty=audio_repetition_penalty,
500
  top_p=audio_top_p,
501
  top_k=audio_top_k,
502
  do_sample=audio_do_sample
503
  )
504
 
 
 
 
505
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
506
  audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
507
  delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
508
  delayed_lengths[delayed_lengths != torch_int64_max] += 1
509
  delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
510
 
511
+ current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2)
512
  current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
513
+ generation_ids = torch.cat([generation_ids, current_input_ids], dim=1)
514
 
515
  if is_stopping.sum() == batch_size:
516
  break