| | import os |
| |
|
| | os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' |
| | import time |
| | from subprocess import CalledProcessError |
| | from typing import Dict, List |
| |
|
| | import torch |
| | import torchaudio |
| | from torch.nn.utils.rnn import pad_sequence |
| | from omegaconf import OmegaConf |
| | from tqdm import tqdm |
| |
|
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| |
|
| | from indextts.BigVGAN.models import BigVGAN as Generator |
| | from indextts.gpt.model import UnifiedVoice |
| | from indextts.utils.checkpoint import load_checkpoint |
| | from indextts.utils.feature_extractors import MelSpectrogramFeatures |
| |
|
| | from indextts.utils.front import TextNormalizer, TextTokenizer |
| |
|
| |
|
| | class IndexTTS: |
| | def __init__( |
| | self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None, |
| | use_cuda_kernel=None, |
| | ): |
| | """ |
| | Args: |
| | cfg_path (str): path to the config file. |
| | model_dir (str): path to the model directory. |
| | use_fp16 (bool): whether to use fp16. |
| | device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. |
| | use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. |
| | """ |
| | if device is not None: |
| | self.device = device |
| | self.use_fp16 = False if device == "cpu" else use_fp16 |
| | self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") |
| | elif torch.cuda.is_available(): |
| | self.device = "cuda:0" |
| | self.use_fp16 = use_fp16 |
| | self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel |
| | elif hasattr(torch, "xpu") and torch.xpu.is_available(): |
| | self.device = "xpu" |
| | self.use_fp16 = use_fp16 |
| | self.use_cuda_kernel = False |
| | elif hasattr(torch, "mps") and torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | self.use_fp16 = False |
| | self.use_cuda_kernel = False |
| | else: |
| | self.device = "cpu" |
| | self.use_fp16 = False |
| | self.use_cuda_kernel = False |
| | print(">> Be patient, it may take a while to run in CPU mode.") |
| |
|
| | self.cfg = OmegaConf.load(cfg_path) |
| | self.model_dir = model_dir |
| | self.dtype = torch.float16 if self.use_fp16 else None |
| | self.stop_mel_token = self.cfg.gpt.stop_mel_token |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.gpt = UnifiedVoice(**self.cfg.gpt) |
| | self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) |
| | load_checkpoint(self.gpt, self.gpt_path) |
| | self.gpt = self.gpt.to(self.device) |
| | if self.use_fp16: |
| | self.gpt.eval().half() |
| | else: |
| | self.gpt.eval() |
| | print(">> GPT weights restored from:", self.gpt_path) |
| | if self.use_fp16: |
| | try: |
| | import deepspeed |
| |
|
| | use_deepspeed = True |
| | except (ImportError, OSError, CalledProcessError) as e: |
| | use_deepspeed = False |
| | print(f">> DeepSpeed加载失败,回退到标准推理: {e}") |
| |
|
| | self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True) |
| | else: |
| | self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) |
| |
|
| | if self.use_cuda_kernel: |
| | |
| | try: |
| | from indextts.BigVGAN.alias_free_activation.cuda import load |
| |
|
| | anti_alias_activation_cuda = load.load() |
| | print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda) |
| | except: |
| | print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.") |
| | self.use_cuda_kernel = False |
| | self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel) |
| | self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint) |
| | vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu") |
| | self.bigvgan.load_state_dict(vocoder_dict["generator"]) |
| | self.bigvgan = self.bigvgan.to(self.device) |
| | |
| | self.bigvgan.remove_weight_norm() |
| | self.bigvgan.eval() |
| | print(">> bigvgan weights restored from:", self.bigvgan_path) |
| | self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"]) |
| | self.normalizer = TextNormalizer() |
| | self.normalizer.load() |
| | print(">> TextNormalizer loaded") |
| | self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer) |
| | print(">> bpe model loaded from:", self.bpe_path) |
| | |
| | self.cache_audio_prompt = None |
| | self.cache_cond_mel = None |
| | |
| | self.gr_progress = None |
| | self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None |
| |
|
| | def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30): |
| | """ |
| | Shrink special tokens (silent_token and stop_mel_token) in codes |
| | codes: [B, T] |
| | """ |
| | code_lens = [] |
| | codes_list = [] |
| | device = codes.device |
| | dtype = codes.dtype |
| | isfix = False |
| | for i in range(0, codes.shape[0]): |
| | code = codes[i] |
| | if not torch.any(code == self.stop_mel_token).item(): |
| | len_ = code.size(0) |
| | else: |
| | stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False) |
| | len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0) |
| |
|
| | count = torch.sum(code == silent_token).item() |
| | if count > max_consecutive: |
| | |
| | ncode_idx = [] |
| | n = 0 |
| | for k in range(len_): |
| | assert code[ |
| | k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here" |
| | if code[k] != silent_token: |
| | ncode_idx.append(k) |
| | n = 0 |
| | elif code[k] == silent_token and n < 10: |
| | ncode_idx.append(k) |
| | n += 1 |
| | |
| | |
| | |
| | len_ = len(ncode_idx) |
| | codes_list.append(code[ncode_idx]) |
| | isfix = True |
| | else: |
| | |
| | codes_list.append(code[:len_]) |
| | code_lens.append(len_) |
| | if isfix: |
| | if len(codes_list) > 1: |
| | codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token) |
| | else: |
| | codes = codes_list[0].unsqueeze(0) |
| | else: |
| | |
| | pass |
| | |
| | max_len = max(code_lens) |
| | if max_len < codes.shape[1]: |
| | codes = codes[:, :max_len] |
| | code_lens = torch.tensor(code_lens, dtype=torch.long, device=device) |
| | return codes, code_lens |
| |
|
| | def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]: |
| | """ |
| | Segment data bucketing. |
| | if ``bucket_max_size=1``, return all segments in one bucket. |
| | """ |
| | outputs: List[Dict] = [] |
| | for idx, sent in enumerate(segments): |
| | outputs.append({"idx": idx, "sent": sent, "len": len(sent)}) |
| |
|
| | if len(outputs) > bucket_max_size: |
| | |
| | buckets: List[List[Dict]] = [] |
| | factor = 1.5 |
| | last_bucket = None |
| | last_bucket_sent_len_median = 0 |
| |
|
| | for sent in sorted(outputs, key=lambda x: x["len"]): |
| | current_sent_len = sent["len"] |
| | if current_sent_len == 0: |
| | print(">> skip empty segment") |
| | continue |
| | if last_bucket is None \ |
| | or current_sent_len >= int(last_bucket_sent_len_median * factor) \ |
| | or len(last_bucket) >= bucket_max_size: |
| | |
| | buckets.append([sent]) |
| | last_bucket = buckets[-1] |
| | last_bucket_sent_len_median = current_sent_len |
| | else: |
| | |
| | last_bucket.append(sent) |
| | mid = len(last_bucket) // 2 |
| | last_bucket_sent_len_median = last_bucket[mid]["len"] |
| | last_bucket = None |
| | |
| | out_buckets: List[List[Dict]] = [] |
| | only_ones: List[Dict] = [] |
| | for b in buckets: |
| | if len(b) == 1: |
| | only_ones.append(b[0]) |
| | else: |
| | out_buckets.append(b) |
| | if len(only_ones) > 0: |
| | |
| | |
| | for i in range(len(out_buckets)): |
| | b = out_buckets[i] |
| | if len(b) < bucket_max_size: |
| | b.append(only_ones.pop(0)) |
| | if len(only_ones) == 0: |
| | break |
| | |
| | if len(only_ones) > 0: |
| | out_buckets.extend( |
| | [only_ones[i:i + bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)]) |
| | return out_buckets |
| | return [outputs] |
| |
|
| | def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor: |
| | if self.model_version and self.model_version >= 1.5: |
| | |
| | |
| | tokens = [t.squeeze(0) for t in tokens] |
| | return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, |
| | padding_side="right") |
| | max_len = max(t.size(1) for t in tokens) |
| | outputs = [] |
| | for tensor in tokens: |
| | pad_len = max_len - tensor.size(1) |
| | if pad_len > 0: |
| | n = min(8, pad_len) |
| | tensor = torch.nn.functional.pad(tensor, (0, n), value=self.cfg.gpt.stop_text_token) |
| | tensor = torch.nn.functional.pad(tensor, (0, pad_len - n), value=self.cfg.gpt.start_text_token) |
| | tensor = tensor[:, :max_len] |
| | outputs.append(tensor) |
| | tokens = torch.cat(outputs, dim=0) |
| | return tokens |
| |
|
| | def torch_empty_cache(self): |
| | try: |
| | if "cuda" in str(self.device): |
| | torch.cuda.empty_cache() |
| | elif "mps" in str(self.device): |
| | torch.mps.empty_cache() |
| | except Exception as e: |
| | pass |
| |
|
| | def _set_gr_progress(self, value, desc): |
| | if self.gr_progress is not None: |
| | self.gr_progress(value, desc=desc) |
| |
|
| | |
| | def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100, |
| | segments_bucket_max_size=4, **generation_kwargs): |
| | """ |
| | Args: |
| | ``max_text_tokens_per_segment``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整 |
| | - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量 |
| | - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 |
| | ``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整 |
| | - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量 |
| | - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 |
| | """ |
| | print(">> starting fast inference...") |
| |
|
| | self._set_gr_progress(0, "starting fast inference...") |
| | if verbose: |
| | print(f"origin text:{text}") |
| | start_time = time.perf_counter() |
| |
|
| | |
| | if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt: |
| | audio, sr = torchaudio.load(audio_prompt) |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| | if audio.shape[0] > 1: |
| | audio = audio[0].unsqueeze(0) |
| | audio = torchaudio.transforms.Resample(sr, 24000)(audio) |
| |
|
| | max_audio_length_seconds = 50 |
| | max_audio_samples = int(max_audio_length_seconds * 24000) |
| | |
| | if audio.shape[1] > max_audio_samples: |
| | if verbose: |
| | print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples") |
| | audio = audio[:, :max_audio_samples] |
| |
|
| | cond_mel = MelSpectrogramFeatures()(audio).to(self.device) |
| | cond_mel_frame = cond_mel.shape[-1] |
| | if verbose: |
| | print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) |
| |
|
| | self.cache_audio_prompt = audio_prompt |
| | self.cache_cond_mel = cond_mel |
| | else: |
| | cond_mel = self.cache_cond_mel |
| | cond_mel_frame = cond_mel.shape[-1] |
| | pass |
| |
|
| | auto_conditioning = cond_mel |
| | cond_mel_lengths = torch.tensor([cond_mel_frame], device=self.device) |
| |
|
| | |
| | text_tokens_list = self.tokenizer.tokenize(text) |
| |
|
| | segments = self.tokenizer.split_segments(text_tokens_list, |
| | max_text_tokens_per_segment=max_text_tokens_per_segment) |
| | if verbose: |
| | print(">> text token count:", len(text_tokens_list)) |
| | print(" segments count:", len(segments)) |
| | print(" max_text_tokens_per_segment:", max_text_tokens_per_segment) |
| | print(*segments, sep="\n") |
| | do_sample = generation_kwargs.pop("do_sample", True) |
| | top_p = generation_kwargs.pop("top_p", 0.8) |
| | top_k = generation_kwargs.pop("top_k", 30) |
| | temperature = generation_kwargs.pop("temperature", 1.0) |
| | autoregressive_batch_size = 1 |
| | length_penalty = generation_kwargs.pop("length_penalty", 0.0) |
| | num_beams = generation_kwargs.pop("num_beams", 3) |
| | repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) |
| | max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) |
| | sampling_rate = 24000 |
| | |
| | |
| | wavs = [] |
| | gpt_gen_time = 0 |
| | gpt_forward_time = 0 |
| | bigvgan_time = 0 |
| |
|
| | |
| | all_text_tokens: List[List[torch.Tensor]] = [] |
| | self._set_gr_progress(0.1, "text processing...") |
| | bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1 |
| | all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size) |
| | bucket_count = len(all_segments) |
| | if verbose: |
| | print(">> segments bucket_count:", bucket_count, |
| | "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments], |
| | "bucket_max_size:", bucket_max_size) |
| | for segments in all_segments: |
| | temp_tokens: List[torch.Tensor] = [] |
| | all_text_tokens.append(temp_tokens) |
| | for item in segments: |
| | sent = item["sent"] |
| | text_tokens = self.tokenizer.convert_tokens_to_ids(sent) |
| | text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) |
| | if verbose: |
| | print(text_tokens) |
| | print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") |
| | |
| | text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) |
| | print("text_token_syms is same as segment tokens", text_token_syms == sent) |
| | temp_tokens.append(text_tokens) |
| |
|
| | |
| | all_batch_num = sum(len(s) for s in all_segments) |
| | all_batch_codes = [] |
| | processed_num = 0 |
| | for item_tokens in all_text_tokens: |
| | batch_num = len(item_tokens) |
| | if batch_num > 1: |
| | batch_text_tokens = self.pad_tokens_cat(item_tokens) |
| | else: |
| | batch_text_tokens = item_tokens[0] |
| | processed_num += batch_num |
| | |
| | self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num, |
| | f"gpt speech inference {processed_num}/{all_batch_num}...") |
| | m_start_time = time.perf_counter() |
| | with torch.no_grad(): |
| | with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, |
| | dtype=self.dtype): |
| | temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens, |
| | cond_mel_lengths=cond_mel_lengths, |
| | |
| | do_sample=do_sample, |
| | top_p=top_p, |
| | top_k=top_k, |
| | temperature=temperature, |
| | num_return_sequences=autoregressive_batch_size, |
| | length_penalty=length_penalty, |
| | num_beams=num_beams, |
| | repetition_penalty=repetition_penalty, |
| | max_generate_length=max_mel_tokens, |
| | **generation_kwargs) |
| | all_batch_codes.append(temp_codes) |
| | gpt_gen_time += time.perf_counter() - m_start_time |
| |
|
| | |
| | self._set_gr_progress(0.5, "gpt latents inference...") |
| | all_idxs = [] |
| | all_latents = [] |
| | has_warned = False |
| | for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments): |
| | for i in range(batch_codes.shape[0]): |
| | codes = batch_codes[i] |
| | if not has_warned and codes[-1] != self.stop_mel_token: |
| | warnings.warn( |
| | f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " |
| | f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.", |
| | category=RuntimeWarning |
| | ) |
| | has_warned = True |
| | codes = codes.unsqueeze(0) |
| | if verbose: |
| | print("codes:", codes.shape) |
| | print(codes) |
| | codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) |
| | if verbose: |
| | print("fix codes:", codes.shape) |
| | print(codes) |
| | print("code_lens:", code_lens) |
| | text_tokens = batch_tokens[i] |
| | all_idxs.append(batch_segments[i]["idx"]) |
| | m_start_time = time.perf_counter() |
| | with torch.no_grad(): |
| | with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): |
| | latent = \ |
| | self.gpt(auto_conditioning, text_tokens, |
| | torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, |
| | code_lens * self.gpt.mel_length_compression, |
| | cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], |
| | device=text_tokens.device), |
| | return_latent=True, clip_inputs=False) |
| | gpt_forward_time += time.perf_counter() - m_start_time |
| | all_latents.append(latent) |
| | del all_batch_codes, all_text_tokens, all_segments |
| | |
| | chunk_size = 2 |
| | all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] |
| | if verbose: |
| | print(">> all_latents:", len(all_latents)) |
| | print(" latents length:", [l.shape[1] for l in all_latents]) |
| | chunk_latents = [all_latents[i: i + chunk_size] for i in range(0, len(all_latents), chunk_size)] |
| | chunk_length = len(chunk_latents) |
| | latent_length = len(all_latents) |
| |
|
| | |
| | self._set_gr_progress(0.7, "bigvgan decoding...") |
| | tqdm_progress = tqdm(total=latent_length, desc="bigvgan") |
| | for items in chunk_latents: |
| | tqdm_progress.update(len(items)) |
| | latent = torch.cat(items, dim=1) |
| | with torch.no_grad(): |
| | with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype): |
| | m_start_time = time.perf_counter() |
| | wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) |
| | bigvgan_time += time.perf_counter() - m_start_time |
| | wav = wav.squeeze(1) |
| | pass |
| | wav = torch.clamp(32767 * wav, -32767.0, 32767.0) |
| | wavs.append(wav.cpu()) |
| |
|
| | |
| | tqdm_progress.close() |
| | del all_latents, chunk_latents |
| | end_time = time.perf_counter() |
| | self.torch_empty_cache() |
| |
|
| | |
| | self._set_gr_progress(0.9, "saving audio...") |
| | wav = torch.cat(wavs, dim=1) |
| | wav_length = wav.shape[-1] / sampling_rate |
| | print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds") |
| | print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") |
| | print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") |
| | print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") |
| | print(f">> Total fast inference time: {end_time - start_time:.2f} seconds") |
| | print(f">> Generated audio length: {wav_length:.2f} seconds") |
| | print(f">> [fast] bigvgan chunk_length: {chunk_length}") |
| | print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}", |
| | f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "") |
| | print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}") |
| |
|
| | |
| | wav = wav.cpu() |
| | if output_path: |
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) |
| | print(">> wav file saved to:", output_path) |
| | return output_path |
| | else: |
| | |
| | wav_data = wav.type(torch.int16) |
| | wav_data = wav_data.numpy().T |
| | return (sampling_rate, wav_data) |
| |
|
| | |
| | def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120, |
| | **generation_kwargs): |
| | print(">> starting inference...") |
| | self._set_gr_progress(0, "starting inference...") |
| | if verbose: |
| | print(f"origin text:{text}") |
| | start_time = time.perf_counter() |
| |
|
| | |
| | if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt: |
| | audio, sr = torchaudio.load(audio_prompt) |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| | if audio.shape[0] > 1: |
| | audio = audio[0].unsqueeze(0) |
| | audio = torchaudio.transforms.Resample(sr, 24000)(audio) |
| | cond_mel = MelSpectrogramFeatures()(audio).to(self.device) |
| | cond_mel_frame = cond_mel.shape[-1] |
| | if verbose: |
| | print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) |
| |
|
| | self.cache_audio_prompt = audio_prompt |
| | self.cache_cond_mel = cond_mel |
| | else: |
| | cond_mel = self.cache_cond_mel |
| | cond_mel_frame = cond_mel.shape[-1] |
| | pass |
| |
|
| | self._set_gr_progress(0.1, "text processing...") |
| | auto_conditioning = cond_mel |
| | text_tokens_list = self.tokenizer.tokenize(text) |
| | segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment) |
| | if verbose: |
| | print("text token count:", len(text_tokens_list)) |
| | print("segments count:", len(segments)) |
| | print("max_text_tokens_per_segment:", max_text_tokens_per_segment) |
| | print(*segments, sep="\n") |
| | do_sample = generation_kwargs.pop("do_sample", True) |
| | top_p = generation_kwargs.pop("top_p", 0.8) |
| | top_k = generation_kwargs.pop("top_k", 30) |
| | temperature = generation_kwargs.pop("temperature", 1.0) |
| | autoregressive_batch_size = 1 |
| | length_penalty = generation_kwargs.pop("length_penalty", 0.0) |
| | num_beams = generation_kwargs.pop("num_beams", 3) |
| | repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) |
| | max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) |
| | sampling_rate = 24000 |
| | |
| | |
| | wavs = [] |
| | gpt_gen_time = 0 |
| | gpt_forward_time = 0 |
| | bigvgan_time = 0 |
| | progress = 0 |
| | has_warned = False |
| | for sent in segments: |
| | text_tokens = self.tokenizer.convert_tokens_to_ids(sent) |
| | text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) |
| | |
| | |
| | |
| | if verbose: |
| | print(text_tokens) |
| | print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") |
| | |
| | text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) |
| | print("text_token_syms is same as segment tokens", text_token_syms == sent) |
| |
|
| | |
| | |
| | progress += 1 |
| | self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments), |
| | f"gpt latents inference {progress}/{len(segments)}...") |
| | m_start_time = time.perf_counter() |
| | with torch.no_grad(): |
| | with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): |
| | codes = self.gpt.inference_speech(auto_conditioning, text_tokens, |
| | cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], |
| | device=text_tokens.device), |
| | |
| | do_sample=do_sample, |
| | top_p=top_p, |
| | top_k=top_k, |
| | temperature=temperature, |
| | num_return_sequences=autoregressive_batch_size, |
| | length_penalty=length_penalty, |
| | num_beams=num_beams, |
| | repetition_penalty=repetition_penalty, |
| | max_generate_length=max_mel_tokens, |
| | **generation_kwargs) |
| | gpt_gen_time += time.perf_counter() - m_start_time |
| | if not has_warned and (codes[:, -1] != self.stop_mel_token).any(): |
| | warnings.warn( |
| | f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " |
| | f"Input text tokens: {text_tokens.shape[1]}. " |
| | f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.", |
| | category=RuntimeWarning |
| | ) |
| | has_warned = True |
| |
|
| | code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) |
| | if verbose: |
| | print(codes, type(codes)) |
| | print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") |
| | print(f"code len: {code_lens}") |
| |
|
| | |
| | |
| | codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) |
| | if verbose: |
| | print(codes, type(codes)) |
| | print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") |
| | print(f"code len: {code_lens}") |
| | self._set_gr_progress(0.2 + 0.4 * progress / len(segments), |
| | f"gpt speech inference {progress}/{len(segments)}...") |
| | m_start_time = time.perf_counter() |
| | |
| | with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): |
| | latent = \ |
| | self.gpt(auto_conditioning, text_tokens, |
| | torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, |
| | code_lens * self.gpt.mel_length_compression, |
| | cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], |
| | device=text_tokens.device), |
| | return_latent=True, clip_inputs=False) |
| | gpt_forward_time += time.perf_counter() - m_start_time |
| |
|
| | m_start_time = time.perf_counter() |
| | wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) |
| | bigvgan_time += time.perf_counter() - m_start_time |
| | wav = wav.squeeze(1) |
| |
|
| | wav = torch.clamp(32767 * wav, -32767.0, 32767.0) |
| | if verbose: |
| | print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) |
| | |
| | wavs.append(wav.cpu()) |
| | end_time = time.perf_counter() |
| | self._set_gr_progress(0.9, "saving audio...") |
| | wav = torch.cat(wavs, dim=1) |
| | wav_length = wav.shape[-1] / sampling_rate |
| | print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds") |
| | print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") |
| | print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") |
| | print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") |
| | print(f">> Total inference time: {end_time - start_time:.2f} seconds") |
| | print(f">> Generated audio length: {wav_length:.2f} seconds") |
| | print(f">> RTF: {(end_time - start_time) / wav_length:.4f}") |
| |
|
| | |
| | wav = wav.cpu() |
| | if output_path: |
| | |
| | if os.path.isfile(output_path): |
| | os.remove(output_path) |
| | print(">> remove old wav file:", output_path) |
| | if os.path.dirname(output_path) != "": |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) |
| | print(">> wav file saved to:", output_path) |
| | return output_path |
| | else: |
| | |
| | wav_data = wav.type(torch.int16) |
| | wav_data = wav_data.numpy().T |
| | return (sampling_rate, wav_data) |
| |
|
| | if __name__ == "__main__": |
| | prompt_wav = "examples/voice_01.wav" |
| | text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。' |
| |
|
| | tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False) |
| | tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) |
| |
|