| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import logging |
| | import math |
| | import os |
| | import types |
| | from collections.abc import Iterator |
| | from copy import deepcopy |
| | from dataclasses import dataclass |
| | from threading import Thread |
| | from typing import List |
| | from typing import Literal |
| | from typing import Optional |
| | from typing import Tuple |
| | from typing import Union |
| |
|
| | import numpy as np |
| | import soundfile as sf |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.nn.utils.parametrize as P |
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| | from torch.nn.utils.parametrizations import weight_norm |
| | from tqdm import tqdm |
| | from transformers import AutoProcessor |
| | from transformers import BertTokenizerFast |
| | from transformers import LlamaConfig |
| | from transformers import LlamaModel |
| | from transformers import LogitsWarper |
| | from transformers import PreTrainedModel |
| | from transformers import Qwen2ForCausalLM |
| | from transformers import Qwen2PreTrainedModel |
| | from transformers import TextIteratorStreamer |
| | from transformers import TopKLogitsWarper |
| | from transformers import TopPLogitsWarper |
| | from transformers.cache_utils import Cache |
| | from transformers.cache_utils import DynamicCache |
| | from transformers.cache_utils import EncoderDecoderCache |
| | from transformers.cache_utils import StaticCache |
| | from transformers.modeling_outputs import BaseModelOutputWithPast |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.models.whisper.modeling_whisper import ACT2FN |
| | from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES |
| | from transformers.models.whisper.modeling_whisper import WhisperConfig |
| | from transformers.models.whisper.modeling_whisper import WhisperEncoder |
| |
|
| | try: |
| | from vector_quantize_pytorch import GroupedResidualFSQ |
| | from vocos import Vocos |
| | from vocos.pretrained import instantiate_class |
| |
|
| | _tts_deps = True |
| | except: |
| | _tts_deps = False |
| |
|
| | from .configuration_minicpm import ConditionalChatTTSConfig |
| | from .configuration_minicpm import MiniCPMOConfig |
| | from .modeling_navit_siglip import SiglipVisionTransformer |
| | from .resampler import Resampler |
| | from .utils import NumberToTextConverter |
| | from .utils import sentence_end |
| | from .utils import VoiceChecker |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class OmniOutput(ModelOutput): |
| | text: Optional[Union[str, List[str], Iterator]] = None |
| | spk_embeds: Optional[torch.FloatTensor] = None |
| | audio_wav: Optional[np.ndarray] = None |
| | sampling_rate: Optional[int] = None |
| |
|
| |
|
| | class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel): |
| | config_class = MiniCPMOConfig |
| |
|
| |
|
| | class MiniCPMO(MiniCPMOPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.llm = Qwen2ForCausalLM(config) |
| | self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) |
| |
|
| | self.embed_dim = self.llm.config.hidden_size |
| |
|
| | |
| | if self.config.init_vision: |
| | self.vpm = self.init_vision_module() |
| | self.vision_dim = self.vpm.embed_dim |
| | self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) |
| |
|
| | |
| | if self.config.init_audio: |
| | self.apm = self.init_audio_module() |
| | audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) |
| | self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) |
| | self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim) |
| | self.audio_encoder_layer = -1 |
| |
|
| | |
| | if self.config.init_tts: |
| | assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed." |
| | self.tts = self.init_tts_module() |
| |
|
| | self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
| |
|
| | self.terminators = ["<|im_end|>", "<|endoftext|>"] |
| |
|
| | self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" |
| | self.force_no_stop = False |
| |
|
| | |
| | self.reset_session() |
| |
|
| | def reset_session(self): |
| | self.session_id = None |
| | self.new_user_msg = True |
| | self.llm_generated = False |
| | self.llm_generate_completed = False |
| | self.llm_past_key_values = None |
| | self.audio_past_key_values = None |
| |
|
| | def init_tts( |
| | self, |
| | tts_text_tokenizer_path=None, |
| | vocos_ckpt_path=None, |
| | ): |
| | """ |
| | load tts tokenizer and vocos |
| | 1. try load form local 2. try load from huggingface |
| | """ |
| | from .processing_minicpmo import ChatTTSProcessor |
| |
|
| | if tts_text_tokenizer_path is None: |
| | tts_text_tokenizer_path = os.path.join(self.config._name_or_path, "assets/chattts_tokenizer") |
| | if not os.path.exists(tts_text_tokenizer_path): |
| | |
| | tts_text_tokenizer_path = "openbmb/chattts_tokenizer" |
| |
|
| | tts_text_tokenizer = BertTokenizerFast.from_pretrained(tts_text_tokenizer_path) |
| | self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer) |
| |
|
| | if vocos_ckpt_path is None: |
| | vocos_ckpt_path = os.path.join(self.config._name_or_path, "assets/Vocos.pt") |
| | if not os.path.exists(vocos_ckpt_path): |
| | vocos_ckpt_path = hf_hub_download(repo_id="openbmb/MiniCPM-o-2_6", subfolder="assets", filename="Vocos.pt") |
| |
|
| | assert os.path.exists(vocos_ckpt_path) |
| | self.vocos = self.initialize_vocos(vocos_ckpt_path) |
| |
|
| | def initialize_vocos(self, ckpt_path): |
| | feature_extractor = instantiate_class( |
| | args=(), |
| | init={ |
| | "class_path": "vocos.feature_extractors.MelSpectrogramFeatures", |
| | "init_args": {"sample_rate": 24000, "n_fft": 1024, "hop_length": 256, "n_mels": 100}, |
| | }, |
| | ) |
| | backbone = instantiate_class( |
| | args=(), |
| | init={ |
| | "class_path": "vocos.models.VocosBackbone", |
| | "init_args": {"input_channels": 100, "dim": 512, "intermediate_dim": 1536, "num_layers": 8}, |
| | }, |
| | ) |
| | head = instantiate_class( |
| | args=(), |
| | init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}}, |
| | ) |
| | vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32) |
| | vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True)) |
| | return vocos |
| |
|
| | def init_vision_module(self): |
| | if self.config._attn_implementation == "flash_attention_2": |
| | self.config.vision_config._attn_implementation = "flash_attention_2" |
| | else: |
| | self.config.vision_config._attn_implementation = "eager" |
| | model = SiglipVisionTransformer(self.config.vision_config) |
| | if self.config.drop_vision_last_layer: |
| | model.encoder.layers = model.encoder.layers[:-1] |
| |
|
| | setattr(model, "embed_dim", model.embeddings.embed_dim) |
| | setattr(model, "patch_size", model.embeddings.patch_size) |
| |
|
| | return model |
| |
|
| | def init_resampler(self, embed_dim, vision_dim): |
| | return Resampler( |
| | num_queries=self.config.query_num, |
| | embed_dim=embed_dim, |
| | num_heads=embed_dim // 128, |
| | kv_dim=vision_dim, |
| | adaptive=True, |
| | ) |
| |
|
| | def init_audio_module(self): |
| | model = MiniCPMWhisperEncoder(self.config.audio_config) |
| | return model |
| |
|
| | def init_tts_module(self): |
| | model = ConditionalChatTTS(self.config.tts_config) |
| | return model |
| |
|
| | def get_input_embeddings(self): |
| | return self.llm.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | self.llm.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.llm.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.llm.lm_head = new_embeddings |
| |
|
| | def set_decoder(self, decoder): |
| | self.llm = decoder |
| |
|
| | def get_decoder(self): |
| | return self.llm |
| |
|
| | def subsequent_chunk_mask( |
| | self, |
| | size: int, |
| | chunk_size: int, |
| | num_left_chunks: int = -1, |
| | device: torch.device = torch.device("cpu"), |
| | num_lookhead: int = 0, |
| | ) -> torch.Tensor: |
| | """Create mask for subsequent steps (size, size) with chunk size, |
| | this is for streaming encoder |
| | |
| | Args: |
| | size (int): size of mask |
| | chunk_size (int): size of chunk |
| | num_left_chunks (int): number of left chunks |
| | <0: use full chunk |
| | >=0: use num_left_chunks |
| | device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
| | |
| | Returns: |
| | torch.Tensor: mask |
| | |
| | Examples: |
| | >>> subsequent_chunk_mask(4, 2) |
| | [[1, 1, 0, 0], |
| | [1, 1, 0, 0], |
| | [1, 1, 1, 1], |
| | [1, 1, 1, 1]] |
| | """ |
| | ret = torch.zeros(size, size, device=device, dtype=torch.bool) |
| | for i in range(size): |
| | if num_left_chunks < 0: |
| | start = 0 |
| | else: |
| | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) |
| | ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) |
| | ret[i, start:ending] = True |
| | return ret |
| |
|
| | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
| | """ |
| | Computes the output length of the convolutional layers and the output length of the audio encoder |
| | """ |
| | input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 |
| | input_lengths_after_pooling = ( |
| | input_lengths_after_cnn - self.config.audio_pool_step |
| | ) // self.config.audio_pool_step + 1 |
| | input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) |
| |
|
| | return input_lengths_after_cnn, input_lengths_after_pooling |
| |
|
| | def get_vllm_embedding(self, data): |
| | """ |
| | Compute all visual embeddings, and set into llm embeddings. |
| | Args: |
| | data: Dict |
| | tgt_sizes: image size after patch embedding |
| | pixel_values: image features |
| | image_bound: position of each picture corresponding to input_ids |
| | input_ids: full input_ids, include placeholder |
| | Returns: |
| | embedding with vision, vision_hidden_states |
| | """ |
| | if "vision_hidden_states" not in data: |
| | dtype = self.llm.model.embed_tokens.weight.dtype |
| | device = self.llm.model.embed_tokens.weight.device |
| | tgt_sizes = data["tgt_sizes"] |
| | pixel_values_list = data["pixel_values"] |
| | vision_hidden_states = [] |
| | all_pixel_values = [] |
| | img_cnt = [] |
| | for pixel_values in pixel_values_list: |
| | img_cnt.append(len(pixel_values)) |
| | all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) |
| |
|
| | |
| | if all_pixel_values: |
| | tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] |
| | tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) |
| |
|
| | max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) |
| |
|
| | all_pixel_values = torch.nn.utils.rnn.pad_sequence( |
| | all_pixel_values, batch_first=True, padding_value=0.0 |
| | ) |
| | B, L, _ = all_pixel_values.shape |
| | all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) |
| |
|
| | patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) |
| | for i in range(B): |
| | patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True |
| |
|
| | vision_batch_size = self.config.vision_batch_size |
| | all_pixel_values = all_pixel_values.type(dtype) |
| | if B > vision_batch_size: |
| | hs = [] |
| | for i in range(0, B, vision_batch_size): |
| | start_idx = i |
| | end_idx = i + vision_batch_size |
| | tmp_hs = self.vpm( |
| | all_pixel_values[start_idx:end_idx], |
| | patch_attention_mask=patch_attn_mask[start_idx:end_idx], |
| | tgt_sizes=tgt_sizes[start_idx:end_idx], |
| | ).last_hidden_state |
| | hs.append(tmp_hs) |
| | vision_embedding = torch.cat(hs, dim=0) |
| | else: |
| | vision_embedding = self.vpm( |
| | all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes |
| | ).last_hidden_state |
| | vision_embedding = self.resampler(vision_embedding, tgt_sizes) |
| |
|
| | start = 0 |
| | for pixel_values in pixel_values_list: |
| | img_cnt = len(pixel_values) |
| | if img_cnt > 0: |
| | vision_hidden_states.append(vision_embedding[start : start + img_cnt]) |
| | start += img_cnt |
| | else: |
| | vision_hidden_states.append([]) |
| | else: |
| | if self.training: |
| | dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) |
| | tgt_sizes = torch.Tensor( |
| | [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]] |
| | ).type(torch.int32) |
| | dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) |
| | else: |
| | dummy_feature = [] |
| | for _ in range(len(pixel_values_list)): |
| | vision_hidden_states.append(dummy_feature) |
| |
|
| | else: |
| | vision_hidden_states = data["vision_hidden_states"] |
| |
|
| | if hasattr(self.llm.config, "scale_emb"): |
| | vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb |
| | else: |
| | vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) |
| |
|
| | new_vllm_embedding = vllm_embedding.clone() |
| | |
| | vision_hidden_states = [ |
| | i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states |
| | ] |
| | |
| | bs = len(data["input_ids"]) |
| | for i in range(bs): |
| | cur_vs_hs = vision_hidden_states[i] |
| | if len(cur_vs_hs) > 0: |
| | cur_vllm_emb = vllm_embedding[i] |
| | cur_image_bound = data["image_bound"][i] |
| | if len(cur_image_bound) > 0: |
| | image_indices = torch.stack( |
| | [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] |
| | ).to(vllm_embedding.device) |
| |
|
| | new_vllm_embedding[i] = cur_vllm_emb.scatter( |
| | 0, |
| | image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), |
| | cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), |
| | ) |
| |
|
| | elif self.training: |
| | new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0 |
| |
|
| | return new_vllm_embedding, vision_hidden_states |
| |
|
| | def get_audio_embedding_streaming(self, data): |
| | r""" |
| | Extract audio embeddings in a streaming manner using cached key-value pairs. |
| | |
| | This method processes incoming audio features incrementally and stores/updates `past_key_values` |
| | for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended |
| | for streaming scenarios. |
| | |
| | Args: |
| | data (dict): |
| | - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. |
| | - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. |
| | |
| | Returns: |
| | List[List[torch.Tensor]]: audio embeddings |
| | """ |
| | wavforms = data.get("audio_features", []) |
| | audio_feature_lens_raw = data.get("audio_feature_lens", []) |
| |
|
| | |
| | if len(wavforms) > 0: |
| | audio_feature_lens = torch.hstack(audio_feature_lens_raw) |
| | batch_size, _, max_mel_seq_len = wavforms.shape |
| | assert batch_size == 1 |
| | max_seq_len = (max_mel_seq_len - 1) // 2 + 1 |
| |
|
| | if self.audio_past_key_values is not None: |
| | cache_length = self.audio_past_key_values[0][0].shape[2] |
| | apm_max_len = self.apm.embed_positions.weight.shape[0] |
| | if cache_length + max_seq_len >= apm_max_len: |
| | logger.warning( |
| | f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." |
| | ) |
| | self.audio_past_key_values = None |
| |
|
| | audio_outputs = self.apm(wavforms, past_key_values=self.audio_past_key_values, use_cache=True) |
| | audio_states = audio_outputs.last_hidden_state |
| | self.audio_past_key_values = audio_outputs.past_key_values |
| |
|
| | audio_embeds = self.audio_projection_layer(audio_states) |
| |
|
| | audio_embeds = audio_embeds.transpose(1, 2) |
| | audio_embeds = self.audio_avg_pooler(audio_embeds) |
| | audio_embeds = audio_embeds.transpose(1, 2) |
| |
|
| | _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) |
| |
|
| | num_audio_tokens = feature_lens_after_pooling |
| |
|
| | final_audio_embeds = [] |
| | idx = 0 |
| | for i in range(len(audio_feature_lens_raw)): |
| | target_audio_embeds = [] |
| | for _ in range(len(audio_feature_lens_raw[i])): |
| | target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) |
| | idx += 1 |
| | final_audio_embeds.append(target_audio_embeds) |
| | return final_audio_embeds |
| | else: |
| | return [] |
| |
|
| | def get_audio_embedding(self, data, chunk_length=-1, dummy=True): |
| | r""" |
| | Extract full audio embeddings with optional chunk-based attention. |
| | |
| | This method computes embeddings for all audio frames at once, either using full attention (when |
| | `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does |
| | not use key-value caching and is suitable for non-streaming inference. |
| | |
| | Args: |
| | data (dict): |
| | - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. |
| | - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. |
| | chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based |
| | attention (>0) during embedding computation. |
| | |
| | Returns: |
| | List[List[torch.Tensor]]: audio embeddings |
| | """ |
| | |
| | wavforms = data.get("audio_features", []) |
| | audio_feature_lens_raw = data.get("audio_feature_lens", []) |
| |
|
| | |
| | if len(wavforms) > 0: |
| | audio_feature_lens = torch.hstack(audio_feature_lens_raw) |
| | batch_size, _, max_mel_seq_len = wavforms.shape |
| | max_seq_len = (max_mel_seq_len - 1) // 2 + 1 |
| |
|
| | |
| | seq_range = ( |
| | torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device) |
| | .unsqueeze(0) |
| | .expand(batch_size, max_seq_len) |
| | ) |
| | lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) |
| | |
| | padding_mask = seq_range >= lengths_expand |
| |
|
| | audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( |
| | batch_size, 1, max_seq_len, max_seq_len |
| | ) |
| | audio_attention_mask = audio_attention_mask_.to( |
| | dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device |
| | ) |
| |
|
| | if chunk_length > 0: |
| | chunk_num_frame = int(chunk_length * 50) |
| | chunk_mask = self.subsequent_chunk_mask( |
| | size=max_seq_len, |
| | chunk_size=chunk_num_frame, |
| | num_left_chunks=-1, |
| | device=audio_attention_mask_.device, |
| | ) |
| | audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) |
| |
|
| | audio_attention_mask[audio_attention_mask_] = float("-inf") |
| | audio_states = self.apm( |
| | wavforms, output_hidden_states=True, attention_mask=audio_attention_mask |
| | ).hidden_states[self.audio_encoder_layer] |
| | audio_embeds = self.audio_projection_layer(audio_states) |
| |
|
| | audio_embeds = audio_embeds.transpose(1, 2) |
| | audio_embeds = self.audio_avg_pooler(audio_embeds) |
| | audio_embeds = audio_embeds.transpose(1, 2) |
| |
|
| | _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) |
| |
|
| | num_audio_tokens = feature_lens_after_pooling |
| |
|
| | final_audio_embeds = [] |
| | idx = 0 |
| | for i in range(len(audio_feature_lens_raw)): |
| | target_audio_embeds = [] |
| | for _ in range(len(audio_feature_lens_raw[i])): |
| | target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) |
| | idx += 1 |
| | final_audio_embeds.append(target_audio_embeds) |
| | return final_audio_embeds |
| | elif self.training and dummy: |
| | dtype = self.apm.embed_positions.weight.dtype |
| | device = self.apm.embed_positions.weight.device |
| |
|
| | dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) |
| | audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] |
| |
|
| | audio_embeds = self.audio_projection_layer(audio_states) |
| |
|
| | audio_embeds = audio_embeds.transpose(1, 2) |
| | audio_embeds = self.audio_avg_pooler(audio_embeds) |
| | audio_embeds = audio_embeds.transpose(1, 2) |
| | return [audio_embeds] |
| |
|
| | else: |
| | return [] |
| |
|
| | def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False): |
| | """ |
| | Args: |
| | data: |
| | input_embeddings: |
| | chunk_length: whisper use full attention or chunk attention |
| | stream_input: use streaming audio embedding |
| | Returns: |
| | final embeddings with audio feature |
| | """ |
| | if stream_input: |
| | audio_embeddings = self.get_audio_embedding_streaming(data) |
| | else: |
| | audio_embeddings = self.get_audio_embedding(data, chunk_length) |
| |
|
| | bs = len(input_embeddings) |
| | if len(data.get("audio_features", [])) > 0: |
| | assert len(audio_embeddings) == len(input_embeddings) |
| | if len(audio_embeddings) > 0: |
| | audio_bounds = data["audio_bounds"] |
| |
|
| | if self.config.chunk_input: |
| | for i in range(bs): |
| | if not audio_embeddings[i]: |
| | continue |
| | audio_embs = torch.cat(audio_embeddings[i], dim=0).to( |
| | device=input_embeddings.device, dtype=input_embeddings.dtype |
| | ) |
| | audio_start_pos = 0 |
| | for bound in audio_bounds[i]: |
| | audio_len = bound[1] - bound[0] |
| | input_embeddings[i, bound[0] : bound[1]] = audio_embs[ |
| | audio_start_pos : audio_start_pos + audio_len, : |
| | ] |
| | audio_start_pos += audio_len |
| | else: |
| | for i in range(bs): |
| | audio_embs = audio_embeddings[i] |
| | bounds = audio_bounds[i] |
| | for embs, bound in zip(audio_embs, bounds): |
| | audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( |
| | input_embeddings.device |
| | ) |
| |
|
| | if embs.shape[0] != len(audio_indices): |
| | raise ValueError( |
| | f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " |
| | f"to input indices of length {len(audio_indices)}" |
| | ) |
| | input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) |
| | elif self.training: |
| | for i in range(bs): |
| | |
| | input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0 |
| |
|
| | return input_embeddings |
| |
|
| | def forward(self, data, **kwargs): |
| | vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) |
| |
|
| | if self.config.init_audio: |
| | vllm_embedding = self.get_omni_embedding( |
| | data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length |
| | ) |
| |
|
| | position_ids = data["position_ids"] |
| | if position_ids.dtype != torch.int64: |
| | position_ids = position_ids.long() |
| |
|
| | |
| | for key in ["input_ids", "inputs_embeds", "position_ids"]: |
| | if key in kwargs: |
| | del kwargs[key] |
| |
|
| | return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs) |
| |
|
| | def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): |
| | kwargs.pop("output_hidden_states", None) |
| | kwargs.pop("return_dict_in_generate", None) |
| | terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| | outputs = self.llm.generate( |
| | inputs_embeds=inputs_embeds, |
| | pad_token_id=0, |
| | eos_token_id=terminators, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | return_dict_in_generate=True, |
| | **kwargs, |
| | ) |
| | return outputs |
| |
|
| | def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): |
| | terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| | streamer = TextIteratorStreamer(tokenizer=tokenizer) |
| | generation_kwargs = { |
| | "inputs_embeds": inputs_embeds, |
| | "pad_token_id": 0, |
| | "eos_token_id": terminators, |
| | "streamer": streamer, |
| | } |
| | generation_kwargs.update(kwargs) |
| |
|
| | thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) |
| | thread.start() |
| |
|
| | return streamer |
| |
|
| | def _decode_text(self, result_ids, tokenizer): |
| | terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| | result_text = [] |
| | for result in result_ids: |
| | result = result[result != 0] |
| | if result[0] == tokenizer.bos_id: |
| | result = result[1:] |
| | if result[-1] in terminators: |
| | result = result[:-1] |
| | result_text.append(tokenizer.decode(result)) |
| | return result_text |
| |
|
| | def get_sys_prompt(self, ref_audio=None, mode="default", language="zh"): |
| | """ |
| | Choose different system prompts according to different tasks |
| | Args: |
| | ref_audio: if ref_audio is not None, will use the voice cloning prompts, and the voice |
| | generated by the model will refer to the timbre of ref audio |
| | mode: |
| | "default": default system prompt and not refer to any task |
| | "omni": input video and audio simultaneously |
| | "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant. |
| | "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. |
| | "voice_cloning": TTS mode, the model will clone the voice of ref_audio. |
| | language: prompts language, the model has the ability to automatically select the response language |
| | based on the question language |
| | Returns: |
| | |
| | """ |
| | if ref_audio is not None: |
| | assert isinstance(ref_audio, np.ndarray), "ref_audio error" |
| | if mode == "omni": |
| | if language == "zh": |
| | sys_prompt = "你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。" |
| | vc_prompt_prefix = sys_prompt + "模仿输入音频中的声音特征。" |
| | vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" |
| | else: |
| | sys_prompt = "You are a helpful assistant. You can accept video, audio and text input and output voice and text. " |
| | vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt." |
| | vc_prompt_suffix = "As an assistant, you will speak using this voice style." |
| |
|
| | if ref_audio is not None: |
| | sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
| |
|
| | else: |
| | sys_msgs = {"role": "user", "content": [sys_prompt]} |
| |
|
| | return sys_msgs |
| | elif mode == "audio_assistant": |
| | if language == "zh": |
| | vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| | vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" |
| | else: |
| | vc_prompt_prefix = "Use the voice in the audio prompt to synthesize new content." |
| | vc_prompt_suffix = "You are a helpful assistant with the above voice style." |
| |
|
| | if ref_audio is not None: |
| | sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
| |
|
| | else: |
| | logger.warning( |
| | "Warning: ref_audio is None, speech generation will be performed based on the default voice." |
| | ) |
| | sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]} |
| |
|
| | return sys_msgs |
| | elif mode == "audio_roleplay": |
| | if language == "zh": |
| | vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| | vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。" |
| | else: |
| | vc_prompt_prefix = "Clone the voice in the provided audio prompt." |
| | vc_prompt_suffix = "Try to role-play the character based on the audio prompt above." |
| |
|
| | if ref_audio is not None: |
| | sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} |
| | else: |
| | print("Warning: ref_audio is None, speech generation will be performed based on the default voice.") |
| | sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]} |
| |
|
| | return sys_msgs |
| | elif mode == "voice_cloning": |
| | if language == "zh": |
| | vc_prompt_prefix = "模仿输入音频中的声音特征。" |
| | else: |
| | vc_prompt_prefix = "Clone the voice in the provided audio prompt." |
| |
|
| | if ref_audio is not None: |
| | sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio]} |
| | else: |
| | raise ValueError("ref_audio con't be None in voice_cloning mode.") |
| |
|
| | return sys_msgs |
| | else: |
| | sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text." |
| | sys_msgs = {"role": "user", "content": [sys_prompt]} |
| |
|
| | return sys_msgs |
| |
|
| | def generate( |
| | self, |
| | input_ids=None, |
| | pixel_values=None, |
| | tgt_sizes=None, |
| | audio_features=[], |
| | audio_feature_lens=None, |
| | image_bound=None, |
| | audio_bounds=None, |
| | spk_bounds=None, |
| | attention_mask=None, |
| | tokenizer=None, |
| | vision_hidden_states=None, |
| | stream=False, |
| | decode_text=True, |
| | **kwargs, |
| | ): |
| | assert input_ids is not None |
| | assert len(input_ids) == len(pixel_values) |
| |
|
| | model_inputs = { |
| | "input_ids": input_ids, |
| | "audio_features": audio_features, |
| | "audio_feature_lens": audio_feature_lens, |
| | "image_bound": image_bound, |
| | "audio_bounds": audio_bounds, |
| | "spk_bounds": spk_bounds, |
| | } |
| |
|
| | if vision_hidden_states is None: |
| | model_inputs["pixel_values"] = pixel_values |
| | model_inputs["tgt_sizes"] = tgt_sizes |
| | else: |
| | model_inputs["vision_hidden_states"] = vision_hidden_states |
| |
|
| | model_output = {} |
| | with torch.inference_mode(): |
| | model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) |
| | model_inputs["inputs_embeds"] = self.get_omni_embedding( |
| | model_inputs, |
| | input_embeddings=model_inputs["inputs_embeds"], |
| | chunk_length=self.config.audio_chunk_length, |
| | ) |
| |
|
| | if stream: |
| | result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) |
| | |
| | outputs = {} |
| | else: |
| | outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) |
| |
|
| | result = self._decode_text(outputs.sequences, tokenizer) |
| | |
| | if decode_text is False: |
| | return outputs |
| | |
| | return result, outputs |
| |
|
| | def chat( |
| | self, |
| | image=None, |
| | msgs=None, |
| | tokenizer=None, |
| | processor=None, |
| | vision_hidden_states=None, |
| | max_new_tokens=2048, |
| | min_new_tokens=0, |
| | sampling=True, |
| | max_inp_length=32768, |
| | stream=False, |
| | chunk_input=True, |
| | omni_input=False, |
| | max_slice_nums=None, |
| | use_image_id=None, |
| | use_tts_template=False, |
| | generate_audio=False, |
| | return_spk_embed=False, |
| | return_dict=False, |
| | output_audio_path=None, |
| | **kwargs, |
| | ): |
| | """ |
| | Unified chat function |
| | |
| | Args: |
| | image: use for batch_size=1 vqa, It is not recommended to continue to use this parameter |
| | msgs: the input chat msgs, support text: (string) / image: (PIL.Image) / audio (numpy.ndarray) |
| | tokenizer: tokenizer for llm |
| | processor: if None, use the default processor |
| | max_new_tokens: the maximum length of the generation |
| | min_new_tokens: the minimum length of the generation |
| | sampling: whether to use sampling decoding or beam search decoding |
| | max_inp_length: the maximum length of input |
| | stream: whether to return generator, only used when tts is not required |
| | chunk_input: whether to split audio into 1s chunks |
| | omni_input: determine whether it is omni mode |
| | max_slice_nums: control the maximum number of image slices |
| | use_image_id: for video understanding or omni understanding, use_image_id should be False |
| | use_tts_template: if the msgs contain audio, use_tts_template should be True |
| | generate_audio: whether to generate audio output, only used when return_dict=True |
| | return_spk_embed: whether to return spk embedding, only used when return_dict=True |
| | return_dict: whether to return dict |
| | output_audio_path: audio save path when generate_audio |
| | **kwargs: |
| | """ |
| | if isinstance(msgs[0], list): |
| | batched = True |
| | else: |
| | batched = False |
| |
|
| | if generate_audio or return_spk_embed: |
| | return_dict = True |
| |
|
| | msgs_list = msgs |
| | images_list = image |
| |
|
| | if batched is False: |
| | images_list, msgs_list = [images_list], [msgs_list] |
| | else: |
| | assert images_list is None, "Please integrate image to msgs when using batch inference." |
| | images_list = [None] * len(msgs_list) |
| | assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." |
| |
|
| | if processor is None: |
| | if self.processor is None: |
| | self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) |
| | processor = self.processor |
| |
|
| | assert ( |
| | self.config.query_num == processor.image_processor.image_feature_size |
| | ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| | assert ( |
| | self.config.patch_size == processor.image_processor.patch_size |
| | ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| | assert ( |
| | self.config.use_image_id == processor.image_processor.use_image_id |
| | ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| | assert ( |
| | self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums |
| | ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| | assert ( |
| | self.config.slice_mode == processor.image_processor.slice_mode |
| | ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." |
| |
|
| | prompts_lists = [] |
| | input_images_list = [] |
| | input_audios_list = [] |
| | audio_parts_list = [] |
| |
|
| | for image, msgs in zip(images_list, msgs_list): |
| | if isinstance(msgs, str): |
| | msgs = json.loads(msgs) |
| | copy_msgs = deepcopy(msgs) |
| |
|
| | assert len(msgs) > 0, "msgs is empty" |
| | assert sampling or not stream, "if use stream mode, make sure sampling=True" |
| |
|
| | if image is not None and isinstance(copy_msgs[0]["content"], str): |
| | copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] |
| |
|
| | images = [] |
| | audios = [] |
| | audio_parts = [] |
| | for i, msg in enumerate(copy_msgs): |
| | role = msg["role"] |
| | content = msg["content"] |
| | assert role in ["system", "user", "assistant"] |
| | if i == 0: |
| | assert role in ["user", "system"], "The role of first msg should be user" |
| | if isinstance(content, str): |
| | content = [content] |
| | cur_msgs = [] |
| | for c in content: |
| | if isinstance(c, Image.Image): |
| | images.append(c) |
| | cur_msgs.append("(<image>./</image>)") |
| | elif isinstance(c, np.ndarray): |
| | audios.append(c) |
| | audio_parts.append(i) |
| | cur_msgs.append("(<audio>./</audio>)") |
| | use_tts_template = True |
| | elif isinstance(c, str): |
| | cur_msgs.append(c) |
| | if omni_input: |
| | msg["content"] = "".join(cur_msgs) |
| | else: |
| | msg["content"] = "\n".join(cur_msgs) |
| |
|
| | prompts_lists.append( |
| | processor.tokenizer.apply_chat_template( |
| | copy_msgs, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | chat_template=self.default_tts_chat_template if use_tts_template else None, |
| | ) |
| | ) |
| | input_images_list.append(images) |
| | input_audios_list.append(audios) |
| | audio_parts_list.append(audio_parts) |
| |
|
| | inputs = processor( |
| | prompts_lists, |
| | input_images_list, |
| | input_audios_list, |
| | audio_parts_list, |
| | max_slice_nums=max_slice_nums, |
| | use_image_id=use_image_id, |
| | chunk_input=chunk_input, |
| | return_tensors="pt", |
| | max_length=max_inp_length, |
| | ).to(self.device) |
| |
|
| | if sampling: |
| | generation_config = { |
| | "top_p": 0.8, |
| | "top_k": 100, |
| | "temperature": 0.7, |
| | "do_sample": True, |
| | "repetition_penalty": 1.05, |
| | } |
| | else: |
| | generation_config = { |
| | "num_beams": 3, |
| | "repetition_penalty": 1.2, |
| | } |
| |
|
| | if min_new_tokens > 0: |
| | generation_config["min_new_tokens"] = min_new_tokens |
| |
|
| | generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
| |
|
| | inputs.pop("image_sizes") |
| | with torch.inference_mode(): |
| | res, outputs = self.generate( |
| | **inputs, |
| | tokenizer=tokenizer, |
| | max_new_tokens=max_new_tokens, |
| | vision_hidden_states=vision_hidden_states, |
| | stream=stream, |
| | **generation_config, |
| | ) |
| |
|
| | if stream: |
| |
|
| | def stream_gen(): |
| | for text in res: |
| | for term in self.terminators: |
| | text = text.replace(term, "") |
| | yield text |
| |
|
| | if return_dict: |
| | return OmniOutput(text=stream_gen()) |
| | else: |
| | return stream_gen() |
| |
|
| | else: |
| | spk_embeds = wav_numpy = sr = None |
| |
|
| | if batched: |
| | answer = res |
| | else: |
| | answer = res[0] |
| |
|
| | if use_tts_template and generate_audio: |
| | mel_spec = self._generate_mel_spec(inputs, outputs, answer) |
| | wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path) |
| |
|
| | if return_spk_embed: |
| | spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
| |
|
| | if isinstance(answer, list): |
| | answer = [i.replace(tokenizer.tts_end, "") for i in answer] |
| | else: |
| | answer = answer.replace(tokenizer.tts_end, "") |
| |
|
| | if return_dict: |
| | return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr) |
| | else: |
| | return answer |
| |
|
| | @torch.inference_mode() |
| | def streaming_prefill( |
| | self, |
| | session_id, |
| | msgs, |
| | tokenizer, |
| | omni_input=True, |
| | max_slice_nums=None, |
| | ls_temperature=1.0, |
| | **kwargs, |
| | ): |
| | """ |
| | Streaming video/audio input and output audio stream, Only support batch_size=1 |
| | Args: |
| | session_id: Note: new connection should use a new session_id |
| | """ |
| | assert session_id is not None |
| | if self.session_id is None or session_id != self.session_id: |
| | self.is_first = True |
| | else: |
| | self.is_first = False |
| |
|
| | images = [] |
| | audios = [] |
| |
|
| | assert len(msgs) == 1 |
| | copy_msgs = deepcopy(msgs) |
| | msg = copy_msgs[0] |
| |
|
| | assert msg["role"] in ["system", "user", "assistant"] |
| |
|
| | content = msg["content"] |
| | cur_msgs = [] |
| | for j, c in enumerate(content): |
| | if isinstance(c, Image.Image): |
| | images.append(c) |
| | cur_msgs.append("(<image>./</image>)") |
| | elif isinstance(c, np.ndarray): |
| | audios.append(c) |
| | cur_msgs.append("(<audio>./</audio>)") |
| | elif isinstance(c, str): |
| | cur_msgs.append(c) |
| | else: |
| | logger.error("Invalid content type:", c) |
| |
|
| | cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input) |
| | if not self.is_first and self.new_user_msg and msg["role"] == "user": |
| | if self.llm_generated: |
| | if self.llm_generate_completed: |
| | msg["content"] = "<|im_end|>\n<|im_start|>user\n" + cur_contents |
| | else: |
| | msg["content"] = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents |
| | else: |
| | msg["content"] = "<|im_start|>user\n" + cur_contents |
| | self.new_user_msg = False |
| | else: |
| | msg["content"] = cur_contents |
| |
|
| | if msg["role"] in ["system", "assistant"]: |
| | self.new_user_msg = True |
| | self.audio_past_key_values = None |
| |
|
| | if self.is_first: |
| | |
| | logger.info(f"new session_id: {session_id}, reset kv cache") |
| | self.reset_session() |
| | self.session_id = session_id |
| |
|
| | prompt = tokenizer.apply_chat_template( |
| | copy_msgs, tokenize=False, add_generation_prompt=False, chat_template=self.default_tts_chat_template |
| | ) |
| | add_special_tokens = True |
| | else: |
| | prompt = copy_msgs[0]["content"] |
| | add_special_tokens = False |
| |
|
| | model_inputs = self.processor( |
| | [prompt], |
| | [images], |
| | [audios], |
| | max_slice_nums=1 if max_slice_nums is None else max_slice_nums, |
| | use_image_id=False, |
| | chunk_input=True, |
| | return_tensors="pt", |
| | max_length=None, |
| | sampling_rate=16000, |
| | add_special_tokens=add_special_tokens, |
| | ).to(self.device) |
| |
|
| | |
| | model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs) |
| | |
| | inputs_embeds = self.get_omni_embedding( |
| | model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=True |
| | ) |
| |
|
| | if self.is_first: |
| | |
| | self.audio_past_key_values = None |
| |
|
| | if self.llm_past_key_values is not None: |
| | cache_length = self.llm_past_key_values[0][0].shape[2] |
| | else: |
| | cache_length = 0 |
| |
|
| | attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device) |
| |
|
| | |
| | outputs = self.llm( |
| | past_key_values=self.llm_past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=None, |
| | use_cache=True, |
| | return_dict=True, |
| | ) |
| | self.llm_past_key_values = outputs["past_key_values"] |
| | return |
| |
|
| | @torch.inference_mode() |
| | def streaming_generate( |
| | self, |
| | session_id, |
| | tokenizer, |
| | max_new_tokens=512, |
| | min_new_tokens=0, |
| | sampling=True, |
| | generate_audio=True, |
| | enable_regenerate=False, |
| | **kwargs, |
| | ): |
| | """ |
| | Streaming video/audio input and output audio stream |
| | Args: |
| | """ |
| | if sampling: |
| | generation_config = { |
| | "top_p": 0.8, |
| | "top_k": 100, |
| | "temperature": 0.7, |
| | "do_sample": True, |
| | "repetition_penalty": 1.05, |
| | } |
| | else: |
| | generation_config = { |
| | "num_beams": 3, |
| | "repetition_penalty": 1.2, |
| | } |
| | generation_config["min_new_tokens"] = min_new_tokens |
| | generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) |
| |
|
| | |
| | |
| | self.new_user_msg = True |
| | self.llm_generated = True |
| | self.llm_generate_completed = False |
| | self.audio_past_key_values = None |
| |
|
| | terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] |
| | generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>" |
| | input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda() |
| |
|
| | spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0] |
| | spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0] |
| | spk_bounds = [ |
| | torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) |
| | ] |
| |
|
| | cache_length = past_length = self.llm_past_key_values[0][0].shape[2] |
| | attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) |
| |
|
| | generation_config["max_new_tokens"] = max_new_tokens |
| | streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config) |
| |
|
| | if generate_audio: |
| | result = self._generate_mel_spec_audio_streaming( |
| | spk_bounds, streamer, output_chunk_size=25, enable_regenerate=enable_regenerate |
| | ) |
| | return result |
| | else: |
| | return streamer |
| |
|
| | def llm_generate_chunk(self, input_ids, attention_mask, tokenizer, terminators, generation_config): |
| | def check_uncompleted_token(ids): |
| | cur_text = tokenizer.decode(ids) |
| | end = len(ids) |
| | while cur_text[-1] == "�": |
| | end -= 1 |
| | if end == 0: |
| | break |
| | cur_text = tokenizer.decode(ids[:end]) |
| | return end |
| |
|
| | max_new_tokens = int(generation_config.pop("max_new_tokens", 2048)) |
| | new_len = 0 |
| | first_chunk = True |
| | eos = False |
| | left_ids = None |
| |
|
| | while True: |
| | outputs = self.llm.generate( |
| | input_ids=input_ids, |
| | past_key_values=self.llm_past_key_values, |
| | attention_mask=attention_mask, |
| | use_cache=True, |
| | max_new_tokens=3, |
| | pad_token_id=0, |
| | output_hidden_states=True if first_chunk else False, |
| | return_dict_in_generate=True, |
| | eos_token_id=terminators, |
| | **generation_config, |
| | ) |
| | if outputs.sequences[0, -1] in terminators: |
| | eos = True |
| | input_len = input_ids.shape[1] |
| | cur_ids = outputs.sequences[:, input_len:] |
| | new_len += cur_ids.shape[1] |
| |
|
| | if left_ids is not None and left_ids.shape[1] > 0: |
| | cur_ids = torch.cat([left_ids, cur_ids], dim=1) |
| | end = check_uncompleted_token(cur_ids[0]) |
| | left_ids = cur_ids[:, end:] |
| | cur_ids = cur_ids[:, :end] |
| | text = self._decode_text(cur_ids, tokenizer)[0] if end > 0 else "" |
| |
|
| | self.llm_past_key_values = outputs.past_key_values |
| | input_ids = outputs.sequences[:, -1:] |
| | cache_length = past_length = self.llm_past_key_values[0][0].shape[2] |
| | attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) |
| |
|
| | res = {"text": text} |
| | if first_chunk: |
| | res["hidden_states"] = outputs.hidden_states |
| | first_chunk = False |
| | yield res |
| |
|
| | if eos: |
| | self.llm_generate_completed = True |
| | break |
| | if new_len >= max_new_tokens: |
| | logger.debug(f"LLM generation {new_len} exceeds max_new_tokens({max_new_tokens}), break.") |
| | break |
| |
|
| | def prepare_tts_text(self, text): |
| | tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False) |
| | tts_tokens_len = len(tts_tokens) |
| | if tts_tokens_len < self.tts.streaming_text_reserved_len: |
| | num_pad_tokens = self.tts.streaming_text_reserved_len - tts_tokens_len |
| |
|
| | pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1) |
| | else: |
| | tts_tokens = tts_tokens[0 : self.tts.streaming_text_reserved_len] |
| | tts_tokens_len = len(tts_tokens) |
| | text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False) |
| | pad_str = "" |
| | spk_emb_placeholder_tts = "[spk_emb]" * self.tts.num_spk_embs |
| |
|
| | new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]" |
| | return new_text_tts, tts_tokens_len |
| |
|
| | def get_tts_text_start_token_ids(self): |
| | text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs |
| | tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[ |
| | "input_ids" |
| | ].cuda() |
| | return tts_input_ids |
| |
|
| | def _build_streaming_mask(self, tts_tokens_len): |
| | tts_sequence_full_length = ( |
| | 1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1 |
| | ) |
| | streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) |
| | streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1 |
| | streaming_attention_mask[-1] = 1 |
| | return streaming_attention_mask |
| |
|
| | def _get_last_spk_embeds(self, inputs, outputs): |
| | last_hidden_states = [hs[-1] for hs in outputs.hidden_states] |
| |
|
| | |
| | last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) |
| |
|
| | |
| | spk_bound = inputs["spk_bounds"][0][-1] |
| |
|
| | spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] |
| | return spk_embeds |
| |
|
| | def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048): |
| | spk_embeds = self._get_last_spk_embeds(inputs, outputs) |
| |
|
| | text = text.split("<|tts_bos|>")[-1] |
| | gen_text = text.split("<|tts_eos|>")[0] |
| | tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| | tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) |
| | tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) |
| | streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| |
|
| | logits_warpers, logits_processors = gen_logits( |
| | num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty |
| | ) |
| |
|
| | condition_length = ( |
| | 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 |
| | ) |
| |
|
| | dtype = self.tts.emb_text.weight.dtype |
| | emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) |
| | past_key_values = [ |
| | ( |
| | torch.zeros( |
| | 1, |
| | self.tts.config.num_attention_heads, |
| | condition_length - 1, |
| | self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| | dtype=emb.dtype, |
| | device=self.tts.device, |
| | ), |
| | torch.zeros( |
| | 1, |
| | self.tts.config.num_attention_heads, |
| | condition_length - 1, |
| | self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| | dtype=emb.dtype, |
| | device=self.tts.device, |
| | ), |
| | ) |
| | for _ in range(self.tts.config.num_hidden_layers) |
| | ] |
| |
|
| | audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) |
| |
|
| | eos_lab = False |
| | for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)): |
| | if chunk_idx == 0: |
| | begin = chunk_idx * self.tts.streaming_text_chunk_size + 0 |
| | end = ( |
| | (chunk_idx + 1) * self.tts.streaming_text_chunk_size |
| | + 1 |
| | + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| | ) |
| | else: |
| | begin = ( |
| | chunk_idx * self.tts.streaming_text_chunk_size |
| | + 1 |
| | + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| | ) |
| | end = min( |
| | (chunk_idx + 1) * self.tts.streaming_text_chunk_size |
| | + 1 |
| | + self.tts.use_speaker_embedding * self.tts.num_spk_embs, |
| | condition_length - 1, |
| | ) |
| |
|
| | if end - begin > 0: |
| | text_input_ids = tts_input_ids[:, begin:end] |
| | position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
| |
|
| | if begin == 0: |
| | past_key_values = self.tts.prefill_text( |
| | input_ids=text_input_ids, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | lm_spk_emb_last_hidden_states=spk_embeds, |
| | ) |
| | else: |
| | past_key_values = self.tts.prefill_text( |
| | input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values |
| | ) |
| |
|
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids, |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=output_chunk_size, |
| | force_no_stop=self.force_no_stop, |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| | audio_input_ids = outputs.audio_input_ids |
| | past_key_values = outputs.past_key_values |
| |
|
| | if outputs.finished: |
| | logger.debug("Generation finished.") |
| | eos_lab = True |
| | break |
| |
|
| | if not eos_lab: |
| | logger.debug("eos_lab False, Generation continue.") |
| | while True: |
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids, |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=output_chunk_size, |
| | force_no_stop=self.force_no_stop, |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| |
|
| | audio_input_ids = outputs.audio_input_ids |
| | past_key_values = outputs.past_key_values |
| |
|
| | if outputs.finished: |
| | logger.debug("Generation finished.") |
| | break |
| | if outputs.new_ids.shape[1] > tts_max_new_tokens: |
| | logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.") |
| | break |
| |
|
| | mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids) |
| | return mel_spec |
| |
|
| | def _linear_overlap_add2_wav(self, frames: List[torch.Tensor], overlap: int): |
| | """ |
| | Merge two audio waveforms with smooth in streaming audio generation. |
| | Borrowed some codes from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py` |
| | """ |
| | assert len(frames) == 2 |
| | device = frames[0].device |
| | dtype = frames[0].dtype |
| | |
| |
|
| | frame0_length = frames[0].shape[-1] |
| | frame1_length = frames[1].shape[-1] |
| | total_size = frame0_length + frame1_length - overlap |
| | weight_len = max(frame0_length, frame1_length) + overlap |
| | t = torch.linspace(0, 1, weight_len + 2, device=device, dtype=dtype)[1:-1] |
| | weight = 0.5 - (t - 0.5).abs() |
| |
|
| | sum_weight = torch.zeros(total_size, device=device, dtype=dtype) |
| | out = torch.zeros(total_size, device=device, dtype=dtype) |
| | offset: int = 0 |
| |
|
| | out[offset : offset + frame0_length] += weight[-frame0_length:] * frames[0] |
| | sum_weight[offset : offset + frame0_length] += weight[-frame0_length:] |
| | offset += frame0_length - overlap |
| | out[offset : offset + frame1_length] += weight[:frame1_length] * frames[1] |
| | sum_weight[offset : offset + frame1_length] += weight[:frame1_length] |
| |
|
| | assert sum_weight.min() > 0 |
| | out = out / sum_weight |
| | return out[:frame0_length], out[frame0_length:] |
| |
|
| | def _generate_mel_spec_audio_streaming( |
| | self, |
| | spk_bounds, |
| | streamer, |
| | output_chunk_size=25, |
| | spk_embeds=None, |
| | prev_seg_text_ids=None, |
| | prev_seg_text_left="", |
| | prev_seg_audio_ids=None, |
| | enable_regenerate=False, |
| | ): |
| | |
| | gen_text = "" |
| | tts_text = "" |
| | new_segment_gen = False |
| | if spk_embeds is None: |
| | spk_bound = spk_bounds[0][-1] |
| | r = next(streamer) |
| | txt = r["text"] |
| | gen_text += txt.split("<|tts_eos|>")[0] |
| | tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| | last_hidden_states = r["hidden_states"][0][-1][0] |
| | spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] |
| |
|
| | |
| | logits_warpers, logits_processors = gen_logits( |
| | num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty |
| | ) |
| | condition_length = ( |
| | 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 |
| | ) |
| | tts_start_token_len = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs |
| | dtype = self.tts.emb_text.weight.dtype |
| | past_key_values = [ |
| | ( |
| | torch.zeros( |
| | 1, |
| | self.tts.config.num_attention_heads, |
| | condition_length - 1, |
| | self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| | dtype=dtype, |
| | device=self.tts.device, |
| | ), |
| | torch.zeros( |
| | 1, |
| | self.tts.config.num_attention_heads, |
| | condition_length - 1, |
| | self.tts.config.hidden_size // self.tts.config.num_attention_heads, |
| | dtype=dtype, |
| | device=self.tts.device, |
| | ), |
| | ) |
| | for _ in range(self.tts.config.num_hidden_layers) |
| | ] |
| | audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) |
| |
|
| | |
| | chunk_idx = 0 |
| | new_ids_len = 0 |
| | prev_text_len = 0 |
| | if prev_seg_text_ids is not None and prev_seg_audio_ids is not None: |
| | tts_token_lens = prev_seg_text_ids.shape[1] |
| | |
| | streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| | position_ids = torch.arange( |
| | 0, tts_token_lens + tts_start_token_len, dtype=torch.long, device=self.tts.device |
| | ).unsqueeze(0) |
| |
|
| | text_input_ids = self.get_tts_text_start_token_ids() |
| | text_input_ids = torch.cat([text_input_ids, prev_seg_text_ids], dim=1) |
| | past_key_values = self.tts.prefill_text( |
| | input_ids=text_input_ids, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | lm_spk_emb_last_hidden_states=spk_embeds, |
| | ) |
| | past_key_values = self.tts.prefill_audio_ids( |
| | input_ids=prev_seg_audio_ids[:, :-1, :], |
| | |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | ) |
| |
|
| | |
| | chunk_idx += int(tts_token_lens / self.tts.streaming_text_chunk_size) |
| | audio_input_ids = torch.cat([audio_input_ids, prev_seg_audio_ids], dim=1) |
| | text = self.tts_processor.text_tokenizer.decode(prev_seg_text_ids[0].tolist(), add_special_tokens=False) |
| |
|
| | gen_text += text |
| | gen_text += prev_seg_text_left |
| | prev_text_len = len(gen_text) |
| | new_ids_len += prev_seg_audio_ids.shape[1] |
| |
|
| | prev_wav = None |
| | eos_lab = False |
| | stop = False |
| | shift_len = 180 |
| | voice_checker = VoiceChecker() |
| | number_converter = NumberToTextConverter() |
| | lang = None |
| | gen_text_raw = gen_text |
| | for t, r in enumerate(streamer): |
| | t += 1 |
| | txt = r["text"] |
| | txt = txt.split("<|tts_eos|>")[0] |
| | gen_text_raw += txt |
| | if t == 1 and txt == "" and prev_seg_text_ids is not None: |
| | logger.warning("New segment is empty, generation finished.") |
| | return |
| | if t <= 2: |
| | lang = number_converter.detect_language(gen_text_raw) |
| | gen_text += number_converter.replace_numbers_with_text(txt, lang).replace("*", "") |
| |
|
| | |
| | tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| |
|
| | if tts_token_lens >= self.tts.streaming_text_reserved_len - shift_len: |
| | end_c = sentence_end(txt) |
| | if end_c: |
| | end_c_idx = gen_text.rfind(end_c) |
| | assert end_c_idx != -1 |
| | text_left = gen_text[end_c_idx + 1 :] |
| | gen_text = gen_text[: end_c_idx + 1] |
| | tts_text, tts_token_lens = self.prepare_tts_text(gen_text) |
| | new_segment_gen = True |
| | logger.debug( |
| | f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, starting a new segment generation" |
| | ) |
| | break |
| |
|
| | if tts_token_lens >= (chunk_idx + 1) * self.tts.streaming_text_chunk_size: |
| |
|
| | |
| | if chunk_idx == 0: |
| | begin = 0 |
| | end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len |
| | else: |
| | begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len |
| | end = min( |
| | (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len, condition_length - 1 |
| | ) |
| |
|
| | tts_input_ids = self.tts_processor.text_tokenizer( |
| | tts_text, return_tensors="pt", add_special_tokens=False |
| | )["input_ids"].cuda() |
| | text_input_ids = tts_input_ids[:, begin:end] |
| | streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| | position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
| |
|
| | past_key_values = self.tts.prefill_text( |
| | input_ids=text_input_ids, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, |
| | ) |
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids, |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=output_chunk_size, |
| | force_no_stop=self.force_no_stop, |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| | audio_input_ids = ( |
| | outputs.audio_input_ids |
| | ) |
| | past_key_values = outputs.past_key_values |
| | chunk_idx += 1 |
| |
|
| | mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) |
| | new_ids_len = outputs.new_ids.shape[1] |
| |
|
| | wav_np, sr = self.decode_mel_to_audio(mel_spec) |
| |
|
| | if enable_regenerate: |
| | if prev_wav is not None: |
| | check_wav_np = wav_np[2048:].cpu().numpy() |
| | check_mel = mel_spec[0, :, 8:].cpu().numpy() |
| | else: |
| | check_wav_np = wav_np.cpu().numpy() |
| | check_mel = mel_spec[0].cpu().numpy() |
| | if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): |
| | voice_checker.reset() |
| | |
| | N = output_chunk_size if prev_wav is None else output_chunk_size * 2 |
| | past_kv = [] |
| | for i in range(len(past_key_values)): |
| | past_kv.append( |
| | ( |
| | past_key_values[i][0][:, :, :-N, :], |
| | past_key_values[i][1][:, :, :-N, :], |
| | ) |
| | ) |
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids[:, :-N, :], |
| | past_key_values=past_kv, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=N, |
| | force_no_stop=self.force_no_stop, |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| | audio_input_ids = outputs.audio_input_ids |
| | past_key_values = outputs.past_key_values |
| |
|
| | new_ids_len -= N |
| | mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) |
| | new_ids_len = outputs.new_ids.shape[1] |
| | wav_np, sr = self.decode_mel_to_audio(mel_spec) |
| |
|
| | if prev_wav is not None: |
| | wav_y = wav_np[: len(prev_wav)] |
| | prev_wav = wav_np[len(prev_wav) :] |
| | cur_text = gen_text_raw[prev_text_len:] |
| | prev_text_len = len(gen_text_raw) |
| | yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) |
| |
|
| | else: |
| | prev_wav = wav_np |
| | else: |
| | |
| | if prev_wav is not None: |
| | wav_np, prev_wav = self._linear_overlap_add2_wav( |
| | [prev_wav, wav_np], overlap=512 * 4 |
| | ) |
| | cur_text = gen_text_raw[prev_text_len:] |
| | prev_text_len = len(gen_text_raw) |
| | yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) |
| |
|
| | else: |
| | prev_wav = wav_np |
| |
|
| | if outputs.finished: |
| | logger.debug("Generation finished.") |
| | eos_lab = True |
| | break |
| |
|
| | if not eos_lab and tts_text: |
| | logger.debug("eos_lab False, Generation continue.") |
| |
|
| | if chunk_idx == 0: |
| | begin = 0 |
| | else: |
| | begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len |
| | end = tts_token_lens + tts_start_token_len + 1 |
| | if end > begin: |
| | tts_input_ids = self.tts_processor.text_tokenizer( |
| | tts_text, return_tensors="pt", add_special_tokens=False |
| | )["input_ids"].cuda() |
| | text_input_ids = tts_input_ids[:, begin:end] |
| | streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) |
| | position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) |
| |
|
| | past_key_values = self.tts.prefill_text( |
| | input_ids=text_input_ids, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, |
| | ) |
| |
|
| | while True: |
| | |
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids, |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=output_chunk_size, |
| | force_no_stop=self.force_no_stop, |
| | |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| | audio_input_ids = outputs.audio_input_ids |
| | past_key_values = outputs.past_key_values |
| | chunk_idx += 1 |
| |
|
| | mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) |
| | new_ids_len = outputs.new_ids.shape[1] |
| |
|
| | wav_np, sr = self.decode_mel_to_audio(mel_spec) |
| |
|
| | if enable_regenerate: |
| | if prev_wav is not None: |
| | check_wav_np = wav_np[2048:].cpu().numpy() |
| | check_mel = mel_spec[0, :, 8:].cpu().numpy() |
| | else: |
| | check_wav_np = wav_np.cpu().numpy() |
| | check_mel = mel_spec[0].cpu().numpy() |
| | if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): |
| | voice_checker.reset() |
| | |
| | N = output_chunk_size if prev_wav is None else output_chunk_size * 2 |
| | past_kv = [] |
| | for i in range(len(past_key_values)): |
| | past_kv.append( |
| | ( |
| | past_key_values[i][0][:, :, :-N, :], |
| | past_key_values[i][1][:, :, :-N, :], |
| | ) |
| | ) |
| | outputs = self.tts.generate( |
| | input_ids=audio_input_ids[:, :-N, :], |
| | past_key_values=past_kv, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=N, |
| | force_no_stop=self.force_no_stop, |
| | temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), |
| | eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), |
| | logits_warpers=logits_warpers, |
| | logits_processors=logits_processors, |
| | ) |
| | audio_input_ids = outputs.audio_input_ids |
| | past_key_values = outputs.past_key_values |
| |
|
| | new_ids_len -= N |
| | mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) |
| | new_ids_len = outputs.new_ids.shape[1] |
| | wav_np, sr = self.decode_mel_to_audio(mel_spec) |
| |
|
| | if prev_wav is not None: |
| | wav_y = wav_np[: len(prev_wav)] |
| | prev_wav = wav_np[len(prev_wav) :] |
| | cur_text = gen_text_raw[prev_text_len:] |
| | prev_text_len = len(gen_text_raw) |
| | yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) |
| | else: |
| | prev_wav = wav_np |
| | else: |
| | |
| | if prev_wav is not None: |
| | wav_np, prev_wav = self._linear_overlap_add2_wav( |
| | [prev_wav, wav_np], overlap=512 * 4 |
| | ) |
| | cur_text = gen_text_raw[prev_text_len:] |
| | prev_text_len = len(gen_text_raw) |
| | yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) |
| | else: |
| | prev_wav = wav_np |
| |
|
| | if outputs.finished: |
| | logger.debug("Generation finished.") |
| | break |
| | if outputs.new_ids.shape[1] > 2048: |
| | stop = True |
| | logger.debug("Generation length > 2048, stopped.") |
| | break |
| |
|
| | if prev_wav is not None: |
| | cur_text = gen_text_raw[prev_text_len:] |
| | yield OmniOutput(text=cur_text, audio_wav=prev_wav, sampling_rate=sr) |
| |
|
| | if new_segment_gen and not stop: |
| | logger.debug( |
| | f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, start a new segment generation" |
| | ) |
| | tid_len = 5 |
| | prev_seg_text_ids = tts_input_ids[:, end - 1 - tid_len : end - 1] |
| | aid_len = 50 |
| | prev_seg_audio_ids = outputs.new_ids[:, -aid_len:, :] |
| |
|
| | result = self._generate_mel_spec_audio_streaming( |
| | spk_bounds, |
| | streamer, |
| | output_chunk_size, |
| | spk_embeds, |
| | prev_seg_text_ids, |
| | text_left, |
| | prev_seg_audio_ids, |
| | enable_regenerate=enable_regenerate, |
| | ) |
| | for res in result: |
| | yield res |
| |
|
| | def decode_mel_to_audio(self, mel_spec, output_path=""): |
| | with torch.inference_mode(): |
| | wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze() |
| | sr = 24000 |
| | if output_path: |
| | sf.write(output_path, wav_numpy.numpy(), samplerate=sr) |
| | logger.info(f"Audio saved to {output_path}") |
| | return wav_numpy, sr |
| |
|
| |
|
| | |
| | class MiniCPMWhisperEncoderLayer(nn.Module): |
| | def __init__(self, config: WhisperConfig, layer_idx: int = None): |
| | super().__init__() |
| | self.embed_dim = config.d_model |
| | self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( |
| | embed_dim=self.embed_dim, |
| | num_heads=config.encoder_attention_heads, |
| | dropout=config.attention_dropout, |
| | config=config, |
| | layer_idx=layer_idx, |
| | ) |
| | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
| | self.dropout = config.dropout |
| | self.activation_fn = ACT2FN[config.activation_function] |
| | self.activation_dropout = config.activation_dropout |
| | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
| | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
| | self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | layer_head_mask: torch.Tensor, |
| | output_attentions: bool = False, |
| | past_key_values: Optional[EncoderDecoderCache] = None, |
| | use_cache: Optional[bool] = False, |
| | ) -> torch.Tensor: |
| | r""" |
| | Args: |
| | hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): |
| | Hidden states to be fed into the encoder layer. |
| | attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): |
| | Attention mask where padding elements are indicated by large negative values. |
| | layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): |
| | Mask to nullify selected heads of the attention modules. |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attention weights. |
| | past_key_values (`EncoderDecoderCache`, *optional*): |
| | Past key-value pairs used for incremental decoding. |
| | use_cache (`bool`, *optional*): |
| | Whether or not to return updated `past_key_values` for caching. |
| | |
| | Returns: |
| | A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. |
| | """ |
| | residual = hidden_states |
| | hidden_states = self.self_attn_layer_norm(hidden_states) |
| | hidden_states, attn_weights, past_key_values = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | layer_head_mask=layer_head_mask, |
| | output_attentions=output_attentions, |
| | past_key_value=past_key_values, |
| | ) |
| | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| | hidden_states = residual + hidden_states |
| |
|
| | residual = hidden_states |
| | hidden_states = self.final_layer_norm(hidden_states) |
| | hidden_states = self.activation_fn(self.fc1(hidden_states)) |
| | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
| | hidden_states = self.fc2(hidden_states) |
| | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| | hidden_states = residual + hidden_states |
| |
|
| | if hidden_states.dtype == torch.float16 and ( |
| | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| | ): |
| | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
|
| | outputs = (hidden_states,) |
| |
|
| | if output_attentions: |
| | outputs += (attn_weights,) |
| |
|
| | if use_cache: |
| | outputs += (past_key_values,) |
| |
|
| | return outputs |
| |
|
| |
|
| | |
| | class MiniCPMWhisperEncoder(WhisperEncoder): |
| |
|
| | def __init__(self, config: WhisperConfig): |
| | super().__init__(config) |
| | self.layers = nn.ModuleList( |
| | [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_features, |
| | attention_mask=None, |
| | head_mask=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | past_key_values: Optional[EncoderDecoderCache] = None, |
| | use_cache: Optional[bool] = None, |
| | ): |
| | r""" |
| | Forward pass of the Whisper encoder. |
| | |
| | Args: |
| | input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): |
| | Float values of log-mel features extracted from the raw audio waveform. Typically generated |
| | by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` |
| | files into padded 2D mel spectrogram frames. These features are projected via convolution layers |
| | (`conv1` and `conv2`) and then transformed into embeddings for the encoder. |
| | |
| | attention_mask (`torch.Tensor`, *optional*): |
| | Not used by Whisper for masking `input_features`, but included for API compatibility with |
| | other models. If provided, it is simply ignored within the model. By default, Whisper |
| | effectively ignores silence in the input log-mel spectrogram. |
| | |
| | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): |
| | Mask to nullify selected attention heads. The elements should be either 1 or 0, where: |
| | - 1 indicates the head is **not masked**, |
| | - 0 indicates the head is **masked** (i.e., the attention head is dropped). |
| | |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attention tensors of all encoder layers. If set to `True`, the |
| | returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with |
| | attention weights for each encoder layer. |
| | |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. If set to `True`, the returned |
| | tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the |
| | initial embedding output as well as the outputs of each layer. |
| | |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead |
| | of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, |
| | otherwise it will be a tuple. |
| | |
| | past_key_values (`EncoderDecoderCache`, *optional*): |
| | When using caching for faster inference, this is an object that stores the key-value pairs |
| | for attention states. If provided, the model will append new states to the existing cache |
| | and return the updated cache. This speeds up sequential decoding or chunked inference. |
| | |
| | - If `past_key_values` is `None`, no past states are used or returned. |
| | - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided |
| | cache and return the updated cache (as `next_encoder_cache`). |
| | |
| | use_cache (`bool`, *optional*): |
| | Whether or not the model should use caching (`past_key_values`) to speed up processing |
| | during inference. When set to `True`, the model will: |
| | - Inspect and use `past_key_values` if provided. |
| | - Return updated `past_key_values` (under the name `next_encoder_cache` in |
| | `BaseModelOutputWithPast`). |
| | |
| | Returns: |
| | `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): |
| | If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: |
| | - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| | The output of the final encoder layer. |
| | - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): |
| | Hidden states of the model at each layer (including the initial projection). |
| | - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): |
| | Attention weights from each encoder layer. |
| | - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): |
| | Updated cache of key-value pairs if `use_cache=True`. |
| | |
| | If `return_dict=False`, a tuple is returned, where the format is: |
| | `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` |
| | only present if their respective `output_*` arguments are set to `True`. |
| | |
| | Example: |
| | >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration |
| | >>> import torch |
| | |
| | >>> # Load a feature extractor and a Whisper model |
| | >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") |
| | >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") |
| | |
| | >>> # Assume you have audio (list of floats or numpy array) loaded from a file |
| | >>> # Then extract the mel features: |
| | >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features |
| | |
| | >>> # Forward pass |
| | >>> outputs = model.encoder( |
| | ... input_features=input_features, |
| | ... output_hidden_states=True, |
| | ... output_attentions=True, |
| | ... use_cache=True |
| | ... ) |
| | |
| | >>> # Retrieve the last hidden state |
| | >>> last_hidden_state = outputs.last_hidden_state |
| | >>> print(last_hidden_state.shape) |
| | torch.Size([batch_size, seq_length, hidden_size]) |
| | |
| | >>> # Retrieve the intermediate hidden states if output_hidden_states=True |
| | >>> all_encoder_hidden_states = outputs.hidden_states |
| | |
| | >>> # Retrieve attention weights if output_attentions=True |
| | >>> all_encoder_attentions = outputs.attentions |
| | |
| | >>> # Retrieve updated past key values if use_cache=True |
| | >>> encoder_cache = outputs.past_key_values |
| | """ |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) |
| |
|
| | inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
| | inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
| |
|
| | inputs_embeds = inputs_embeds.permute(0, 2, 1) |
| |
|
| | embed_pos = self.embed_positions.weight |
| | past_key_values_length = 0 |
| | if use_cache: |
| | if past_key_values is None: |
| | past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| | elif isinstance(past_key_values, list): |
| | past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) |
| | elif isinstance(past_key_values, DynamicCache): |
| | past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| | else: |
| | pass |
| | past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) |
| | if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: |
| | logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") |
| | embed_pos_front = embed_pos[past_key_values_length:, :] |
| | embed_pos = torch.cat( |
| | ( |
| | embed_pos_front, |
| | torch.repeat_interleave( |
| | embed_pos[-1, :].unsqueeze(0), |
| | inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, |
| | dim=0, |
| | ), |
| | ) |
| | ) |
| | else: |
| | embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] |
| | else: |
| | embed_pos = embed_pos[: inputs_embeds.shape[1], :] |
| |
|
| | hidden_states = inputs_embeds + embed_pos |
| | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| |
|
| | encoder_states = () if output_hidden_states else None |
| | all_attentions = () if output_attentions else None |
| |
|
| | |
| | if head_mask is not None: |
| | assert head_mask.size()[0] == ( |
| | len(self.layers) |
| | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
| |
|
| | for idx, encoder_layer in enumerate(self.layers): |
| | if output_hidden_states: |
| | encoder_states = encoder_states + (hidden_states,) |
| | |
| | to_drop = False |
| | if self.training: |
| | dropout_probability = torch.rand([]) |
| | if dropout_probability < self.layerdrop: |
| | to_drop = True |
| |
|
| | |
| | if to_drop: |
| | layer_outputs = (None, None) |
| | else: |
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | encoder_layer.__call__, |
| | hidden_states, |
| | attention_mask, |
| | (head_mask[idx] if head_mask is not None else None), |
| | output_attentions, |
| | past_key_values, |
| | use_cache, |
| | ) |
| | else: |
| | layer_outputs = encoder_layer( |
| | hidden_states, |
| | attention_mask, |
| | layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| | output_attentions=output_attentions, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if use_cache: |
| | next_encoder_cache = layer_outputs[2 if output_attentions else 1] |
| | else: |
| | next_encoder_cache = None |
| |
|
| | if output_attentions: |
| | all_attentions = all_attentions + (layer_outputs[1],) |
| |
|
| | hidden_states = self.layer_norm(hidden_states) |
| | if output_hidden_states: |
| | encoder_states = encoder_states + (hidden_states,) |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | hidden_states=encoder_states, |
| | attentions=all_attentions, |
| | past_key_values=next_encoder_cache, |
| | ) |
| |
|
| |
|
| | |
| | class ConvNeXtBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | intermediate_dim: int, |
| | kernel: int, |
| | dilation: int, |
| | layer_scale_init_value: float = 1e-6, |
| | ): |
| | |
| | super().__init__() |
| | self.dwconv = nn.Conv1d( |
| | dim, |
| | dim, |
| | kernel_size=kernel, |
| | padding=dilation * (kernel // 2), |
| | dilation=dilation, |
| | groups=dim, |
| | ) |
| |
|
| | self.norm = nn.LayerNorm(dim, eps=1e-6) |
| | self.pwconv1 = nn.Linear(dim, intermediate_dim) |
| | self.act = nn.GELU() |
| | self.pwconv2 = nn.Linear(intermediate_dim, dim) |
| | self.coef = ( |
| | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) |
| | if layer_scale_init_value > 0 |
| | else None |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: |
| | residual = x |
| |
|
| | y = self.dwconv(x) |
| | y.transpose_(1, 2) |
| | x = self.norm(y) |
| | del y |
| | y = self.pwconv1(x) |
| | del x |
| | x = self.act(y) |
| | del y |
| | y = self.pwconv2(x) |
| | del x |
| | if self.coef is not None: |
| | y *= self.coef |
| | y.transpose_(1, 2) |
| |
|
| | x = y + residual |
| | del y |
| |
|
| | return x |
| |
|
| |
|
| | |
| | class GFSQ(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | levels: List[int], |
| | G: int, |
| | R: int, |
| | eps=1e-5, |
| | transpose=True, |
| | ): |
| | super(GFSQ, self).__init__() |
| | self.quantizer = GroupedResidualFSQ( |
| | dim=dim, |
| | levels=list(levels), |
| | num_quantizers=R, |
| | groups=G, |
| | ) |
| | self.n_ind = math.prod(levels) |
| | self.eps = eps |
| | self.transpose = transpose |
| | self.G = G |
| | self.R = R |
| |
|
| | def _embed(self, x: torch.Tensor): |
| | if self.transpose: |
| | x = x.transpose(1, 2) |
| | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) |
| | feat = self.quantizer.get_output_from_indices(x) |
| | return feat.transpose_(1, 2) if self.transpose else feat |
| |
|
| | def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| | return super().__call__(x) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.transpose: |
| | x.transpose_(1, 2) |
| | _, ind = self.quantizer(x) |
| | ind = ind.permute(1, 2, 0, 3).contiguous() |
| | ind = ind.view(ind.size(0), ind.size(1), -1) |
| | return ind.transpose_(1, 2) if self.transpose else ind |
| |
|
| |
|
| | |
| | class DVAEDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | idim: int, |
| | odim: int, |
| | n_layer=12, |
| | bn_dim=64, |
| | hidden=256, |
| | kernel=7, |
| | dilation=2, |
| | up=False, |
| | ): |
| | super().__init__() |
| | self.up = up |
| | self.conv_in = nn.Sequential( |
| | nn.Conv1d(idim, bn_dim, 3, 1, 1), |
| | nn.GELU(), |
| | nn.Conv1d(bn_dim, hidden, 3, 1, 1), |
| | ) |
| | self.decoder_block = nn.ModuleList( |
| | [ |
| | ConvNeXtBlock( |
| | hidden, |
| | hidden * 4, |
| | kernel, |
| | dilation, |
| | ) |
| | for _ in range(n_layer) |
| | ] |
| | ) |
| | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) |
| |
|
| | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: |
| | |
| | y = self.conv_in(x) |
| | del x |
| | for f in self.decoder_block: |
| | y = f(y, conditioning) |
| |
|
| | x = self.conv_out(y) |
| | del y |
| | return x |
| |
|
| |
|
| | |
| | class DVAE(nn.Module): |
| | def __init__( |
| | self, |
| | ): |
| | super().__init__() |
| |
|
| | coef = torch.rand(100) |
| | self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2)) |
| |
|
| | self.downsample_conv = nn.Sequential( |
| | nn.Conv1d(100, 512, 3, 1, 1), |
| | nn.GELU(), |
| | nn.Conv1d(512, 512, 4, 2, 1), |
| | nn.GELU(), |
| | ) |
| |
|
| | self.encoder = DVAEDecoder( |
| | idim=512, |
| | odim=1024, |
| | hidden=256, |
| | n_layer=12, |
| | bn_dim=128, |
| | ) |
| |
|
| | self.decoder = DVAEDecoder( |
| | idim=512, |
| | odim=512, |
| | hidden=256, |
| | n_layer=12, |
| | bn_dim=128, |
| | ) |
| |
|
| | self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False) |
| |
|
| | self.vq_layer = GFSQ( |
| | dim=1024, |
| | levels=(5, 5, 5, 5), |
| | G=2, |
| | R=2, |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def forward(self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode") -> torch.Tensor: |
| | if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: |
| | mel = inp.clone() |
| | x: torch.Tensor = self.downsample_conv( |
| | torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), |
| | ).unsqueeze_(0) |
| | del mel |
| | x = self.encoder(x) |
| | ind = self.vq_layer(x) |
| | del x |
| | return ind |
| |
|
| | if self.vq_layer is not None: |
| | vq_feats = self.vq_layer._embed(inp) |
| | else: |
| | vq_feats = inp |
| |
|
| | vq_feats = ( |
| | vq_feats.view( |
| | (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), |
| | ) |
| | .permute(0, 2, 3, 1) |
| | .flatten(2) |
| | ) |
| |
|
| | dec_out = self.out_conv( |
| | self.decoder( |
| | x=vq_feats, |
| | ), |
| | ) |
| |
|
| | del vq_feats |
| |
|
| | return torch.mul(dec_out, self.coef, out=dec_out) |
| |
|
| |
|
| | def apply_spk_emb( |
| | input_ids: torch.Tensor = None, |
| | spk_emb: torch.Tensor = None, |
| | input_embeds: torch.Tensor = None, |
| | spk_emb_token_id: int = 0, |
| | num_spk_embs: int = 1, |
| | ): |
| | """ |
| | Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max] |
| | spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim] |
| | input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim] |
| | spk_emb_token_id (int): ID of the speaker embedding token |
| | num_spk_embs (int): Number of speaker embeddings |
| | |
| | Returns: |
| | None |
| | """ |
| |
|
| | batch_size = input_ids.shape[0] |
| |
|
| | for idx in range(batch_size): |
| | input_ids_ = input_ids[idx] |
| | spk_emb_ = spk_emb[idx] |
| | mask_ = input_ids_ == spk_emb_token_id |
| | nonzero_position_idx = mask_.nonzero(as_tuple=False) |
| | assert nonzero_position_idx.shape[0] == num_spk_embs |
| | begin_idx = nonzero_position_idx.min() |
| | end_idx = nonzero_position_idx.max() |
| | input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_ |
| |
|
| | return |
| |
|
| |
|
| | def make_streaming_chunk_mask_generation( |
| | inputs_embeds: torch.Tensor, |
| | past_seen_tokens: int, |
| | streaming_tts_text_mask: torch.Tensor, |
| | streaming_reserved_length: int = 300, |
| | streaming_audio_chunk_size: int = 50, |
| | streaming_text_chunk_size: int = 10, |
| | num_spk_emb: int = 1, |
| | use_spk_emb: bool = True, |
| | ) -> torch.Tensor: |
| | """ |
| | In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens. |
| | |
| | This function creates a mask that allows the model to attend to a specific chunk of text |
| | tokens when generating each chunk of audio tokens, enabling streaming TTS generation. |
| | |
| | Args: |
| | inputs_embeds (torch.Tensor): Input embeddings tensor. |
| | past_seen_tokens (int): Number of tokens already seen by the model. |
| | streaming_tts_text_mask (torch.Tensor): Mask for the text tokens. |
| | streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300. |
| | streaming_chunk_length (int, optional): Length of each streaming chunk. Defaults to 50. |
| | streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7. |
| | |
| | Returns: |
| | torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1] |
| | |
| | Raises: |
| | AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference). |
| | """ |
| | assert inputs_embeds.shape[0] == 1 |
| |
|
| | dtype = inputs_embeds.dtype |
| | device = inputs_embeds.device |
| | min_dtype = torch.finfo(dtype).min |
| |
|
| | |
| | causal_mask = torch.full((1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device) |
| |
|
| | |
| | invisible_text_tokens_start = ( |
| | min( |
| | math.ceil((past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size) |
| | * streaming_text_chunk_size, |
| | streaming_reserved_length, |
| | ) |
| | + 1 |
| | + num_spk_emb * use_spk_emb |
| | ) |
| |
|
| | invisible_text_tokens_end = ( |
| | streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 |
| | ) |
| |
|
| | |
| | causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype |
| |
|
| | |
| | causal_mask[0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_( |
| | streaming_tts_text_mask == 0, min_dtype |
| | ) |
| |
|
| | |
| | causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) |
| |
|
| | return causal_mask |
| |
|
| |
|
| | |
| | class CustomRepetitionPenaltyLogitsProcessorRepeat: |
| | def __init__(self, penalty: float, max_input_ids: int, past_window: int): |
| | if not isinstance(penalty, float) or not (penalty > 0): |
| | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") |
| |
|
| | self.penalty = penalty |
| | self.max_input_ids = max_input_ids |
| | self.past_window = past_window |
| |
|
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| | if input_ids.size(1) > self.past_window: |
| | input_ids = input_ids.narrow(1, -self.past_window, self.past_window) |
| | freq = F.one_hot(input_ids, scores.size(1)).sum(1) |
| | if freq.size(0) > self.max_input_ids: |
| | freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_() |
| | alpha = torch.pow(self.penalty, freq) |
| | scores = scores.contiguous() |
| | inp = scores.multiply(alpha) |
| | oth = scores.divide(alpha) |
| | con = scores < 0 |
| | out = torch.where(con, inp, oth) |
| | del inp, oth, scores, con, alpha |
| | return out |
| |
|
| |
|
| | @dataclass |
| | class ConditionalChatTTSGenerationOutput(ModelOutput): |
| | """ |
| | Output class for ConditionalChatTTS generation. |
| | |
| | Args: |
| | new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). |
| | audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). |
| | past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). |
| | finished (bool): Boolean indicating whether generation is complete. |
| | |
| | """ |
| |
|
| | new_ids: torch.LongTensor = None |
| | audio_input_ids: torch.LongTensor = None |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| | finished: bool = None |
| |
|
| |
|
| | class MultiModalProjector(nn.Module): |
| | def __init__(self, in_dim, out_dim): |
| | super().__init__() |
| | self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) |
| | self.relu = nn.ReLU() |
| | self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) |
| |
|
| | def forward(self, audio_features): |
| | hidden_states = self.relu(self.linear1(audio_features)) |
| | hidden_states = self.linear2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class ConditionalChatTTS(PreTrainedModel): |
| | """A conditional text-to-speech model that can generate speech from text with speaker conditioning. |
| | |
| | This model extends PreTrainedModel to provide text-to-speech capabilities with: |
| | - LLM hidden state conditioning |
| | - Streaming generation |
| | |
| | The model uses a transformer architecture with LLM hidden states and can operate in both |
| | streaming and non-streaming modes for flexible deployment. |
| | |
| | The model process sequence in the following format: |
| | | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token | |
| | |
| | The format is designed to support LLM-conditioned streaming audio generation. |
| | |
| | Usage: |
| | To support streaming generation, two global variables should be maintained outside of the model. |
| | 1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq]. |
| | 2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads] |
| | |
| | where `num_vq` is the number of audio codebooks, in default setting, it is `4`. |
| | |
| | 1. Create an empty `past_key_values` with |
| | ```python |
| | initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token |
| | dtype = model.emb_text.weight.dtype |
| | device = model.emb_text.weight.device |
| | past_key_values = [ |
| | ( |
| | torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), |
| | torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) |
| | ) |
| | for _ in range(model.config.num_hidden_layers) |
| | ] |
| | |
| | 2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder. |
| | |
| | ```python |
| | initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1 |
| | # [bos token, speaker embeddings, text tokens, audio bos token] |
| | audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq) |
| | ``` |
| | |
| | 2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method. |
| | |
| | ```python |
| | outputs = llm.generate(**kwargs) |
| | llm_tokens = some_function_to_extract_llm_tokens(outputs) |
| | lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs) |
| | tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) |
| | # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens. |
| | begin = 0 |
| | end = 9+1 |
| | position_ids = torch.arange(begin, end, dtype=torch.long, device=device) |
| | |
| | past_key_values = model.prefill_text( |
| | input_ids=tts_text_input_ids, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
| | ) |
| | ``` |
| | |
| | 3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention. |
| | |
| | ```python |
| | streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length) |
| | streaming_tts_text_mask[0:end] = 1 # denotes these post |
| | ``` |
| | |
| | 3. Generate audio codes using `generate` method. |
| | |
| | ```python |
| | outputs = model.generate( |
| | input_ids=audio_input_ids, |
| | past_key_values=past_key_values, |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | max_new_token=50, |
| | ) |
| | |
| | # update past_key_values and input_ids |
| | past_key_values = outputs.past_key_values |
| | audio_input_ids = outputs.input_ids |
| | ``` |
| | |
| | The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling. |
| | |
| | 4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response. |
| | |
| | 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above. |
| | """ |
| |
|
| | config_class = ConditionalChatTTSConfig |
| | _no_split_modules = [] |
| |
|
| | def __init__(self, config: ConditionalChatTTSConfig): |
| | super().__init__(config) |
| |
|
| | self.use_speaker_embedding = config.use_speaker_embedding |
| | self.use_llm_hidden_state = config.use_llm_hidden_state |
| | self.num_spk_embs = config.num_spk_embs |
| | self.spk_emb_token_id = config.spk_emb_token_id |
| |
|
| | self.use_text = config.use_text |
| | self.streaming = config.streaming |
| | self.streaming_text_chunk_size = config.streaming_text_chunk_size |
| | self.streaming_audio_chunk_size = config.streaming_audio_chunk_size |
| | self.streaming_text_reserved_len = config.streaming_text_reserved_len |
| | self.audio_bos_token_id = config.audio_bos_token_id |
| | self.num_mel_bins = config.num_mel_bins |
| | self.num_vq = config.num_vq |
| | self.num_audio_tokens = config.num_audio_tokens |
| |
|
| | self.top_p = config.top_p |
| | self.top_k = config.top_k |
| | self.repetition_penalty = config.repetition_penalty |
| |
|
| | if self.config.use_mlp: |
| | self.projector = MultiModalProjector(config.llm_dim, config.hidden_size) |
| | else: |
| | self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False) |
| | self.emb_code = nn.ModuleList( |
| | [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)] |
| | ) |
| | self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) |
| | self.head_code = nn.ModuleList( |
| | [ |
| | weight_norm( |
| | nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), |
| | name="weight", |
| | ) |
| | for _ in range(config.num_vq) |
| | ] |
| | ) |
| | dvae = DVAE() |
| | self.dvae = dvae |
| |
|
| | model_config = LlamaConfig( |
| | hidden_size=config.hidden_size, |
| | intermediate_size=config.intermediate_size, |
| | num_attention_heads=config.num_attention_heads, |
| | num_hidden_layers=config.num_hidden_layers, |
| | max_position_embeddings=config.max_position_embeddings, |
| | attn_implementation=config.attn_implementation, |
| | ) |
| |
|
| | model = LlamaModel(model_config) |
| | self.model = model |
| |
|
| | @torch.inference_mode() |
| | def merge_inputs_embeds( |
| | self, |
| | input_ids: torch.Tensor, |
| | lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
| | ): |
| | """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input token IDs. |
| | lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None. |
| | |
| | Raises: |
| | NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented. |
| | |
| | Returns: |
| | torch.Tensor: Prepared input embeddings for the model. |
| | """ |
| | assert input_ids.shape[0] == 1 |
| |
|
| | |
| | inputs_embeds = self.emb_text(input_ids) |
| |
|
| | |
| | if self.use_speaker_embedding: |
| | spk_emb_mask = input_ids == self.spk_emb_token_id |
| | if spk_emb_mask.any(): |
| | assert lm_spk_emb_last_hidden_states is not None |
| | |
| | lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(self.projector.linear1.weight.dtype) |
| | projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states) |
| | projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1) |
| | apply_spk_emb( |
| | input_ids=input_ids, |
| | spk_emb=projected_spk_emb, |
| | input_embeds=inputs_embeds, |
| | spk_emb_token_id=self.spk_emb_token_id, |
| | num_spk_embs=self.num_spk_embs, |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return inputs_embeds |
| |
|
| | @torch.inference_mode() |
| | def prefill_text( |
| | self, |
| | input_ids: torch.Tensor, |
| | position_ids: torch.LongTensor, |
| | past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| | lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, |
| | ): |
| | """Prefill a chunk of new text tokens in streaming setting. |
| | Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens. |
| | |
| | Args: |
| | input_ids (Tensor): Tensor of shape [batch_size, seq_len] |
| | position_ids (LongTensor): Tensor of shape [batch_size, seq_len] |
| | past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated. |
| | lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None. |
| | lm_last_hidden_states (Tensor, optional): _description_. Defaults to None. |
| | |
| | Note that all `batch_size` should be `1`. |
| | """ |
| | assert input_ids.shape[0] == 1 |
| | assert past_key_values is not None |
| |
|
| | |
| | inputs_embeds = self.merge_inputs_embeds( |
| | input_ids=input_ids, |
| | lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, |
| | ) |
| |
|
| | |
| | past_key_values_for_prefill = [] |
| | for i in range(len(past_key_values)): |
| | past_key_values_for_prefill.append( |
| | ( |
| | past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(), |
| | past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(), |
| | ) |
| | ) |
| |
|
| | |
| | outputs_prefill: BaseModelOutputWithPast = self.model( |
| | attention_mask=None, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values_for_prefill, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=True, |
| | output_attentions=False, |
| | cache_position=position_ids, |
| | ) |
| |
|
| | |
| | past_key_values_for_prefill_updated = outputs_prefill.past_key_values |
| |
|
| | |
| | for layer_idx in range(len(past_key_values)): |
| | |
| | past_key_values[layer_idx][0][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( |
| | past_key_values_for_prefill_updated[layer_idx][0][ |
| | :, :, position_ids[:, 0] : position_ids[:, -1] + 1 |
| | ].clone() |
| | ) |
| | |
| | past_key_values[layer_idx][1][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( |
| | past_key_values_for_prefill_updated[layer_idx][1][ |
| | :, :, position_ids[:, 0] : position_ids[:, -1] + 1 |
| | ].clone() |
| | ) |
| |
|
| | |
| | |
| |
|
| | return past_key_values |
| |
|
| | @torch.inference_mode() |
| | def prefill_audio_ids( |
| | self, |
| | input_ids: torch.Tensor, |
| | past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| | streaming_tts_text_mask=None, |
| | add_audio_bos: bool = True, |
| | ): |
| | """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation. |
| | Specifically, prefill many audio ids (typically from last window) to the model in the new window. |
| | |
| | Args: |
| | input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids. |
| | past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. |
| | """ |
| | assert input_ids.shape[0] == 1 |
| | assert past_key_values is not None |
| |
|
| | code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)] |
| | inputs_embeds = torch.stack(code_emb, 3).sum(3) |
| | input_len = input_ids.shape[1] |
| |
|
| | if add_audio_bos: |
| | narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) |
| | bos_inputs_embeds = self.emb_text(narrowed_input_ids) |
| | inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1) |
| | input_len += 1 |
| |
|
| | past_key_values_length = past_key_values[0][0].shape[2] |
| | position_ids = torch.arange( |
| | past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device |
| | ).unsqueeze(0) |
| |
|
| | cache_position = position_ids.clone() |
| | causal_mask = make_streaming_chunk_mask_generation( |
| | inputs_embeds=inputs_embeds, |
| | past_seen_tokens=past_key_values[0][0].shape[2], |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | streaming_reserved_length=self.streaming_text_reserved_len, |
| | streaming_text_chunk_size=self.streaming_text_chunk_size, |
| | ) |
| |
|
| | |
| | outputs: BaseModelOutputWithPast = self.model( |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=True, |
| | output_attentions=False, |
| | cache_position=cache_position, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | return past_key_values |
| |
|
| | @torch.inference_mode() |
| | def generate( |
| | self, |
| | input_ids: torch.Tensor, |
| | past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| | temperature: torch.Tensor, |
| | eos_token: Union[int, torch.Tensor], |
| | streaming_tts_text_mask=None, |
| | force_no_stop=False, |
| | min_new_token=10, |
| | max_new_token=50, |
| | logits_warpers: List[LogitsWarper] = [], |
| | logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], |
| | show_tqdm=False, |
| | ): |
| | """Generate audio codes in streaming setting or non-streaming setting. |
| | Specifically speaking, generate audio codes when not all text tokens are prefilled. |
| | |
| | Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details. |
| | |
| | In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Input token ids. |
| | past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. |
| | temperature (torch.Tensor): Temperature for sampling. |
| | eos_token (Union[int, torch.Tensor]): End of sequence token. |
| | streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None. |
| | max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50. |
| | logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. |
| | logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. |
| | show_tqdm (bool, optional): Whether to show progress bar. Defaults to True. |
| | |
| | Returns: |
| | GenerationOutputs: Generation outputs. |
| | """ |
| |
|
| | |
| | assert input_ids.shape[0] == 1 |
| | assert past_key_values is not None |
| |
|
| | |
| | |
| | start_idx = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
| |
|
| | finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool() |
| |
|
| | temperature = temperature.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous().view(-1, 1) |
| |
|
| | progress = input_ids.shape[1] |
| |
|
| | |
| | input_ids_buf = torch.zeros( |
| | input_ids.shape[0], |
| | progress + max_new_token, |
| | input_ids.shape[2], |
| | dtype=input_ids.dtype, |
| | device=input_ids.device, |
| | ) |
| |
|
| | |
| | input_ids_buf.narrow(1, 0, progress).copy_(input_ids) |
| |
|
| | del input_ids |
| | input_ids = input_ids_buf.narrow(1, 0, progress) |
| |
|
| | pbar: Optional[tqdm] = None |
| | if show_tqdm: |
| | pbar = tqdm( |
| | total=max_new_token, |
| | desc="code", |
| | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", |
| | ) |
| |
|
| | condition_length = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1 |
| |
|
| | for i in range(max_new_token): |
| | |
| | audio_bos = False |
| |
|
| | |
| | if progress == condition_length: |
| | audio_bos = True |
| |
|
| | assert progress == ( |
| | past_key_values[0][0].shape[2] + 1 |
| | ) |
| |
|
| | if audio_bos: |
| | |
| | narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) |
| | inputs_embeds = self.emb_text(narrowed_input_ids) |
| | del narrowed_input_ids |
| | else: |
| | |
| | narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1) |
| | code_emb = [self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq)] |
| | inputs_embeds = torch.stack(code_emb, 3).sum(3) |
| |
|
| | position_ids = torch.tensor( |
| | [past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device |
| | ).unsqueeze(0) |
| |
|
| | cache_position = position_ids.clone() |
| |
|
| | |
| | causal_mask = make_streaming_chunk_mask_generation( |
| | inputs_embeds=inputs_embeds, |
| | past_seen_tokens=past_key_values[0][0].shape[2], |
| | streaming_tts_text_mask=streaming_tts_text_mask, |
| | streaming_reserved_length=self.streaming_text_reserved_len, |
| | streaming_text_chunk_size=self.streaming_text_chunk_size, |
| | ) |
| |
|
| | |
| | outputs: BaseModelOutputWithPast = self.model( |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=True, |
| | output_attentions=False, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | del position_ids |
| | del inputs_embeds |
| | del cache_position |
| | del causal_mask |
| |
|
| | hidden_states = outputs.last_hidden_state |
| | past_key_values = outputs.past_key_values |
| |
|
| | with P.cached(): |
| | logits = torch.empty( |
| | hidden_states.size(0), |
| | hidden_states.size(1), |
| | self.num_audio_tokens, |
| | self.num_vq, |
| | dtype=torch.float, |
| | device=self.device, |
| | ) |
| | for num_vq_iter in range(self.num_vq): |
| | x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) |
| | logits[..., num_vq_iter] = x |
| | del x |
| |
|
| | del hidden_states |
| |
|
| | |
| | logits = logits.narrow(1, -1, 1).squeeze_(1).float() |
| |
|
| | |
| | logits = logits.permute(0, 2, 1) |
| | logits = logits.reshape(-1, logits.size(2)) |
| | |
| | input_ids_sliced = input_ids.narrow( |
| | 1, |
| | start_idx, |
| | input_ids.size(1) - start_idx, |
| | ).permute(0, 2, 1) |
| | logits_token = input_ids_sliced.reshape( |
| | input_ids_sliced.size(0) * input_ids_sliced.size(1), |
| | -1, |
| | ).to(self.device) |
| | del input_ids_sliced |
| |
|
| | logits /= temperature |
| |
|
| | if not audio_bos: |
| | for logitsProcessors in logits_processors: |
| | logits = logitsProcessors(logits_token, logits) |
| | if not audio_bos: |
| | for logitsWarpers in logits_warpers: |
| | logits = logitsWarpers(logits_token, logits) |
| |
|
| | del logits_token |
| |
|
| | if i < min_new_token: |
| | logits[:, eos_token] = -torch.inf |
| |
|
| | if force_no_stop: |
| | logits[:, eos_token] = -torch.inf |
| |
|
| | scores = F.softmax(logits, dim=-1) |
| |
|
| | del logits |
| | idx_next = torch.multinomial(scores, num_samples=1) |
| |
|
| | del scores |
| |
|
| | |
| | idx_next = idx_next.view(-1, self.num_vq) |
| | finish_or = idx_next.eq(eos_token).any(1) |
| | finish.logical_or_(finish_or) |
| |
|
| | del finish_or |
| | |
| | input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) |
| |
|
| | if i == 0 and finish.any(): |
| | |
| | break |
| |
|
| | del idx_next |
| | progress += 1 |
| | input_ids = input_ids_buf.narrow(1, 0, progress) |
| |
|
| | if finish.all(): |
| | break |
| |
|
| | if pbar is not None: |
| | pbar.update(1) |
| |
|
| | if pbar is not None: |
| | pbar.close() |
| |
|
| | if not finish.all(): |
| | if show_tqdm: |
| | logger.info(f"incomplete result. hit max_new_token: {max_new_token}") |
| |
|
| | del input_ids_buf |
| |
|
| | if finish.all(): |
| | |
| | genrated_input_ids = input_ids[:, condition_length:-1, :] |
| | else: |
| | |
| | genrated_input_ids = input_ids[:, condition_length:, :] |
| |
|
| | return ConditionalChatTTSGenerationOutput( |
| | new_ids=genrated_input_ids, |
| | audio_input_ids=input_ids, |
| | past_key_values=past_key_values, |
| | finished=finish.all(), |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def decode_to_mel_specs( |
| | self, |
| | result_list: List[torch.Tensor], |
| | ): |
| | """Decode discrete audio codes to mel spectrograms. |
| | |
| | Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py` |
| | |
| | Args: |
| | result_list (List[torch.Tensor]): Audio codes output from `generate`. |
| | |
| | Returns: |
| | torch.Tensor: Mel spectrograms. |
| | """ |
| |
|
| | decoder = self.dvae |
| | max_x_len = -1 |
| | if len(result_list) == 0: |
| | return np.array([], dtype=np.float32) |
| | for result in result_list: |
| | if result.size(0) > max_x_len: |
| | max_x_len = result.size(0) |
| | batch_result = torch.zeros( |
| | (len(result_list), result_list[0].size(1), max_x_len), |
| | dtype=result_list[0].dtype, |
| | device=result_list[0].device, |
| | ) |
| | for i in range(len(result_list)): |
| | src = result_list[i] |
| | batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0)) |
| | del src |
| |
|
| | mel_specs = decoder(batch_result) |
| | del batch_result |
| | return mel_specs |
| |
|
| |
|
| | |
| | def gen_logits( |
| | num_code: int, |
| | top_P=0.7, |
| | top_K=20, |
| | repetition_penalty=1.0, |
| | ): |
| | logits_warpers = [] |
| | if top_P is not None: |
| | logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) |
| | if top_K is not None: |
| | logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) |
| |
|
| | logits_processors = [] |
| | if repetition_penalty is not None and repetition_penalty != 1: |
| | logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16)) |
| |
|
| | return logits_warpers, logits_processors |
| |
|
| |
|
| | |
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | past_key_values=None, |
| | attention_mask=None, |
| | inputs_embeds=None, |
| | cache_position=None, |
| | position_ids=None, |
| | use_cache=True, |
| | **kwargs, |
| | ): |
| | if past_key_values is not None: |
| | if isinstance(past_key_values, Cache): |
| | cache_length = past_key_values.get_seq_length() |
| | past_length = past_key_values.seen_tokens |
| | else: |
| | cache_length = past_length = past_key_values[0][0].shape[2] |
| |
|
| | |
| | |
| | |
| | |
| | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| | |
| | |
| | elif past_length < input_ids.shape[1]: |
| | input_ids = input_ids[:, past_length:] |
| | |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | if past_key_values: |
| | position_ids = position_ids[:, -input_ids.shape[1] :] |
| |
|
| | |
| | position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
| |
|
| | |
| | if inputs_embeds is not None and cache_position[0] == 0: |
| | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
| | else: |
| | |
| | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
| |
|
| | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: |
| | if model_inputs["inputs_embeds"] is not None: |
| | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
| | device = model_inputs["inputs_embeds"].device |
| | else: |
| | batch_size, sequence_length = model_inputs["input_ids"].shape |
| | device = model_inputs["input_ids"].device |
| |
|
| | dtype = self.lm_head.weight.dtype |
| | min_dtype = torch.finfo(dtype).min |
| |
|
| | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| | attention_mask, |
| | sequence_length=sequence_length, |
| | target_length=past_key_values.get_max_length(), |
| | dtype=dtype, |
| | device=device, |
| | min_dtype=min_dtype, |
| | cache_position=cache_position, |
| | batch_size=batch_size, |
| | ) |
| |
|
| | model_inputs.update( |
| | { |
| | "position_ids": position_ids, |
| | |
| | "past_key_values": past_key_values, |
| | "use_cache": use_cache, |
| | "attention_mask": attention_mask, |
| | } |
| | ) |
| | return model_inputs |
| |
|