| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import logging |
| import math |
| import os |
| import types |
| from collections.abc import Iterator |
| from copy import deepcopy |
| from dataclasses import dataclass |
| from threading import Thread |
| from typing import List |
| from typing import Literal |
| from typing import Optional |
| from typing import Tuple |
| from typing import Union |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.utils.parametrize as P |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| from torch.nn.utils.parametrizations import weight_norm |
| from tqdm import tqdm |
| from transformers import AutoProcessor |
| from transformers import BertTokenizerFast |
| from transformers import LlamaConfig |
| from transformers import LlamaModel |
| from transformers import LogitsWarper |
| from transformers import PreTrainedModel |
| from transformers import Qwen2ForCausalLM |
| from transformers import Qwen2PreTrainedModel |
| from transformers import TextIteratorStreamer |
| from transformers import TopKLogitsWarper |
| from transformers import TopPLogitsWarper |
| from transformers.cache_utils import Cache |
| from transformers.cache_utils import DynamicCache |
| from transformers.cache_utils import EncoderDecoderCache |
| from transformers.cache_utils import StaticCache |
| from transformers.modeling_outputs import BaseModelOutputWithPast |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.models.whisper.modeling_whisper import ACT2FN |
| from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES |
| from transformers.models.whisper.modeling_whisper import WhisperConfig |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder |
|
|
| try: |
| from vector_quantize_pytorch import GroupedResidualFSQ |
| from vocos import Vocos |
| from vocos.pretrained import instantiate_class |
|
|
| _tts_deps = True |
| except: |
| _tts_deps = False |
|
|
| from .configuration_minicpm import ConditionalChatTTSConfig |
| from .configuration_minicpm import MiniCPMOConfig |
| from .modeling_navit_siglip import SiglipVisionTransformer |
| from .resampler import Resampler |
| from .utils import NumberToTextConverter |
| from .utils import sentence_end |
| from .utils import VoiceChecker |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class OmniOutput(ModelOutput): |
| text: Optional[Union[str, List[str], Iterator]] = None |
| spk_embeds: Optional[torch.FloatTensor] = None |
| audio_wav: Optional[np.ndarray] = None |
| sampling_rate: Optional[int] = None |
|
|
|
|
| class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel): |
| config_class = MiniCPMOConfig |
|
|
|
|
| class MiniCPMO(MiniCPMOPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.llm = Qwen2ForCausalLM(config) |
| self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) |
|
|
| self.embed_dim = self.llm.config.hidden_size |
|
|
| |
| if self.config.init_vision: |
| self.vpm = self.init_vision_module() |
| self.vision_dim = self.vpm.embed_dim |
| self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) |
|
|
| |
| if self.config.init_audio: |
| self.apm = self.init_audio_module() |
| audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) |
| self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) |
| self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim) |
| self.audio_encoder_layer = -1 |
|
|
| |
| if self.config.init_tts: |
| assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed." |
| self.tts = self.init_tts_module() |
|
|
| self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
|
|
| self.terminators = ["<|im_end|>", "<|endoftext|>"] |
|
|
| self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" |
| self.force_no_stop = False |
|
|
| |
| self.reset_session() |
|
|
| def reset_session(self): |
| self.session_id = None |
| self.new_user_msg = True |
| self.llm_generated = False |
| self.llm_generate_completed = False |
| self.llm_past_key_values = None |
| self.audio_past_key_values = None |
|
|
| def init_tts( |
| self, |
| tts_text_tokenizer_path=None, |
| vocos_ckpt_path=None, |
| ): |
| """ |
| load tts tokenizer and vocos |
| 1. try load form local 2. try load from huggingface |
| """ |
| from .processing_minicpmo import ChatTTSProcessor |
|
|
| if tts_text_tokenizer_path is None: |
| tts_text_tokenizer_path = os.path.join(self.config._name_or_path, "assets/chattts_tokenizer") |
| if not os.path.exists(tts_text_tokenizer_path): |
| |
| tts_text_tokenizer_path = "openbmb/chattts_tokenizer" |
|
|
| tts_text_tokenizer = BertTokenizerFast.from_pretrained(tts_text_tokenizer_path) |
| self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer) |
|
|
| if vocos_ckpt_path is None: |
| vocos_ckpt_path = os.path.join(self.config._name_or_path, "assets/Vocos.pt") |
| if not os.path.exists(vocos_ckpt_path): |
| vocos_ckpt_path = hf_hub_download(repo_id="openbmb/MiniCPM-o-2_6", subfolder="assets", filename="Vocos.pt") |
|
|
| assert os.path.exists(vocos_ckpt_path) |
| self.vocos = self.initialize_vocos(vocos_ckpt_path) |
|
|
| def initialize_vocos(self, ckpt_path): |
| feature_extractor = instantiate_class( |
| args=(), |
| init={ |
| "class_path": "vocos.feature_extractors.MelSpectrogramFeatures", |
| "init_args": {"sample_rate": 24000, "n_fft": 1024, "hop_length": 256, "n_mels": 100}, |
| }, |
| ) |
| backbone = instantiate_class( |
| args=(), |
| init={ |
| "class_path": "vocos.models.VocosBackbone", |
| "init_args": {"input_channels": 100, "dim": 512, "intermediate_dim": 1536, "num_layers": 8}, |
| }, |
| ) |
| head = instantiate_class( |
| args=(), |
| init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}}, |
| ) |
| vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32) |
| vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True)) |
| return vocos |
|
|
| def init_vision_module(self): |
| if self.config._attn_implementation == "flash_attention_2": |
| self.config.vision_config._attn_implementation = "flash_attention_2" |
| else: |
| self.config.vision_config._attn_implementation = "eager" |
| model = SiglipVisionTransformer(self.config.vision_config) |
| if self.config.drop_vision_last_layer: |
| model.encoder.layers = model.encoder.layers[:-1] |
|
|
| setattr(model, "embed_dim", model.embeddings.embed_dim) |
| setattr(model, "patch_size", model.embeddings.patch_size) |
|
|
| return model |
|
|
| def init_resampler(self, embed_dim, vision_dim): |
| return Resampler( |
| num_queries=self.config.query_num, |
| embed_dim=embed_dim, |
| num_heads=embed_dim // 128, |
| kv_dim=vision_dim, |
| adaptive=True, |
| ) |
|
|
| def init_audio_module(self): |
| model = MiniCPMWhisperEncoder(self.config.audio_config) |
| return model |
|
|
| def init_tts_module(self): |
| model = ConditionalChatTTS(self.config.tts_config) |
| return model |
|
|
| def get_input_embeddings(self): |
| return self.llm.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.llm.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.llm.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.llm.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.llm = decoder |
|
|
| def get_decoder(self): |
| return self.llm |
|
|
| def subsequent_chunk_mask( |
| self, |
| size: int, |
| chunk_size: int, |
| num_left_chunks: int = -1, |
| device: torch.device = torch.device("cpu"), |
| num_lookhead: int = 0, |
| ) -> torch.Tensor: |
| """Create mask for subsequent steps (size, size) with chunk size, |
| this is for streaming encoder |
| |
| Args: |
| size (int): size of mask |
| chunk_size (int): size of chunk |
| num_left_chunks (int): number of left chunks |
| <0: use full chunk |
| >=0: use num_left_chunks |
| device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
| |
| Returns: |
| torch.Tensor: mask |
| |
| Examples: |
| >>> subsequent_chunk_mask(4, 2) |
| [[1, 1, 0, 0], |
| [1, 1, 0, 0], |
| [1, 1, 1, 1], |
| [1, 1, 1, 1]] |
| """ |
| ret = torch.zeros(size, size, device=device, dtype=torch.bool) |
| for i in range(size): |
| if num_left_chunks < 0: |
| start = 0 |
| else: |
| start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) |
| ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) |
| ret[i, start:ending] = True |
| return ret |
|
|
| def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
| """ |
| Computes the output length of the convolutional layers and the output length of the audio encoder |
| """ |
| input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 |
| input_lengths_after_pooling = ( |
| input_lengths_after_cnn - self.config.audio_pool_step |
| ) // self.config.audio_pool_step + 1 |
| input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) |
|
|
| return input_lengths_after_cnn, input_lengths_after_pooling |
|
|
| def get_vllm_embedding(self, data): |
| """ |
| Compute all visual embeddings, and set into llm embeddings. |
| Args: |
| data: Dict |
| tgt_sizes: image size after patch embedding |
| pixel_values: image features |
| image_bound: position of each picture corresponding to input_ids |
| input_ids: full input_ids, include placeholder |
| Returns: |
| embedding with vision, vision_hidden_states |
| """ |
| if "vision_hidden_states" not in data: |
| dtype = self.llm.model.embed_tokens.weight.dtype |
| device = self.llm.model.embed_tokens.weight.device |
| tgt_sizes = data["tgt_sizes"] |
| pixel_values_list = data["pixel_values"] |
| vision_hidden_states = [] |
| all_pixel_values = [] |
| img_cnt = [] |
| for pixel_values in pixel_values_list: |
| img_cnt.append(len(pixel_values)) |
| all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) |
|
|
| |
| if all_pixel_values: |
| tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] |
| tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) |
|
|
| max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) |
|
|
| all_pixel_values = torch.nn.utils.rnn.pad_sequence( |
| all_pixel_values, batch_first=True, padding_value=0.0 |
| ) |
| B, L, _ = all_pixel_values.shape |
| all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) |
|
|
| patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) |
| for i in range(B): |
| patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True |
|
|
| vision_batch_size = self.config.vision_batch_size |
| all_pixel_values = all_pixel_values.type(dtype) |
| if B > vision_batch_size: |
| hs = [] |
| for i in range(0, B, vision_batch_size): |
| start_idx = i |
| end_idx = i + vision_batch_size |
| tmp_hs = self.vpm( |
| all_pixel_values[start_idx:end_idx], |
| patch_attention_mask=patch_attn_mask[start_idx:end_idx], |
| tgt_sizes=tgt_sizes[start_idx:end_idx], |
| ).last_hidden_state |
| hs.append(tmp_hs) |
| vision_embedding = torch.cat(hs, dim=0) |
| else: |
| vision_embedding = self.vpm( |
| all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes |
| ).last_hidden_state |
| vision_embedding = self.resampler(vision_embedding, tgt_sizes) |
|
|
| start = 0 |
| for pixel_values in pixel_values_list: |
| img_cnt = len(pixel_values) |
| if img_cnt > 0: |
| vision_hidden_states.append(vision_embedding[start : start + img_cnt]) |
| start += img_cnt |
| else: |
| vision_hidden_states.append([]) |
| else: |
| if self.training: |
| dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) |
| tgt_sizes = torch.Tensor( |
| [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]] |
| ).type(torch.int32) |
| dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) |
| else: |
| dummy_feature = [] |
| for _ in range(len(pixel_values_list)): |
| vision_hidden_states.append(dummy_feature) |
|
|
| else: |
| vision_hidden_states = data["vision_hidden_states"] |
|
|
| if hasattr(self.llm.config, "scale_emb"): |
| vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb |
| else: |
| vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) |
|
|
| new_vllm_embedding = vllm_embedding.clone() |
| |
| vision_hidden_states = [ |
| i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states |
| ] |
| |
| bs = len(data["input_ids"]) |
| for i in range(bs): |
| cur_vs_hs = vision_hidden_states[i] |
| if len(cur_vs_hs) > 0: |
| cur_vllm_emb = vllm_embedding[i] |
| cur_image_bound = data["image_bound"][i] |
| if len(cur_image_bound) > 0: |
| image_indices = torch.stack( |
| [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] |
| ).to(vllm_embedding.device) |
|
|
| new_vllm_embedding[i] = cur_vllm_emb.scatter( |
| 0, |
| image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), |
| cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), |
| ) |
|
|
| elif self.training: |
| new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0 |
|
|
| return new_vllm_embedding, vision_hidden_states |
|
|
| def get_audio_embedding_streaming(self, data): |
| r""" |
| Extract audio embeddings in a streaming manner using cached key-value pairs. |
| |
| This method processes incoming audio features incrementally and stores/updates `past_key_values` |
| for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended |
| for streaming scenarios. |
| |
| Args: |
| data (dict): |
| - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. |
| - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. |
| |
| Returns: |
| List[List[torch.Tensor]]: audio embeddings |
| """ |
| wavforms = data.get("audio_features", []) |
| audio_feature_lens_raw = data.get("audio_feature_lens", []) |
|
|
| |
| if len(wavforms) > 0: |
| audio_feature_lens = torch.hstack(audio_feature_lens_raw) |
| batch_size, _, max_mel_seq_len = wavforms.shape |
| assert batch_size == 1 |
| max_seq_len = (max_mel_seq_len - 1) // 2 + 1 |
|
|
| if self.audio_past_key_values is not None: |
| cache_length = self.audio_past_key_values[0][0].shape[2] |
| apm_max_len = self.apm.embed_positions.weight.shape[0] |
| if cache_length + max_seq_len >= apm_max_len: |
| logger.warning( |
| f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." |
| ) |
| self.audio_past_key_values = None |
|
|
| audio_outputs = self.apm(wavforms, past_key_values=self.audio_past_key_values, use_cache=True) |
| audio_states = audio_outputs.last_hidden_state |
| self.audio_past_key_values = audio_outputs.past_key_values |
|
|
| audio_embeds = self.audio_projection_layer(audio_states) |
|
|
| audio_embeds = audio_embeds.transpose(1, 2) |
| audio_embeds = self.audio_avg_pooler(audio_embeds) |
| audio_embeds = audio_embeds.transpose(1, 2) |
|
|
| _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) |
|
|
| num_audio_tokens = feature_lens_after_pooling |
|
|
| final_audio_embeds = [] |
| idx = 0 |
| for i in range(len(audio_feature_lens_raw)): |
| target_audio_embeds = [] |
| for _ in range(len(audio_feature_lens_raw[i])): |
| target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) |
| idx += 1 |
| final_audio_embeds.append(target_audio_embeds) |
| return final_audio_embeds |
| else: |
| return [] |
|
|
| def get_audio_embedding(self, data, chunk_length=-1, dummy=True): |
| r""" |
| Extract full audio embeddings with optional chunk-based attention. |
| |
| This method computes embeddings for all audio frames at once, either using full attention (when |
| `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does |
| not use key-value caching and is suitable for non-streaming inference. |
| |
| Args: |
| data (dict): |
| - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. |
| - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. |
| chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based |
| attention (>0) during embedding computation. |
| |
| Returns: |
| List[List[torch.Tensor]]: audio embeddings |
| """ |
| |
| wavforms = data.get("audio_features", []) |
| audio_feature_lens_raw = data.get("audio_feature_lens", []) |
|
|
| |
| if len(wavforms) > 0: |
| audio_feature_lens = torch.hstack(audio_feature_lens_raw) |
| batch_size, _, max_mel_seq_len = wavforms.shape |
| max_seq_len = (max_mel_seq_len - 1) // 2 + 1 |
|
|
| |
| seq_range = ( |
| torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device) |
| .unsqueeze(0) |
| .expand(batch_size, max_seq_len) |
| ) |
| lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) |
| |
| padding_mask = seq_range >= lengths_expand |
|
|
| audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( |
| batch_size, 1, max_seq_len, max_seq_len |
| ) |
| audio_attention_mask = audio_attention_mask_.to( |
| dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device |
| ) |
|
|
| if chunk_length > 0: |
| chunk_num_frame = int(chunk_length * 50) |
| chunk_mask = self.subsequent_chunk_mask( |
| size=max_seq_len, |
| chunk_size=chunk_num_frame, |
| num_left_chunks=-1, |
| device=audio_attention_mask_.device, |
| ) |
| audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) |
|
|
| audio_attention_mask[audio_attention_mask_] = float("-inf") |
| audio_states = self.apm( |
| wavforms, output_hidden_states=True, attention_mask=audio_attention_mask |
| ).hidden_states[self.audio_encoder_layer] |
| audio_embeds = self.audio_projection_layer(audio_states) |
|
|
| audio_embeds = audio_embeds.transpose(1, 2) |
| audio_embeds = self.audio_avg_pooler(audio_embeds) |
| audio_embeds = audio_embeds.transpose(1, 2) |
|
|
| _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) |
|
|
| num_audio_tokens = feature_lens_after_pooling |
|
|
| final_audio_embeds = [] |
| idx = 0 |
| for i in range(len(audio_feature_lens_raw)): |
| target_audio_embeds = [] |
| for _ in range(len(audio_feature_lens_raw[i])): |
| target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) |
| idx += 1 |
| final_audio_embeds.append(target_audio_embeds) |
| return final_audio_embeds |
| elif self.training and dummy: |
| dtype = self.apm.embed_positions.weight.dtype |
| device = self.apm.embed_positions.weight.device |
|
|
| dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) |
| audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] |
|
|
| audio_embeds = self.audio_projection_layer(audio_states) |
|
|
| audio_embeds = audio_embeds.transpose(1, 2) |
| audio_embeds = self.audio_avg_pooler(audio_embeds) |
| audio_embeds = audio_embeds.transpose(1, 2) |
| return [audio_embeds] |
|
|
| else: |
| return [] |
|
|
| def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False): |
| """ |
| Args: |
| data: |
| input_embeddings: |
| chunk_length: whisper use full attention or chunk attention |
| stream_input: use streaming audio embedding |
| Returns: |
| final embeddings with audio feature |
| """ |
| if stream_input: |
| audio_embeddings = self.get_audio_embedding_streaming(data) |
| else: |
| audio_embeddings = self.get_audio_embedding(data, chunk_length) |
|
|
| bs = len(input_embeddings) |
| if len(data.get("audio_features", [])) > 0: |
| assert len(audio_embeddings) == len(input_embeddings) |
| if len(audio_embeddings) > 0: |
| audio_bounds = data["audio_bounds"] |
|
|
| if self.config.chunk_input: |
| for i in range(bs): |
| if not audio_embeddings[i]: |
| continue |
| audio_embs = torch.cat(audio_embeddings[i], dim=0).to( |
| device=input_embeddings.device, dtype=input_embeddings.dtype |
| ) |
| audio_start_pos = 0 |
| for bound in audio_bounds[i]: |
| audio_len = bound[1] - bound[0] |
| input_embeddings[i, bound[0] : bound[1]] = audio_embs[ |
| audio_start_pos : audio_start_pos + audio_len, : |
| ] |
| audio_start_pos += audio_len |
| else: |
| for i in range(bs): |
| audio_embs = audio_embeddings[i] |
| bounds = audio_bounds[i] |
| for embs, bound in zip(audio_embs, bounds): |
| audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( |
| input_embeddings.device |
| ) |
|
|
| if embs.shape[0] != len(audio_indices): |
| raise ValueError( |
| f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " |
| f"to input indices of length {len(audio_indices)}" |
| ) |
| input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) |
| elif self.training: |
| for i in range(bs): |
| |
| input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0 |
|
|
| return input_embeddings |
|
|
| def forward(self, data, **kwargs): |
| vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) |
|
|
| if self.config.init_audio: |
| vllm_embedding = self.get_omni_embedding( |
| data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length |
| ) |
|
|
| position_ids = data["position_ids"] |
| if position_ids.dtype != torch.int64: |
| position_ids = position_ids.long() |
|
|
| |
| for key in ["input_ids", "inputs_embeds", "position_ids"]: |
| if key in kwargs: |
| del kwargs[key] |
|
|
| return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs) |
|
|
| def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): |
| kwargs.pop("output_hidden_states", None) |
| kwargs.pop("return_dict_in_generate", None) |
| terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| outputs = self.llm.generate( |
| inputs_embeds=inputs_embeds, |
| pad_token_id=0, |
| eos_token_id=terminators, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict_in_generate=True, |
| **kwargs, |
| ) |
| return outputs |
|
|
| def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): |
| terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| streamer = TextIteratorStreamer(tokenizer=tokenizer) |
| generation_kwargs = { |
| "inputs_embeds": inputs_embeds, |
| "pad_token_id": 0, |
| "eos_token_id": terminators, |
| "streamer": streamer, |
| } |
| generation_kwargs.update(kwargs) |
|
|
| thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| return streamer |
|
|
| def _decode_text(self, result_ids, tokenizer): |
| terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| result_text = [] |
| for result in result_ids: |
| result = result[result != 0] |
| if result[0] == tokenizer.bos_id: |
| result = result[1:] |
| if result[-1] in terminators: |
| result = result[:-1] |
| result_text.append(tokenizer.decode(result)) |
| return result_text |
|
|
| def get_sys_prompt(self, ref_audio=None, mode="default", language="zh"): |
| """ |
| Choose different system prompts according to different tasks |
| Args: |
| ref_audio: if ref_audio is not None, will use the voice cloning prompts, and the voice |
| generated by the model will refer to the timbre of ref audio |
| mode: |
| "default": default system prompt and not refer to any task |
| "omni": input video and audio simultaneously |
| "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant. |
| "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. |
| "voice_cloning": TTS mode, the model will clone the voice of ref_audio. |
| language: prompts language, the model has the ability to automatically select the response language |
| based on the question language |
| Returns: |
| |
| """ |
| if ref_audio is not None: |
| assert isinstance(ref_audio, np.ndarray), "ref_audio error" |
| if mode == "omni": |
| if language == "zh": |
| sys_prompt = "你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。" |
| vc_prompt_prefix = sys_prompt + "模仿输入音频中的声音特征。" |
| vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" |
| else: |
| sys_prompt = "You are a helpful assistant. You can accept video, audio and text input and output voice and text. " |
| vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt." |
| vc_prompt_suffix = "As an assistant, you will speak using this voice style." |
|
|
| if ref_audio is not None: |
| sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
|
|
| else: |
| sys_msgs = {"role": "user", "content": [sys_prompt]} |
|
|
| return sys_msgs |
| elif mode == "audio_assistant": |
| if language == "zh": |
| vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" |
| else: |
| vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content." |
| vc_prompt_suffix = "You are a helpful assistant with the above voice style." |
|
|
| if ref_audio is not None: |
| sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
|
|
| else: |
| logger.warning( |
| "Warning: ref_audio is None, speech generation will be performed based on the default voice." |
| ) |
| sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]} |
|
|
| return sys_msgs |
| elif mode == "audio_roleplay": |
| if language == "zh": |
| vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。" |
| else: |
| vc_prompt_prefix = "Clone the voice in the provided audio prompt." |
| vc_prompt_suffix = "Try to role-play the character based on the audio prompt above." |
|
|
| if ref_audio is not None: |
| sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
| else: |
| print("Warning: ref_audio is None, speech generation will be performed based on the default voice.") |
| sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]} |
|
|
| return sys_msgs |
| elif mode == "voice_cloning": |
| if language == "zh": |
| vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| else: |
| vc_prompt_prefix = "Clone the voice in the provided audio prompt." |
|
|
| if ref_audio is not None: |
| sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio]} |
| else: |
| raise ValueError("ref_audio con't be None in voice_cloning mode.") |
|
|
| return sys_msgs |
| else: |
| sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text." |
| sys_msgs = {"role": "user", "content": [sys_prompt]} |
|
|
| return sys_msgs |
|
|
| def generate( |
| self, |
| input_ids=None, |
| pixel_values=None, |
| tgt_sizes=None, |
| audio_features=[], |
| audio_feature_lens=None, |
| image_bound=None, |
| audio_bounds=None, |
| spk_bounds=None, |
| attention_mask=None, |
| tokenizer=None, |
| vision_hidden_states=None, |
| stream=False, |
| decode_text=True, |
| **kwargs, |
| ): |
| assert input_ids is not None |
| assert len(input_ids) == len(pixel_values) |
|
|
| model_inputs = { |
| "input_ids": input_ids, |
| "audio_features": audio_features, |
| "audio_feature_lens": audio_feature_lens, |
| "image_bound": image_bound, |
| "audio_bounds": audio_bounds, |
| "spk_bounds": spk_bounds, |
| } |
|
|
| if vision_hidden_states is None: |
| model_inputs["pixel_values"] = pixel_values |
| model_inputs["tgt_sizes"] = tgt_sizes |
| else: |
| model_inputs["vision_hidden_states"] = vision_hidden_states |
|
|
| model_output = {} |
| with torch.inference_mode(): |
| model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) |
| model_inputs["inputs_embeds"] = self.get_omni_embedding( |
| model_inputs, |
| input_embeddings=model_inputs["inputs_embeds"], |
| chunk_length=self.config.audio_chunk_length, |
| ) |
|
|
| if stream: |
| result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) |
| |
| outputs = {} |
| else: |
| outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) |
|
|
| result = self._decode_text(outputs.sequences, tokenizer) |
| |
| if decode_text is False: |
| return outputs |
| |
| return result, outputs |
|
|
| def chat( |
| self, |
| image=None, |
| msgs=None, |
| tokenizer=None, |
| processor=None, |
| vision_hidden_states=None, |
| max_new_tokens=2048, |
| min_new_tokens=0, |
| sampling=True, |
| max_inp_length=32768, |
| stream=False, |
| chunk_input=True, |
| omni_input=False, |
| max_slice_nums=None, |
| use_image_id=None, |
| use_tts_template=False, |
| generate_audio=False, |
| return_spk_embed=False, |
| return_dict=False, |
| output_audio_path=None, |
| **kwargs, |
| ): |
| """ |
| Unified chat function |
| |
| Args: |
| image: use for batch_size=1 vqa, It is not recommended to continue to use this parameter |
| msgs: the input chat msgs, support text: (string) / image: (PIL.Image) / audio (numpy.ndarray) |
| tokenizer: tokenizer for llm |
| processor: if None, use the default processor |
| max_new_tokens: the maximum length of the generation |
| min_new_tokens: the minimum length of the generation |
| sampling: whether to use sampling decoding or beam search decoding |
| max_inp_length: the maximum length of input |
| stream: whether to return generator, only used when tts is not required |
| chunk_input: whether to split audio into 1s chunks |
| omni_input: determine whether it is omni mode |
| max_slice_nums: control the maximum number of image slices |
| use_image_id: for video understanding or omni understanding, use_image_id should be False |
| use_tts_template: if the msgs contain audio, use_tts_template should be True |
| generate_audio: whether to generate audio output, only used when return_dict=True |
| return_spk_embed: whether to return spk embedding, only used when return_dict=True |
| return_dict: whether to return dict |
| output_audio_path: audio save path when generate_audio |
| **kwargs: |
| """ |
| if isinstance(msgs[0], list): |
| batched = True |
| else: |
| batched = False |
|
|
| if generate_audio or return_spk_embed: |
| return_dict = True |
|
|
| msgs_list = msgs |
| images_list = image |
|
|
| if batched is False: |
| images_list, msgs_list = [images_list], [msgs_list] |
| else: |
| assert images_list is None, "Please integrate image to msgs when using batch inference." |
| images_list = [None] * len(msgs_list) |
| assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." |
|
|
| if processor is None: |
| if self.processor is None: |
| self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
| processor = self.processor |
|
|
| assert ( |
| self.config.query_num == processor.image_processor.image_feature_size |
| ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| assert ( |
| self.config.patch_size == processor.image_processor.patch_size |
| ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| assert ( |
| self.config.use_image_id == processor.image_processor.use_image_id |
| ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| assert ( |
| self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums |
| ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| assert ( |
| self.config.slice_mode == processor.image_processor.slice_mode |
| ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
|
|
| prompts_lists = [] |
| input_images_list = [] |
| input_audios_list = [] |
| audio_parts_list = [] |
|
|
| for image, msgs in zip(images_list, msgs_list): |
| if isinstance(msgs, str): |
| msgs = json.loads(msgs) |
| copy_msgs = deepcopy(msgs) |
|
|
| assert len(msgs) > 0, "msgs is empty" |
| assert sampling or not stream, "if use stream mode, make sure sampling=True" |
|
|
| if image is not None and isinstance(copy_msgs[0]["content"], str): |
| copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] |
|
|
| images = [] |
| audios = [] |
| audio_parts = [] |
| for i, msg in enumerate(copy_msgs): |
| role = msg["role"] |
| content = msg["content"] |
| assert role in ["system", "user", "assistant"] |
| if i == 0: |
| assert role in ["user", "system"], "The role of first msg should be user" |
| if isinstance(content, str): |
| content = [content] |
| cur_msgs = [] |
| for c in content: |
| if isinstance(c, Image.Image): |
| images.append(c) |
| cur_msgs.append("(<image>./</image>)") |
| elif isinstance(c, np.ndarray): |
| audios.append(c) |
| audio_parts.append(i) |
| cur_msgs.append("(<audio>./</audio>)") |
| use_tts_template = True |
| elif isinstance(c, str): |
| cur_msgs.append(c) |
| if omni_input: |
| msg["content"] = "".join(cur_msgs) |
| else: |
| msg["content"] = "\n".join(cur_msgs) |
|
|
| prompts_lists.append( |
| processor.tokenizer.apply_chat_template( |
| copy_msgs, |
| tokenize=False, |
| add_generation_prompt=True, |
| chat_template=self.default_tts_chat_template if use_tts_template else None, |
| ) |
| ) |
| input_images_list.append(images) |
| input_audios_list.append(audios) |
| audio_parts_list.append(audio_parts) |
|
|
| inputs = processor( |
| prompts_lists, |
| input_images_list, |
| input_audios_list, |
| audio_parts_list, |
| max_slice_nums=max_slice_nums, |
| use_image_id=use_image_id, |
| chunk_input=chunk_input, |
| return_tensors="pt", |
| max_length=max_inp_length, |
| ).to(self.device) |
|
|
| if sampling: |
| generation_config = { |
| "top_p": 0.8, |
| "top_k": 100, |
| "temperature": 0.7, |
| "do_sample": True, |
| "repetition_penalty": 1.05, |
| } |
| else: |
| generation_config = { |
| "num_beams": 3, |
| "repetition_penalty": 1.2, |
| } |
|
|
| if min_new_tokens > 0: |
| generation_config["min_new_tokens"] = min_new_tokens |
|
|
| generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
|
|
| inputs.pop("image_sizes") |
| with torch.inference_mode(): |
| res, outputs = self.generate( |
| **inputs, |
| tokenizer=tokenizer, |
| max_new_tokens=max_new_tokens, |
| vision_hidden_states=vision_hidden_states, |
| stream=stream, |
| **generation_config, |
| ) |
|
|
| if stream: |
|
|
| def stream_gen(): |
| for text in res: |
| for term in self.terminators: |
| text = text.replace(term, "") |
| yield text |
|
|
| if return_dict: |
| return OmniOutput(text=stream_gen()) |
| else: |
| return stream_gen() |
|
|
| else: |
| spk_embeds = wav_numpy = sr = None |
|
|
| if batched: |
| answer = res |
| else: |
| answer = res[0] |
|
|
| if use_tts_template and generate_audio: |
| mel_spec = self._generate_mel_spec(inputs, outputs, answer) |
| wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path) |
|
|
| if return_spk_embed: |
| spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
|
|
| if isinstance(answer, list): |
| answer = [i.replace(tokenizer.tts_end, "") for i in answer] |
| else: |
| answer = answer.replace(tokenizer.tts_end, "") |
|
|
| if return_dict: |
| return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr) |
| else: |
| return answer |
|
|
| @torch.inference_mode() |
| def streaming_prefill( |
| self, |
| session_id, |
| msgs, |
| tokenizer, |
| omni_input=True, |
| max_slice_nums=None, |
| ls_temperature=1.0, |
| **kwargs, |
| ): |
| """ |
| Streaming video/audio input and output audio stream, Only support batch_size=1 |
| Args: |
| session_id: Note: new connection should use a new session_id |
| """ |
| assert session_id is not None |
| if self.session_id is None or session_id != self.session_id: |
| self.is_first = True |
| else: |
| self.is_first = False |
|
|
| images = [] |
| audios = [] |
|
|
| assert len(msgs) == 1 |
| copy_msgs = deepcopy(msgs) |
| msg = copy_msgs[0] |
|
|
| assert msg["role"] in ["system", "user", "assistant"] |
|
|
| content = msg["content"] |
| cur_msgs = [] |
| for j, c in enumerate(content): |
| if isinstance(c, Image.Image): |
| images.append(c) |
| cur_msgs.append("(<image>./</image>)") |
| elif isinstance(c, np.ndarray): |
| audios.append(c) |
| cur_msgs.append("(<audio>./</audio>)") |
| elif isinstance(c, str): |
| cur_msgs.append(c) |
| else: |
| logger.error("Invalid content type:", c) |
|
|
| cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input) |
| if not self.is_first and self.new_user_msg and msg["role"] == "user": |
| if self.llm_generated: |
| if self.llm_generate_completed: |
| msg["content"] = "<|im_end|>\n<|im_start|>user\n" + cur_contents |
| else: |
| msg["content"] = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents |
| else: |
| msg["content"] = "<|im_start|>user\n" + cur_contents |
| self.new_user_msg = False |
| else: |
| msg["content"] = cur_contents |
|
|
| if msg["role"] in ["system", "assistant"]: |
| self.new_user_msg = True |
| self.audio_past_key_values = None |
|
|
| if self.is_first: |
| |
| logger.info(f"new session_id: {session_id}, reset kv cache") |
| self.reset_session() |
| self.session_id = session_id |
|
|
| prompt = tokenizer.apply_chat_template( |
| copy_msgs, tokenize=False, add_generation_prompt=False, chat_template=self.default_tts_chat_template |
| ) |
| add_special_tokens = True |
| else: |
| prompt = copy_msgs[0]["content"] |
| add_special_tokens = False |
|
|
| model_inputs = self.processor( |
| [prompt], |
| [images], |
| [audios], |
| max_slice_nums=1 if max_slice_nums is None else max_slice_nums, |
| use_image_id=False, |
| chunk_input=True, |
| return_tensors="pt", |
| max_length=None, |
| sampling_rate=16000, |
| add_special_tokens=add_special_tokens, |
| ).to(self.device) |
|
|
| |
| model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs) |
| |
| inputs_embeds = self.get_omni_embedding( |
| model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=True |
| ) |
|
|
| if self.is_first: |
| |
| self.audio_past_key_values = None |
|
|
| if self.llm_past_key_values is not None: |
| cache_length = self.llm_past_key_values[0][0].shape[2] |
| else: |
| cache_length = 0 |
|
|
| attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device) |
|
|
| |
| outputs = self.llm( |
| past_key_values=self.llm_past_key_values, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=None, |
| use_cache=True, |
| return_dict=True, |
| ) |
| self.llm_past_key_values = outputs["past_key_values"] |
| return |
|
|
| @torch.inference_mode() |
| def streaming_generate( |
| self, |
| session_id, |
| tokenizer, |
| max_new_tokens=512, |
| min_new_tokens=0, |
| sampling=True, |
| generate_audio=True, |
| enable_regenerate=False, |
| **kwargs, |
| ): |
| """ |
| Streaming video/audio input and output audio stream |
| Args: |
| """ |
| if sampling: |
| generation_config = { |
| "top_p": 0.8, |
| "top_k": 100, |
| "temperature": 0.7, |
| "do_sample": True, |
| "repetition_penalty": 1.05, |
| } |
| else: |
| generation_config = { |
| "num_beams": 3, |
| "repetition_penalty": 1.2, |
| } |
| generation_config["min_new_tokens"] = min_new_tokens |
| generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
|
|
| |
| |
| self.new_user_msg = True |
| self.llm_generated = True |
| self.llm_generate_completed = False |
| self.audio_past_key_values = None |
|
|
| terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>" |
| input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda() |
|
|
| spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0] |
| spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0] |
| spk_bounds = [ |
| torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) |
| ] |
|
|
| cache_length = past_length = self.llm_past_key_values[0][0].shape[2] |
| attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) |
|
|
| generation_config["max_new_tokens"] = max_new_tokens |
| streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config) |
|
|
| if generate_audio: |
| result = self._generate_mel_spec_audio_streaming( |
| spk_bounds, streamer, output_chunk_size=25, enable_regenerate=enable_regenerate |
| ) |
| return result |
| else: |
| return streamer |
|
|
| def llm_generate_chunk(self, input_ids, attention_mask, tokenizer, terminators, generation_config): |
| def check_uncompleted_token(ids): |
| cur_text = tokenizer.decode(ids) |
| end = len(ids) |
| while cur_text[-1] == "�": |
| end -= 1 |
| if end == 0: |
| break |
| cur_text = tokenizer.decode(ids[:end]) |
| return end |
|
|
| max_new_tokens = int(generation_config.pop("max_new_tokens", 2048)) |
| new_len = 0 |
| first_chunk = True |
| eos = False |
| left_ids = None |
|
|
| while True: |
| outputs = self.llm.generate( |
| input_ids=input_ids, |
| past_key_values=self.llm_past_key_values, |
| attention_mask=attention_mask, |
| use_cache=True, |
| max_new_tokens=3, |
| pad_token_id=0, |
| output_hidden_states=True if first_chunk else False, |
| return_dict_in_generate=True, |
| eos_token_id=terminators, |
| **generation_config, |
| ) |
| if outputs.sequences[0, -1] in terminators: |
| eos = True |
| input_len = input_ids.shape[1] |
| cur_ids = outputs.sequences[:, input_len:] |
| new_len += cur_ids.shape[1] |
|
|
| if left_ids is not None and left_ids.shape[1] > 0: |
| cur_ids = torch.cat([left_ids, cur_ids], dim=1) |
| end = check_uncompleted_token(cur_ids[0]) |
| left_ids = cur_ids[:, end:] |
| cur_ids = cur_ids[:, :end] |
| text = self._decode_text(cur_ids, tokenizer)[0] if end > 0 else "" |
|
|
| self.llm_past_key_values = outputs.past_key_values |
| input_ids = outputs.sequences[:, -1:] |
| cache_length = past_length = self.llm_past_key_values[0][0].shape[2] |
| attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) |
|
|
| res = {"text": text} |
| if first_chunk: |
| res["hidden_states"] = outputs.hidden_states |
| first_chunk = False |
| yield res |
|
|
| if eos: |
| self.llm_generate_completed = True |
| break |
| if new_len >= max_new_tokens: |
| logger.debug(f"LLM generation {new_len} exceeds max_new_tokens({max_new_tokens}), break.") |
| break |
|
|
| def prepare_tts_text(self, text): |
| tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False) |
| tts_tokens_len = len(tts_tokens) |
| if tts_tokens_len < self.tts.streaming_text_reserved_len: |
| num_pad_tokens = self.tts.streaming_text_reserved_len - tts_tokens_len |
|
|
| pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1) |
| else: |
| tts_tokens = tts_tokens[0 : self.tts.streaming_text_reserved_len] |
| tts_tokens_len = len(tts_tokens) |
| text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False) |
| pad_str = "" |
| spk_emb_placeholder_tts = "[spk_emb]" * self.tts.num_spk_embs |
|
|
| new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]" |
| return new_text_tts, tts_tokens_len |
|
|
| def get_tts_text_start_token_ids(self): |
| text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs |
| tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[ |
| "input_ids" |
| ].cuda() |
| return tts_input_ids |
|
|
| def _build_streaming_mask(self, tts_tokens_len): |
| tts_sequence_full_length = ( |
| 1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1 |
| ) |
| streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) |
| streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1 |
| streaming_attention_mask[-1] = 1 |
| return streaming_attention_mask |
|
|
| def _get_last_spk_embeds(self, inputs, outputs): |
| last_hidden_states = [hs[-1] for hs in outputs.hidden_states] |
|
|
| |
| last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) |
|
|
| |
| spk_bound = inputs["spk_bounds"][0][-1] |
|
|
| spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] |
| return spk_embeds |
|
|
| def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048): |
| spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
|
|
| text = text.split("<|tts_bos|>")[-1] |
| gen_text = text.split("<|tts_eos|>")[0] |
| tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) |
| tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) |
| streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
|
|
| logits_warpers, logits_processors = gen_logits( |
| num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty |
| ) |
|
|
| condition_length = ( |
| 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 |
| ) |
|
|
| dtype = self.tts.emb_text.weight.dtype |
| emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) |
| past_key_values = [ |
| ( |
| torch.zeros( |
| 1, |
| self.tts.config.num_attention_heads, |
| condition_length - 1, |
| self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| dtype=emb.dtype, |
| device=self.tts.device, |
| ), |
| torch.zeros( |
| 1, |
| self.tts.config.num_attention_heads, |
| condition_length - 1, |
| self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| dtype=emb.dtype, |
| device=self.tts.device, |
| ), |
| ) |
| for _ in range(self.tts.config.num_hidden_layers) |
| ] |
|
|
| audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) |
|
|
| eos_lab = False |
| for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)): |
| if chunk_idx == 0: |
| begin = chunk_idx * self.tts.streaming_text_chunk_size + 0 |
| end = ( |
| (chunk_idx + 1) * self.tts.streaming_text_chunk_size |
| + 1 |
| + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| ) |
| else: |
| begin = ( |
| chunk_idx * self.tts.streaming_text_chunk_size |
| + 1 |
| + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| ) |
| end = min( |
| (chunk_idx + 1) * self.tts.streaming_text_chunk_size |
| + 1 |
| + self.tts.use_speaker_embedding * self.tts.num_spk_embs, |
| condition_length - 1, |
| ) |
|
|
| if end - begin > 0: |
| text_input_ids = tts_input_ids[:, begin:end] |
| position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
|
|
| if begin == 0: |
| past_key_values = self.tts.prefill_text( |
| input_ids=text_input_ids, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| lm_spk_emb_last_hidden_states=spk_embeds, |
| ) |
| else: |
| past_key_values = self.tts.prefill_text( |
| input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values |
| ) |
|
|
| outputs = self.tts.generate( |
| input_ids=audio_input_ids, |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=output_chunk_size, |
| force_no_stop=self.force_no_stop, |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
| audio_input_ids = outputs.audio_input_ids |
| past_key_values = outputs.past_key_values |
|
|
| if outputs.finished: |
| logger.debug("Generation finished.") |
| eos_lab = True |
| break |
|
|
| if not eos_lab: |
| logger.debug("eos_lab False, Generation continue.") |
| while True: |
| outputs = self.tts.generate( |
| input_ids=audio_input_ids, |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=output_chunk_size, |
| force_no_stop=self.force_no_stop, |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
|
|
| audio_input_ids = outputs.audio_input_ids |
| past_key_values = outputs.past_key_values |
|
|
| if outputs.finished: |
| logger.debug("Generation finished.") |
| break |
| if outputs.new_ids.shape[1] > tts_max_new_tokens: |
| logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.") |
| break |
|
|
| mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids) |
| return mel_spec |
|
|
| def _linear_overlap_add2_wav(self, frames: List[torch.Tensor], overlap: int): |
| """ |
| Merge two audio waveforms with smooth in streaming audio generation. |
| Borrowed some codes from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py` |
| """ |
| assert len(frames) == 2 |
| device = frames[0].device |
| dtype = frames[0].dtype |
| |
|
|
| frame0_length = frames[0].shape[-1] |
| frame1_length = frames[1].shape[-1] |
| total_size = frame0_length + frame1_length - overlap |
| weight_len = max(frame0_length, frame1_length) + overlap |
| t = torch.linspace(0, 1, weight_len + 2, device=device, dtype=dtype)[1:-1] |
| weight = 0.5 - (t - 0.5).abs() |
|
|
| sum_weight = torch.zeros(total_size, device=device, dtype=dtype) |
| out = torch.zeros(total_size, device=device, dtype=dtype) |
| offset: int = 0 |
|
|
| out[offset : offset + frame0_length] += weight[-frame0_length:] * frames[0] |
| sum_weight[offset : offset + frame0_length] += weight[-frame0_length:] |
| offset += frame0_length - overlap |
| out[offset : offset + frame1_length] += weight[:frame1_length] * frames[1] |
| sum_weight[offset : offset + frame1_length] += weight[:frame1_length] |
|
|
| assert sum_weight.min() > 0 |
| out = out / sum_weight |
| return out[:frame0_length], out[frame0_length:] |
|
|
| def _generate_mel_spec_audio_streaming( |
| self, |
| spk_bounds, |
| streamer, |
| output_chunk_size=25, |
| spk_embeds=None, |
| prev_seg_text_ids=None, |
| prev_seg_text_left="", |
| prev_seg_audio_ids=None, |
| enable_regenerate=False, |
| ): |
| |
| gen_text = "" |
| tts_text = "" |
| new_segment_gen = False |
| if spk_embeds is None: |
| spk_bound = spk_bounds[0][-1] |
| r = next(streamer) |
| txt = r["text"] |
| gen_text += txt.split("<|tts_eos|>")[0] |
| tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| last_hidden_states = r["hidden_states"][0][-1][0] |
| spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] |
|
|
| |
| logits_warpers, logits_processors = gen_logits( |
| num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty |
| ) |
| condition_length = ( |
| 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 |
| ) |
| tts_start_token_len = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| dtype = self.tts.emb_text.weight.dtype |
| past_key_values = [ |
| ( |
| torch.zeros( |
| 1, |
| self.tts.config.num_attention_heads, |
| condition_length - 1, |
| self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| dtype=dtype, |
| device=self.tts.device, |
| ), |
| torch.zeros( |
| 1, |
| self.tts.config.num_attention_heads, |
| condition_length - 1, |
| self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| dtype=dtype, |
| device=self.tts.device, |
| ), |
| ) |
| for _ in range(self.tts.config.num_hidden_layers) |
| ] |
| audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) |
|
|
| |
| chunk_idx = 0 |
| new_ids_len = 0 |
| prev_text_len = 0 |
| if prev_seg_text_ids is not None and prev_seg_audio_ids is not None: |
| tts_token_lens = prev_seg_text_ids.shape[1] |
| |
| streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| position_ids = torch.arange( |
| 0, tts_token_lens + tts_start_token_len, dtype=torch.long, device=self.tts.device |
| ).unsqueeze(0) |
|
|
| text_input_ids = self.get_tts_text_start_token_ids() |
| text_input_ids = torch.cat([text_input_ids, prev_seg_text_ids], dim=1) |
| past_key_values = self.tts.prefill_text( |
| input_ids=text_input_ids, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| lm_spk_emb_last_hidden_states=spk_embeds, |
| ) |
| past_key_values = self.tts.prefill_audio_ids( |
| input_ids=prev_seg_audio_ids[:, :-1, :], |
| |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| ) |
|
|
| |
| chunk_idx += int(tts_token_lens / self.tts.streaming_text_chunk_size) |
| audio_input_ids = torch.cat([audio_input_ids, prev_seg_audio_ids], dim=1) |
| text = self.tts_processor.text_tokenizer.decode(prev_seg_text_ids[0].tolist(), add_special_tokens=False) |
|
|
| gen_text += text |
| gen_text += prev_seg_text_left |
| prev_text_len = len(gen_text) |
| new_ids_len += prev_seg_audio_ids.shape[1] |
|
|
| prev_wav = None |
| eos_lab = False |
| stop = False |
| shift_len = 180 |
| voice_checker = VoiceChecker() |
| number_converter = NumberToTextConverter() |
| lang = None |
| gen_text_raw = gen_text |
| for t, r in enumerate(streamer): |
| t += 1 |
| txt = r["text"] |
| txt = txt.split("<|tts_eos|>")[0] |
| gen_text_raw += txt |
| if t == 1 and txt == "" and prev_seg_text_ids is not None: |
| logger.warning("New segment is empty, generation finished.") |
| return |
| if t <= 2: |
| lang = number_converter.detect_language(gen_text_raw) |
| gen_text += number_converter.replace_numbers_with_text(txt, lang).replace("*", "") |
|
|
| |
| tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
|
|
| if tts_token_lens >= self.tts.streaming_text_reserved_len - shift_len: |
| end_c = sentence_end(txt) |
| if end_c: |
| end_c_idx = gen_text.rfind(end_c) |
| assert end_c_idx != -1 |
| text_left = gen_text[end_c_idx + 1 :] |
| gen_text = gen_text[: end_c_idx + 1] |
| tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| new_segment_gen = True |
| logger.debug( |
| f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, starting a new segment generation" |
| ) |
| break |
|
|
| if tts_token_lens >= (chunk_idx + 1) * self.tts.streaming_text_chunk_size: |
|
|
| |
| if chunk_idx == 0: |
| begin = 0 |
| end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len |
| else: |
| begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len |
| end = min( |
| (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len, condition_length - 1 |
| ) |
|
|
| tts_input_ids = self.tts_processor.text_tokenizer( |
| tts_text, return_tensors="pt", add_special_tokens=False |
| )["input_ids"].cuda() |
| text_input_ids = tts_input_ids[:, begin:end] |
| streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
|
|
| past_key_values = self.tts.prefill_text( |
| input_ids=text_input_ids, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, |
| ) |
| outputs = self.tts.generate( |
| input_ids=audio_input_ids, |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=output_chunk_size, |
| force_no_stop=self.force_no_stop, |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
| audio_input_ids = ( |
| outputs.audio_input_ids |
| ) |
| past_key_values = outputs.past_key_values |
| chunk_idx += 1 |
|
|
| mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) |
| new_ids_len = outputs.new_ids.shape[1] |
|
|
| wav_np, sr = self.decode_mel_to_audio(mel_spec) |
|
|
| if enable_regenerate: |
| if prev_wav is not None: |
| check_wav_np = wav_np[2048:].cpu().numpy() |
| check_mel = mel_spec[0, :, 8:].cpu().numpy() |
| else: |
| check_wav_np = wav_np.cpu().numpy() |
| check_mel = mel_spec[0].cpu().numpy() |
| if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): |
| voice_checker.reset() |
| |
| N = output_chunk_size if prev_wav is None else output_chunk_size * 2 |
| past_kv = [] |
| for i in range(len(past_key_values)): |
| past_kv.append( |
| ( |
| past_key_values[i][0][:, :, :-N, :], |
| past_key_values[i][1][:, :, :-N, :], |
| ) |
| ) |
| outputs = self.tts.generate( |
| input_ids=audio_input_ids[:, :-N, :], |
| past_key_values=past_kv, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=N, |
| force_no_stop=self.force_no_stop, |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
| audio_input_ids = outputs.audio_input_ids |
| past_key_values = outputs.past_key_values |
|
|
| new_ids_len -= N |
| mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) |
| new_ids_len = outputs.new_ids.shape[1] |
| wav_np, sr = self.decode_mel_to_audio(mel_spec) |
|
|
| if prev_wav is not None: |
| wav_y = wav_np[: len(prev_wav)] |
| prev_wav = wav_np[len(prev_wav) :] |
| cur_text = gen_text_raw[prev_text_len:] |
| prev_text_len = len(gen_text_raw) |
| yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) |
|
|
| else: |
| prev_wav = wav_np |
| else: |
| |
| if prev_wav is not None: |
| wav_np, prev_wav = self._linear_overlap_add2_wav( |
| [prev_wav, wav_np], overlap=512 * 4 |
| ) |
| cur_text = gen_text_raw[prev_text_len:] |
| prev_text_len = len(gen_text_raw) |
| yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) |
|
|
| else: |
| prev_wav = wav_np |
|
|
| if outputs.finished: |
| logger.debug("Generation finished.") |
| eos_lab = True |
| break |
|
|
| if not eos_lab and tts_text: |
| logger.debug("eos_lab False, Generation continue.") |
|
|
| if chunk_idx == 0: |
| begin = 0 |
| else: |
| begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len |
| end = tts_token_lens + tts_start_token_len + 1 |
| if end > begin: |
| tts_input_ids = self.tts_processor.text_tokenizer( |
| tts_text, return_tensors="pt", add_special_tokens=False |
| )["input_ids"].cuda() |
| text_input_ids = tts_input_ids[:, begin:end] |
| streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
|
|
| past_key_values = self.tts.prefill_text( |
| input_ids=text_input_ids, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, |
| ) |
|
|
| while True: |
| |
| outputs = self.tts.generate( |
| input_ids=audio_input_ids, |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=output_chunk_size, |
| force_no_stop=self.force_no_stop, |
| |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
| audio_input_ids = outputs.audio_input_ids |
| past_key_values = outputs.past_key_values |
| chunk_idx += 1 |
|
|
| mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) |
| new_ids_len = outputs.new_ids.shape[1] |
|
|
| wav_np, sr = self.decode_mel_to_audio(mel_spec) |
|
|
| if enable_regenerate: |
| if prev_wav is not None: |
| check_wav_np = wav_np[2048:].cpu().numpy() |
| check_mel = mel_spec[0, :, 8:].cpu().numpy() |
| else: |
| check_wav_np = wav_np.cpu().numpy() |
| check_mel = mel_spec[0].cpu().numpy() |
| if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): |
| voice_checker.reset() |
| |
| N = output_chunk_size if prev_wav is None else output_chunk_size * 2 |
| past_kv = [] |
| for i in range(len(past_key_values)): |
| past_kv.append( |
| ( |
| past_key_values[i][0][:, :, :-N, :], |
| past_key_values[i][1][:, :, :-N, :], |
| ) |
| ) |
| outputs = self.tts.generate( |
| input_ids=audio_input_ids[:, :-N, :], |
| past_key_values=past_kv, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=N, |
| force_no_stop=self.force_no_stop, |
| temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| logits_warpers=logits_warpers, |
| logits_processors=logits_processors, |
| ) |
| audio_input_ids = outputs.audio_input_ids |
| past_key_values = outputs.past_key_values |
|
|
| new_ids_len -= N |
| mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) |
| new_ids_len = outputs.new_ids.shape[1] |
| wav_np, sr = self.decode_mel_to_audio(mel_spec) |
|
|
| if prev_wav is not None: |
| wav_y = wav_np[: len(prev_wav)] |
| prev_wav = wav_np[len(prev_wav) :] |
| cur_text = gen_text_raw[prev_text_len:] |
| prev_text_len = len(gen_text_raw) |
| yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) |
| else: |
| prev_wav = wav_np |
| else: |
| |
| if prev_wav is not None: |
| wav_np, prev_wav = self._linear_overlap_add2_wav( |
| [prev_wav, wav_np], overlap=512 * 4 |
| ) |
| cur_text = gen_text_raw[prev_text_len:] |
| prev_text_len = len(gen_text_raw) |
| yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) |
| else: |
| prev_wav = wav_np |
|
|
| if outputs.finished: |
| logger.debug("Generation finished.") |
| break |
| if outputs.new_ids.shape[1] > 2048: |
| stop = True |
| logger.debug("Generation length > 2048, stopped.") |
| break |
|
|
| if prev_wav is not None: |
| cur_text = gen_text_raw[prev_text_len:] |
| yield OmniOutput(text=cur_text, audio_wav=prev_wav, sampling_rate=sr) |
|
|
| if new_segment_gen and not stop: |
| logger.debug( |
| f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, start a new segment generation" |
| ) |
| tid_len = 5 |
| prev_seg_text_ids = tts_input_ids[:, end - 1 - tid_len : end - 1] |
| aid_len = 50 |
| prev_seg_audio_ids = outputs.new_ids[:, -aid_len:, :] |
|
|
| result = self._generate_mel_spec_audio_streaming( |
| spk_bounds, |
| streamer, |
| output_chunk_size, |
| spk_embeds, |
| prev_seg_text_ids, |
| text_left, |
| prev_seg_audio_ids, |
| enable_regenerate=enable_regenerate, |
| ) |
| for res in result: |
| yield res |
|
|
| def decode_mel_to_audio(self, mel_spec, output_path=""): |
| with torch.inference_mode(): |
| wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze() |
| sr = 24000 |
| if output_path: |
| sf.write(output_path, wav_numpy.numpy(), samplerate=sr) |
| logger.info(f"Audio saved to {output_path}") |
| return wav_numpy, sr |
|
|
|
|
| |
| class MiniCPMWhisperEncoderLayer(nn.Module): |
| def __init__(self, config: WhisperConfig, layer_idx: int = None): |
| super().__init__() |
| self.embed_dim = config.d_model |
| self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( |
| embed_dim=self.embed_dim, |
| num_heads=config.encoder_attention_heads, |
| dropout=config.attention_dropout, |
| config=config, |
| layer_idx=layer_idx, |
| ) |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
| self.dropout = config.dropout |
| self.activation_fn = ACT2FN[config.activation_function] |
| self.activation_dropout = config.activation_dropout |
| self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
| self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| layer_head_mask: torch.Tensor, |
| output_attentions: bool = False, |
| past_key_values: Optional[EncoderDecoderCache] = None, |
| use_cache: Optional[bool] = False, |
| ) -> torch.Tensor: |
| r""" |
| Args: |
| hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): |
| Hidden states to be fed into the encoder layer. |
| attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): |
| Attention mask where padding elements are indicated by large negative values. |
| layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): |
| Mask to nullify selected heads of the attention modules. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attention weights. |
| past_key_values (`EncoderDecoderCache`, *optional*): |
| Past key-value pairs used for incremental decoding. |
| use_cache (`bool`, *optional*): |
| Whether or not to return updated `past_key_values` for caching. |
| |
| Returns: |
| A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. |
| """ |
| residual = hidden_states |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
| hidden_states, attn_weights, past_key_values = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| past_key_value=past_key_values, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.final_layer_norm(hidden_states) |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
| hidden_states = self.fc2(hidden_states) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
|
|
| if hidden_states.dtype == torch.float16 and ( |
| torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| ): |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| if use_cache: |
| outputs += (past_key_values,) |
|
|
| return outputs |
|
|
|
|
| |
| class MiniCPMWhisperEncoder(WhisperEncoder): |
|
|
| def __init__(self, config: WhisperConfig): |
| super().__init__(config) |
| self.layers = nn.ModuleList( |
| [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] |
| ) |
|
|
| def forward( |
| self, |
| input_features, |
| attention_mask=None, |
| head_mask=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| past_key_values: Optional[EncoderDecoderCache] = None, |
| use_cache: Optional[bool] = None, |
| ): |
| r""" |
| Forward pass of the Whisper encoder. |
| |
| Args: |
| input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): |
| Float values of log-mel features extracted from the raw audio waveform. Typically generated |
| by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` |
| files into padded 2D mel spectrogram frames. These features are projected via convolution layers |
| (`conv1` and `conv2`) and then transformed into embeddings for the encoder. |
| |
| attention_mask (`torch.Tensor`, *optional*): |
| Not used by Whisper for masking `input_features`, but included for API compatibility with |
| other models. If provided, it is simply ignored within the model. By default, Whisper |
| effectively ignores silence in the input log-mel spectrogram. |
| |
| head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): |
| Mask to nullify selected attention heads. The elements should be either 1 or 0, where: |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked** (i.e., the attention head is dropped). |
| |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attention tensors of all encoder layers. If set to `True`, the |
| returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with |
| attention weights for each encoder layer. |
| |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. If set to `True`, the returned |
| tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the |
| initial embedding output as well as the outputs of each layer. |
| |
| return_dict (`bool`, *optional*): |
| Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead |
| of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, |
| otherwise it will be a tuple. |
| |
| past_key_values (`EncoderDecoderCache`, *optional*): |
| When using caching for faster inference, this is an object that stores the key-value pairs |
| for attention states. If provided, the model will append new states to the existing cache |
| and return the updated cache. This speeds up sequential decoding or chunked inference. |
| |
| - If `past_key_values` is `None`, no past states are used or returned. |
| - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided |
| cache and return the updated cache (as `next_encoder_cache`). |
| |
| use_cache (`bool`, *optional*): |
| Whether or not the model should use caching (`past_key_values`) to speed up processing |
| during inference. When set to `True`, the model will: |
| - Inspect and use `past_key_values` if provided. |
| - Return updated `past_key_values` (under the name `next_encoder_cache` in |
| `BaseModelOutputWithPast`). |
| |
| Returns: |
| `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): |
| If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: |
| - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| The output of the final encoder layer. |
| - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): |
| Hidden states of the model at each layer (including the initial projection). |
| - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): |
| Attention weights from each encoder layer. |
| - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): |
| Updated cache of key-value pairs if `use_cache=True`. |
| |
| If `return_dict=False`, a tuple is returned, where the format is: |
| `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` |
| only present if their respective `output_*` arguments are set to `True`. |
| |
| Example: |
| >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration |
| >>> import torch |
| |
| >>> # Load a feature extractor and a Whisper model |
| >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") |
| >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
| |
| >>> # Assume you have audio (list of floats or numpy array) loaded from a file |
| >>> # Then extract the mel features: |
| >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features |
| |
| >>> # Forward pass |
| >>> outputs = model.encoder( |
| ... input_features=input_features, |
| ... output_hidden_states=True, |
| ... output_attentions=True, |
| ... use_cache=True |
| ... ) |
| |
| >>> # Retrieve the last hidden state |
| >>> last_hidden_state = outputs.last_hidden_state |
| >>> print(last_hidden_state.shape) |
| torch.Size([batch_size, seq_length, hidden_size]) |
| |
| >>> # Retrieve the intermediate hidden states if output_hidden_states=True |
| >>> all_encoder_hidden_states = outputs.hidden_states |
| |
| >>> # Retrieve attention weights if output_attentions=True |
| >>> all_encoder_attentions = outputs.attentions |
| |
| >>> # Retrieve updated past key values if use_cache=True |
| >>> encoder_cache = outputs.past_key_values |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) |
|
|
| inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
| inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
|
|
| inputs_embeds = inputs_embeds.permute(0, 2, 1) |
|
|
| embed_pos = self.embed_positions.weight |
| past_key_values_length = 0 |
| if use_cache: |
| if past_key_values is None: |
| past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| elif isinstance(past_key_values, list): |
| past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) |
| elif isinstance(past_key_values, DynamicCache): |
| past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| else: |
| pass |
| past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) |
| if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: |
| logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") |
| embed_pos_front = embed_pos[past_key_values_length:, :] |
| embed_pos = torch.cat( |
| ( |
| embed_pos_front, |
| torch.repeat_interleave( |
| embed_pos[-1, :].unsqueeze(0), |
| inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, |
| dim=0, |
| ), |
| ) |
| ) |
| else: |
| embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] |
| else: |
| embed_pos = embed_pos[: inputs_embeds.shape[1], :] |
|
|
| hidden_states = inputs_embeds + embed_pos |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| encoder_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| |
| if head_mask is not None: |
| assert head_mask.size()[0] == ( |
| len(self.layers) |
| ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
|
| for idx, encoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
| |
| to_drop = False |
| if self.training: |
| dropout_probability = torch.rand([]) |
| if dropout_probability < self.layerdrop: |
| to_drop = True |
|
|
| |
| if to_drop: |
| layer_outputs = (None, None) |
| else: |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| encoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| (head_mask[idx] if head_mask is not None else None), |
| output_attentions, |
| past_key_values, |
| use_cache, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| output_attentions=output_attentions, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_encoder_cache = layer_outputs[2 if output_attentions else 1] |
| else: |
| next_encoder_cache = None |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| hidden_states = self.layer_norm(hidden_states) |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| hidden_states=encoder_states, |
| attentions=all_attentions, |
| past_key_values=next_encoder_cache, |
| ) |
|
|
|
|
| |
| class ConvNeXtBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| intermediate_dim: int, |
| kernel: int, |
| dilation: int, |
| layer_scale_init_value: float = 1e-6, |
| ): |
| |
| super().__init__() |
| self.dwconv = nn.Conv1d( |
| dim, |
| dim, |
| kernel_size=kernel, |
| padding=dilation * (kernel // 2), |
| dilation=dilation, |
| groups=dim, |
| ) |
|
|
| self.norm = nn.LayerNorm(dim, eps=1e-6) |
| self.pwconv1 = nn.Linear(dim, intermediate_dim) |
| self.act = nn.GELU() |
| self.pwconv2 = nn.Linear(intermediate_dim, dim) |
| self.coef = ( |
| nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) |
| if layer_scale_init_value > 0 |
| else None |
| ) |
|
|
| def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: |
| residual = x |
|
|
| y = self.dwconv(x) |
| y.transpose_(1, 2) |
| x = self.norm(y) |
| del y |
| y = self.pwconv1(x) |
| del x |
| x = self.act(y) |
| del y |
| y = self.pwconv2(x) |
| del x |
| if self.coef is not None: |
| y *= self.coef |
| y.transpose_(1, 2) |
|
|
| x = y + residual |
| del y |
|
|
| return x |
|
|
|
|
| |
| class GFSQ(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| levels: List[int], |
| G: int, |
| R: int, |
| eps=1e-5, |
| transpose=True, |
| ): |
| super(GFSQ, self).__init__() |
| self.quantizer = GroupedResidualFSQ( |
| dim=dim, |
| levels=list(levels), |
| num_quantizers=R, |
| groups=G, |
| ) |
| self.n_ind = math.prod(levels) |
| self.eps = eps |
| self.transpose = transpose |
| self.G = G |
| self.R = R |
|
|
| def _embed(self, x: torch.Tensor): |
| if self.transpose: |
| x = x.transpose(1, 2) |
| x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) |
| feat = self.quantizer.get_output_from_indices(x) |
| return feat.transpose_(1, 2) if self.transpose else feat |
|
|
| def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| return super().__call__(x) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.transpose: |
| x.transpose_(1, 2) |
| _, ind = self.quantizer(x) |
| ind = ind.permute(1, 2, 0, 3).contiguous() |
| ind = ind.view(ind.size(0), ind.size(1), -1) |
| return ind.transpose_(1, 2) if self.transpose else ind |
|
|
|
|
| |
| class DVAEDecoder(nn.Module): |
| def __init__( |
| self, |
| idim: int, |
| odim: int, |
| n_layer=12, |
| bn_dim=64, |
| hidden=256, |
| kernel=7, |
| dilation=2, |
| up=False, |
| ): |
| super().__init__() |
| self.up = up |
| self.conv_in = nn.Sequential( |
| nn.Conv1d(idim, bn_dim, 3, 1, 1), |
| nn.GELU(), |
| nn.Conv1d(bn_dim, hidden, 3, 1, 1), |
| ) |
| self.decoder_block = nn.ModuleList( |
| [ |
| ConvNeXtBlock( |
| hidden, |
| hidden * 4, |
| kernel, |
| dilation, |
| ) |
| for _ in range(n_layer) |
| ] |
| ) |
| self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) |
|
|
| def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: |
| |
| y = self.conv_in(x) |
| del x |
| for f in self.decoder_block: |
| y = f(y, conditioning) |
|
|
| x = self.conv_out(y) |
| del y |
| return x |
|
|
|
|
| |
| class DVAE(nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
|
|
| coef = torch.rand(100) |
| self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2)) |
|
|
| self.downsample_conv = nn.Sequential( |
| nn.Conv1d(100, 512, 3, 1, 1), |
| nn.GELU(), |
| nn.Conv1d(512, 512, 4, 2, 1), |
| nn.GELU(), |
| ) |
|
|
| self.encoder = DVAEDecoder( |
| idim=512, |
| odim=1024, |
| hidden=256, |
| n_layer=12, |
| bn_dim=128, |
| ) |
|
|
| self.decoder = DVAEDecoder( |
| idim=512, |
| odim=512, |
| hidden=256, |
| n_layer=12, |
| bn_dim=128, |
| ) |
|
|
| self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False) |
|
|
| self.vq_layer = GFSQ( |
| dim=1024, |
| levels=(5, 5, 5, 5), |
| G=2, |
| R=2, |
| ) |
|
|
| @torch.inference_mode() |
| def forward(self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode") -> torch.Tensor: |
| if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: |
| mel = inp.clone() |
| x: torch.Tensor = self.downsample_conv( |
| torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), |
| ).unsqueeze_(0) |
| del mel |
| x = self.encoder(x) |
| ind = self.vq_layer(x) |
| del x |
| return ind |
|
|
| if self.vq_layer is not None: |
| vq_feats = self.vq_layer._embed(inp) |
| else: |
| vq_feats = inp |
|
|
| vq_feats = ( |
| vq_feats.view( |
| (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), |
| ) |
| .permute(0, 2, 3, 1) |
| .flatten(2) |
| ) |
|
|
| dec_out = self.out_conv( |
| self.decoder( |
| x=vq_feats, |
| ), |
| ) |
|
|
| del vq_feats |
|
|
| return torch.mul(dec_out, self.coef, out=dec_out) |
|
|
|
|
| def apply_spk_emb( |
| input_ids: torch.Tensor = None, |
| spk_emb: torch.Tensor = None, |
| input_embeds: torch.Tensor = None, |
| spk_emb_token_id: int = 0, |
| num_spk_embs: int = 1, |
| ): |
| """ |
| Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. |
| |
| Args: |
| input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max] |
| spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim] |
| input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim] |
| spk_emb_token_id (int): ID of the speaker embedding token |
| num_spk_embs (int): Number of speaker embeddings |
| |
| Returns: |
| None |
| """ |
|
|
| batch_size = input_ids.shape[0] |
|
|
| for idx in range(batch_size): |
| input_ids_ = input_ids[idx] |
| spk_emb_ = spk_emb[idx] |
| mask_ = input_ids_ == spk_emb_token_id |
| nonzero_position_idx = mask_.nonzero(as_tuple=False) |
| assert nonzero_position_idx.shape[0] == num_spk_embs |
| begin_idx = nonzero_position_idx.min() |
| end_idx = nonzero_position_idx.max() |
| input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_ |
|
|
| return |
|
|
|
|
| def make_streaming_chunk_mask_generation( |
| inputs_embeds: torch.Tensor, |
| past_seen_tokens: int, |
| streaming_tts_text_mask: torch.Tensor, |
| streaming_reserved_length: int = 300, |
| streaming_audio_chunk_size: int = 50, |
| streaming_text_chunk_size: int = 10, |
| num_spk_emb: int = 1, |
| use_spk_emb: bool = True, |
| ) -> torch.Tensor: |
| """ |
| In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens. |
| |
| This function creates a mask that allows the model to attend to a specific chunk of text |
| tokens when generating each chunk of audio tokens, enabling streaming TTS generation. |
| |
| Args: |
| inputs_embeds (torch.Tensor): Input embeddings tensor. |
| past_seen_tokens (int): Number of tokens already seen by the model. |
| streaming_tts_text_mask (torch.Tensor): Mask for the text tokens. |
| streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300. |
| streaming_chunk_length (int, optional): Length of each streaming chunk. Defaults to 50. |
| streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7. |
| |
| Returns: |
| torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1] |
| |
| Raises: |
| AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference). |
| """ |
| assert inputs_embeds.shape[0] == 1 |
|
|
| dtype = inputs_embeds.dtype |
| device = inputs_embeds.device |
| min_dtype = torch.finfo(dtype).min |
|
|
| |
| causal_mask = torch.full((1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device) |
|
|
| |
| invisible_text_tokens_start = ( |
| min( |
| math.ceil((past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size) |
| * streaming_text_chunk_size, |
| streaming_reserved_length, |
| ) |
| + 1 |
| + num_spk_emb * use_spk_emb |
| ) |
|
|
| invisible_text_tokens_end = ( |
| streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 |
| ) |
|
|
| |
| causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype |
|
|
| |
| causal_mask[0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_( |
| streaming_tts_text_mask == 0, min_dtype |
| ) |
|
|
| |
| causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) |
|
|
| return causal_mask |
|
|
|
|
| |
| class CustomRepetitionPenaltyLogitsProcessorRepeat: |
| def __init__(self, penalty: float, max_input_ids: int, past_window: int): |
| if not isinstance(penalty, float) or not (penalty > 0): |
| raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") |
|
|
| self.penalty = penalty |
| self.max_input_ids = max_input_ids |
| self.past_window = past_window |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| if input_ids.size(1) > self.past_window: |
| input_ids = input_ids.narrow(1, -self.past_window, self.past_window) |
| freq = F.one_hot(input_ids, scores.size(1)).sum(1) |
| if freq.size(0) > self.max_input_ids: |
| freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_() |
| alpha = torch.pow(self.penalty, freq) |
| scores = scores.contiguous() |
| inp = scores.multiply(alpha) |
| oth = scores.divide(alpha) |
| con = scores < 0 |
| out = torch.where(con, inp, oth) |
| del inp, oth, scores, con, alpha |
| return out |
|
|
|
|
| @dataclass |
| class ConditionalChatTTSGenerationOutput(ModelOutput): |
| """ |
| Output class for ConditionalChatTTS generation. |
| |
| Args: |
| new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). |
| audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). |
| past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). |
| finished (bool): Boolean indicating whether generation is complete. |
| |
| """ |
|
|
| new_ids: torch.LongTensor = None |
| audio_input_ids: torch.LongTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| finished: bool = None |
|
|
|
|
| class MultiModalProjector(nn.Module): |
| def __init__(self, in_dim, out_dim): |
| super().__init__() |
| self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) |
| self.relu = nn.ReLU() |
| self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) |
|
|
| def forward(self, audio_features): |
| hidden_states = self.relu(self.linear1(audio_features)) |
| hidden_states = self.linear2(hidden_states) |
| return hidden_states |
|
|
|
|
| class ConditionalChatTTS(PreTrainedModel): |
| """A conditional text-to-speech model that can generate speech from text with speaker conditioning. |
| |
| This model extends PreTrainedModel to provide text-to-speech capabilities with: |
| - LLM hidden state conditioning |
| - Streaming generation |
| |
| The model uses a transformer architecture with LLM hidden states and can operate in both |
| streaming and non-streaming modes for flexible deployment. |
| |
| The model process sequence in the following format: |
| | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token | |
| |
| The format is designed to support LLM-conditioned streaming audio generation. |
| |
| Usage: |
| To support streaming generation, two global variables should be maintained outside of the model. |
| 1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq]. |
| 2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads] |
| |
| where `num_vq` is the number of audio codebooks, in default setting, it is `4`. |
| |
| 1. Create an empty `past_key_values` with |
| ```python |
| initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token |
| dtype = model.emb_text.weight.dtype |
| device = model.emb_text.weight.device |
| past_key_values = [ |
| ( |
| torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), |
| torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) |
| ) |
| for _ in range(model.config.num_hidden_layers) |
| ] |
| |
| 2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder. |
| |
| ```python |
| initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1 |
| # [bos token, speaker embeddings, text tokens, audio bos token] |
| audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq) |
| ``` |
| |
| 2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method. |
| |
| ```python |
| outputs = llm.generate(**kwargs) |
| llm_tokens = some_function_to_extract_llm_tokens(outputs) |
| lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs) |
| tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) |
| # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens. |
| begin = 0 |
| end = 9+1 |
| position_ids = torch.arange(begin, end, dtype=torch.long, device=device) |
| |
| past_key_values = model.prefill_text( |
| input_ids=tts_text_input_ids, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
| ) |
| ``` |
| |
| 3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention. |
| |
| ```python |
| streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length) |
| streaming_tts_text_mask[0:end] = 1 # denotes these post |
| ``` |
| |
| 3. Generate audio codes using `generate` method. |
| |
| ```python |
| outputs = model.generate( |
| input_ids=audio_input_ids, |
| past_key_values=past_key_values, |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| max_new_token=50, |
| ) |
| |
| # update past_key_values and input_ids |
| past_key_values = outputs.past_key_values |
| audio_input_ids = outputs.input_ids |
| ``` |
| |
| The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling. |
| |
| 4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response. |
| |
| 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above. |
| """ |
|
|
| config_class = ConditionalChatTTSConfig |
| _no_split_modules = [] |
|
|
| def __init__(self, config: ConditionalChatTTSConfig): |
| super().__init__(config) |
|
|
| self.use_speaker_embedding = config.use_speaker_embedding |
| self.use_llm_hidden_state = config.use_llm_hidden_state |
| self.num_spk_embs = config.num_spk_embs |
| self.spk_emb_token_id = config.spk_emb_token_id |
|
|
| self.use_text = config.use_text |
| self.streaming = config.streaming |
| self.streaming_text_chunk_size = config.streaming_text_chunk_size |
| self.streaming_audio_chunk_size = config.streaming_audio_chunk_size |
| self.streaming_text_reserved_len = config.streaming_text_reserved_len |
| self.audio_bos_token_id = config.audio_bos_token_id |
| self.num_mel_bins = config.num_mel_bins |
| self.num_vq = config.num_vq |
| self.num_audio_tokens = config.num_audio_tokens |
|
|
| self.top_p = config.top_p |
| self.top_k = config.top_k |
| self.repetition_penalty = config.repetition_penalty |
|
|
| if self.config.use_mlp: |
| self.projector = MultiModalProjector(config.llm_dim, config.hidden_size) |
| else: |
| self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False) |
| self.emb_code = nn.ModuleList( |
| [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)] |
| ) |
| self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) |
| self.head_code = nn.ModuleList( |
| [ |
| weight_norm( |
| nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), |
| name="weight", |
| ) |
| for _ in range(config.num_vq) |
| ] |
| ) |
| dvae = DVAE() |
| self.dvae = dvae |
|
|
| model_config = LlamaConfig( |
| hidden_size=config.hidden_size, |
| intermediate_size=config.intermediate_size, |
| num_attention_heads=config.num_attention_heads, |
| num_hidden_layers=config.num_hidden_layers, |
| max_position_embeddings=config.max_position_embeddings, |
| attn_implementation=config.attn_implementation, |
| ) |
|
|
| model = LlamaModel(model_config) |
| self.model = model |
|
|
| @torch.inference_mode() |
| def merge_inputs_embeds( |
| self, |
| input_ids: torch.Tensor, |
| lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
| ): |
| """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`. |
| |
| Args: |
| input_ids (torch.Tensor): Input token IDs. |
| lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None. |
| |
| Raises: |
| NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented. |
| |
| Returns: |
| torch.Tensor: Prepared input embeddings for the model. |
| """ |
| assert input_ids.shape[0] == 1 |
|
|
| |
| inputs_embeds = self.emb_text(input_ids) |
|
|
| |
| if self.use_speaker_embedding: |
| spk_emb_mask = input_ids == self.spk_emb_token_id |
| if spk_emb_mask.any(): |
| assert lm_spk_emb_last_hidden_states is not None |
| |
| lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(self.projector.linear1.weight.dtype) |
| projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states) |
| projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1) |
| apply_spk_emb( |
| input_ids=input_ids, |
| spk_emb=projected_spk_emb, |
| input_embeds=inputs_embeds, |
| spk_emb_token_id=self.spk_emb_token_id, |
| num_spk_embs=self.num_spk_embs, |
| ) |
| else: |
| raise NotImplementedError |
|
|
| return inputs_embeds |
|
|
| @torch.inference_mode() |
| def prefill_text( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.LongTensor, |
| past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
| ): |
| """Prefill a chunk of new text tokens in streaming setting. |
| Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens. |
| |
| Args: |
| input_ids (Tensor): Tensor of shape [batch_size, seq_len] |
| position_ids (LongTensor): Tensor of shape [batch_size, seq_len] |
| past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated. |
| lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None. |
| lm_last_hidden_states (Tensor, optional): _description_. Defaults to None. |
| |
| Note that all `batch_size` should be `1`. |
| """ |
| assert input_ids.shape[0] == 1 |
| assert past_key_values is not None |
|
|
| |
| inputs_embeds = self.merge_inputs_embeds( |
| input_ids=input_ids, |
| lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
| ) |
|
|
| |
| past_key_values_for_prefill = [] |
| for i in range(len(past_key_values)): |
| past_key_values_for_prefill.append( |
| ( |
| past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(), |
| past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(), |
| ) |
| ) |
|
|
| |
| outputs_prefill: BaseModelOutputWithPast = self.model( |
| attention_mask=None, |
| position_ids=position_ids, |
| past_key_values=past_key_values_for_prefill, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| output_attentions=False, |
| cache_position=position_ids, |
| ) |
|
|
| |
| past_key_values_for_prefill_updated = outputs_prefill.past_key_values |
|
|
| |
| for layer_idx in range(len(past_key_values)): |
| |
| past_key_values[layer_idx][0][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( |
| past_key_values_for_prefill_updated[layer_idx][0][ |
| :, :, position_ids[:, 0] : position_ids[:, -1] + 1 |
| ].clone() |
| ) |
| |
| past_key_values[layer_idx][1][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( |
| past_key_values_for_prefill_updated[layer_idx][1][ |
| :, :, position_ids[:, 0] : position_ids[:, -1] + 1 |
| ].clone() |
| ) |
|
|
| |
| |
|
|
| return past_key_values |
|
|
| @torch.inference_mode() |
| def prefill_audio_ids( |
| self, |
| input_ids: torch.Tensor, |
| past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| streaming_tts_text_mask=None, |
| add_audio_bos: bool = True, |
| ): |
| """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation. |
| Specifically, prefill many audio ids (typically from last window) to the model in the new window. |
| |
| Args: |
| input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids. |
| past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. |
| """ |
| assert input_ids.shape[0] == 1 |
| assert past_key_values is not None |
|
|
| code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)] |
| inputs_embeds = torch.stack(code_emb, 3).sum(3) |
| input_len = input_ids.shape[1] |
|
|
| if add_audio_bos: |
| narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) |
| bos_inputs_embeds = self.emb_text(narrowed_input_ids) |
| inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1) |
| input_len += 1 |
|
|
| past_key_values_length = past_key_values[0][0].shape[2] |
| position_ids = torch.arange( |
| past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device |
| ).unsqueeze(0) |
|
|
| cache_position = position_ids.clone() |
| causal_mask = make_streaming_chunk_mask_generation( |
| inputs_embeds=inputs_embeds, |
| past_seen_tokens=past_key_values[0][0].shape[2], |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| streaming_reserved_length=self.streaming_text_reserved_len, |
| streaming_text_chunk_size=self.streaming_text_chunk_size, |
| ) |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| output_attentions=False, |
| cache_position=cache_position, |
| ) |
| past_key_values = outputs.past_key_values |
| return past_key_values |
|
|
| @torch.inference_mode() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| temperature: torch.Tensor, |
| eos_token: Union[int, torch.Tensor], |
| streaming_tts_text_mask=None, |
| force_no_stop=False, |
| min_new_token=10, |
| max_new_token=50, |
| logits_warpers: List[LogitsWarper] = [], |
| logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], |
| show_tqdm=False, |
| ): |
| """Generate audio codes in streaming setting or non-streaming setting. |
| Specifically speaking, generate audio codes when not all text tokens are prefilled. |
| |
| Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details. |
| |
| In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`. |
| |
| Args: |
| input_ids (torch.Tensor): Input token ids. |
| past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. |
| temperature (torch.Tensor): Temperature for sampling. |
| eos_token (Union[int, torch.Tensor]): End of sequence token. |
| streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None. |
| max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50. |
| logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. |
| logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. |
| show_tqdm (bool, optional): Whether to show progress bar. Defaults to True. |
| |
| Returns: |
| GenerationOutputs: Generation outputs. |
| """ |
|
|
| |
| assert input_ids.shape[0] == 1 |
| assert past_key_values is not None |
|
|
| |
| |
| start_idx = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
|
|
| finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool() |
|
|
| temperature = temperature.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous().view(-1, 1) |
|
|
| progress = input_ids.shape[1] |
|
|
| |
| input_ids_buf = torch.zeros( |
| input_ids.shape[0], |
| progress + max_new_token, |
| input_ids.shape[2], |
| dtype=input_ids.dtype, |
| device=input_ids.device, |
| ) |
|
|
| |
| input_ids_buf.narrow(1, 0, progress).copy_(input_ids) |
|
|
| del input_ids |
| input_ids = input_ids_buf.narrow(1, 0, progress) |
|
|
| pbar: Optional[tqdm] = None |
| if show_tqdm: |
| pbar = tqdm( |
| total=max_new_token, |
| desc="code", |
| bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", |
| ) |
|
|
| condition_length = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
|
|
| for i in range(max_new_token): |
| |
| audio_bos = False |
|
|
| |
| if progress == condition_length: |
| audio_bos = True |
|
|
| assert progress == ( |
| past_key_values[0][0].shape[2] + 1 |
| ) |
|
|
| if audio_bos: |
| |
| narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) |
| inputs_embeds = self.emb_text(narrowed_input_ids) |
| del narrowed_input_ids |
| else: |
| |
| narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1) |
| code_emb = [self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq)] |
| inputs_embeds = torch.stack(code_emb, 3).sum(3) |
|
|
| position_ids = torch.tensor( |
| [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device |
| ).unsqueeze(0) |
|
|
| cache_position = position_ids.clone() |
|
|
| |
| causal_mask = make_streaming_chunk_mask_generation( |
| inputs_embeds=inputs_embeds, |
| past_seen_tokens=past_key_values[0][0].shape[2], |
| streaming_tts_text_mask=streaming_tts_text_mask, |
| streaming_reserved_length=self.streaming_text_reserved_len, |
| streaming_text_chunk_size=self.streaming_text_chunk_size, |
| ) |
|
|
| |
| outputs: BaseModelOutputWithPast = self.model( |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| output_attentions=False, |
| cache_position=cache_position, |
| ) |
|
|
| del position_ids |
| del inputs_embeds |
| del cache_position |
| del causal_mask |
|
|
| hidden_states = outputs.last_hidden_state |
| past_key_values = outputs.past_key_values |
|
|
| with P.cached(): |
| logits = torch.empty( |
| hidden_states.size(0), |
| hidden_states.size(1), |
| self.num_audio_tokens, |
| self.num_vq, |
| dtype=torch.float, |
| device=self.device, |
| ) |
| for num_vq_iter in range(self.num_vq): |
| x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
| logits[..., num_vq_iter] = x |
| del x |
|
|
| del hidden_states |
|
|
| |
| logits = logits.narrow(1, -1, 1).squeeze_(1).float() |
|
|
| |
| logits = logits.permute(0, 2, 1) |
| logits = logits.reshape(-1, logits.size(2)) |
| |
| input_ids_sliced = input_ids.narrow( |
| 1, |
| start_idx, |
| input_ids.size(1) - start_idx, |
| ).permute(0, 2, 1) |
| logits_token = input_ids_sliced.reshape( |
| input_ids_sliced.size(0) * input_ids_sliced.size(1), |
| -1, |
| ).to(self.device) |
| del input_ids_sliced |
|
|
| logits /= temperature |
|
|
| if not audio_bos: |
| for logitsProcessors in logits_processors: |
| logits = logitsProcessors(logits_token, logits) |
| if not audio_bos: |
| for logitsWarpers in logits_warpers: |
| logits = logitsWarpers(logits_token, logits) |
|
|
| del logits_token |
|
|
| if i < min_new_token: |
| logits[:, eos_token] = -torch.inf |
|
|
| if force_no_stop: |
| logits[:, eos_token] = -torch.inf |
|
|
| scores = F.softmax(logits, dim=-1) |
|
|
| del logits |
| idx_next = torch.multinomial(scores, num_samples=1) |
|
|
| del scores |
|
|
| |
| idx_next = idx_next.view(-1, self.num_vq) |
| finish_or = idx_next.eq(eos_token).any(1) |
| finish.logical_or_(finish_or) |
|
|
| del finish_or |
| |
| input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) |
|
|
| if i == 0 and finish.any(): |
| |
| break |
|
|
| del idx_next |
| progress += 1 |
| input_ids = input_ids_buf.narrow(1, 0, progress) |
|
|
| if finish.all(): |
| break |
|
|
| if pbar is not None: |
| pbar.update(1) |
|
|
| if pbar is not None: |
| pbar.close() |
|
|
| if not finish.all(): |
| if show_tqdm: |
| logger.info(f"incomplete result. hit max_new_token: {max_new_token}") |
|
|
| del input_ids_buf |
|
|
| if finish.all(): |
| |
| genrated_input_ids = input_ids[:, condition_length:-1, :] |
| else: |
| |
| genrated_input_ids = input_ids[:, condition_length:, :] |
|
|
| return ConditionalChatTTSGenerationOutput( |
| new_ids=genrated_input_ids, |
| audio_input_ids=input_ids, |
| past_key_values=past_key_values, |
| finished=finish.all(), |
| ) |
|
|
| @torch.inference_mode() |
| def decode_to_mel_specs( |
| self, |
| result_list: List[torch.Tensor], |
| ): |
| """Decode discrete audio codes to mel spectrograms. |
| |
| Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py` |
| |
| Args: |
| result_list (List[torch.Tensor]): Audio codes output from `generate`. |
| |
| Returns: |
| torch.Tensor: Mel spectrograms. |
| """ |
|
|
| decoder = self.dvae |
| max_x_len = -1 |
| if len(result_list) == 0: |
| return np.array([], dtype=np.float32) |
| for result in result_list: |
| if result.size(0) > max_x_len: |
| max_x_len = result.size(0) |
| batch_result = torch.zeros( |
| (len(result_list), result_list[0].size(1), max_x_len), |
| dtype=result_list[0].dtype, |
| device=result_list[0].device, |
| ) |
| for i in range(len(result_list)): |
| src = result_list[i] |
| batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0)) |
| del src |
|
|
| mel_specs = decoder(batch_result) |
| del batch_result |
| return mel_specs |
|
|
|
|
| |
| def gen_logits( |
| num_code: int, |
| top_P=0.7, |
| top_K=20, |
| repetition_penalty=1.0, |
| ): |
| logits_warpers = [] |
| if top_P is not None: |
| logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) |
| if top_K is not None: |
| logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) |
|
|
| logits_processors = [] |
| if repetition_penalty is not None and repetition_penalty != 1: |
| logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16)) |
|
|
| return logits_warpers, logits_processors |
|
|
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| cache_length = past_key_values.get_seq_length() |
| past_length = past_key_values.seen_tokens |
| else: |
| cache_length = past_length = past_key_values[0][0].shape[2] |
|
|
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| |
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
| |
| if inputs_embeds is not None and cache_position[0] == 0: |
| model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
| else: |
| |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
| if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: |
| if model_inputs["inputs_embeds"] is not None: |
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
| device = model_inputs["inputs_embeds"].device |
| else: |
| batch_size, sequence_length = model_inputs["input_ids"].shape |
| device = model_inputs["input_ids"].device |
|
|
| dtype = self.lm_head.weight.dtype |
| min_dtype = torch.finfo(dtype).min |
|
|
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=past_key_values.get_max_length(), |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=batch_size, |
| ) |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|