| |
| |
|
|
| from typing import List |
| from typing import Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn.utils.parametrize as P |
|
|
|
|
| class TTSStreamingGenerator: |
| def __init__( |
| self, |
| model, |
| temperature: float, |
| eos_token: Union[int, torch.Tensor], |
| chunk_size: int = 25, |
| tts_last_turn_tokens: torch.Tensor = None, |
| logits_processors=[], |
| logits_warpers=[], |
| ): |
| self.tts = model |
| self.device = model.device |
| self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) |
| self.eos_token = ( |
| torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device) |
| ) |
|
|
| self.num_vq = model.num_vq |
| self.num_audio_tokens = model.num_audio_tokens |
| self.window_size = model.window_size |
| self.recomputed_chunks = model.recomputed_chunks |
| self.emb_code = model.emb_code |
| self.head_code = model.head_code |
|
|
| |
| self.logits_processors = logits_processors |
| |
| self.logits_warpers = logits_warpers |
|
|
| |
| self.past_key_values = None |
| self.text_start_pos = 0 |
| self.idx = -1 |
| self.all_conditions = [] |
| self.all_generated_tokens = [] |
| self.tts_last_turn_tokens = tts_last_turn_tokens |
| self.spk_emb = None |
|
|
| audio_bos = [self.tts.audio_bos_token_id] |
| audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long) |
|
|
| self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0) |
| self.text_eos_embed = self.tts.emb_text( |
| torch.tensor( |
| [self.tts.config.text_eos_token_id], |
| device=self.tts.emb_text.weight.device, |
| dtype=torch.long, |
| ) |
| ).unsqueeze(0) |
|
|
| |
| self.chunk_size = chunk_size |
| self._token_buffer: List[torch.Tensor] = [] |
|
|
| @torch.inference_mode() |
| def generate_with_buffer( |
| self, |
| condition: torch.Tensor, |
| text_finished: bool = False, |
| max_new_token: int = 500, |
| ): |
| """input a condition embedding chunk, generate audio token each time, |
| and accumulate to buffer, only yield when buffer satisfies chunk_size. |
| |
| Yields: |
| torch.Tensor of shape [chunk_size] (2D: [1, chunk_size]) |
| """ |
| self.idx += 1 |
| self.device = self.tts.device |
|
|
| |
| if text_finished: |
| condition = torch.cat([condition, self.text_eos_embed], dim=1) |
|
|
| |
| condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device) |
|
|
| self.all_conditions.append(condition) |
|
|
| current_condition = condition |
|
|
| condition_length = current_condition.shape[1] |
| finished = torch.zeros(1, dtype=torch.bool, device=self.device) |
| chunk_generated_tokens = [] |
|
|
| for t in range(max_new_token): |
| if t == 0: |
| inputs_embeds = current_condition |
| pos_ids = torch.arange( |
| self.text_start_pos, |
| self.text_start_pos + condition_length, |
| dtype=torch.long, |
| device=self.device, |
| ).unsqueeze(0) |
| else: |
| last = self.all_generated_tokens[-1] |
| |
| inputs_embeds = self.emb_code[0](last) |
| pos_ids = torch.tensor( |
| [self.text_start_pos + condition_length + t - 1], |
| dtype=torch.long, |
| device=self.device, |
| ).unsqueeze(0) |
|
|
| outputs = self.tts.model( |
| position_ids=pos_ids, |
| past_key_values=self.past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| ) |
| hidden_states = outputs.last_hidden_state |
| self.past_key_values = outputs.past_key_values |
|
|
| with P.cached(): |
| logits = torch.empty( |
| hidden_states.size(0), |
| hidden_states.size(1), |
| self.num_audio_tokens, |
| self.num_vq, |
| dtype=torch.float, |
| device=self.device, |
| ) |
| for num_vq_iter in range(self.num_vq): |
| x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
| logits[..., num_vq_iter] = x |
| del x |
|
|
| del hidden_states |
|
|
| logits = logits[:, -1].float() |
|
|
| logits = logits.permute(0, 2, 1) |
| logits = logits.reshape(-1, logits.size(2)) |
|
|
| logits /= self.temperature |
|
|
| audio_bos = len(self.all_generated_tokens) == 0 and t == 0 |
|
|
| if not audio_bos: |
| |
| all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) |
| for processor in self.logits_processors: |
| logits = processor(all_generated_tokens, logits) |
|
|
| for warper in self.logits_warpers: |
| logits = warper(all_generated_tokens, logits) |
| del all_generated_tokens |
|
|
| |
| scores = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(scores, num_samples=1) |
| next_id = idx_next.view(-1, self.num_vq)[:, 0:1] |
| del scores |
|
|
| if next_id.eq( |
| self.eos_token |
| ).any(): |
| finished[:] = True |
| else: |
| |
| if next_id.dim() == 0: |
| next_tok = next_id.unsqueeze(0).unsqueeze(0) |
| elif next_id.dim() == 1: |
| next_tok = next_id.unsqueeze(0) |
| else: |
| next_tok = next_id |
|
|
| self.all_generated_tokens.append(next_tok) |
| chunk_generated_tokens.append(next_tok) |
|
|
| self._token_buffer.append(next_tok) |
|
|
| if len(self._token_buffer) == 0: |
| |
| if text_finished: |
| yield torch.empty(1, 0, dtype=torch.long, device=self.device), True |
| break |
| |
| else: |
| break |
| else: |
| |
| if len(self._token_buffer) >= self.chunk_size: |
| batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) |
| yield batch, False |
| |
| self._token_buffer = self._token_buffer[self.chunk_size :] |
|
|
| |
| else: |
| |
| if finished.all(): |
| if text_finished: |
| batch = torch.cat(self._token_buffer, dim=1) |
| yield batch, True |
| self._token_buffer = [] |
| break |
| else: |
| |
| break |
| else: |
| continue |
|
|
| self.text_start_pos += condition_length + len(chunk_generated_tokens) |
| |
|
|