|
|
import os
|
|
|
|
|
|
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
|
|
|
import time
|
|
|
from subprocess import CalledProcessError
|
|
|
from typing import Dict, List
|
|
|
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
from omegaconf import OmegaConf
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
|
from indextts.BigVGAN.models import BigVGAN as Generator
|
|
|
from indextts.gpt.model import UnifiedVoice
|
|
|
from indextts.utils.checkpoint import load_checkpoint
|
|
|
from indextts.utils.feature_extractors import MelSpectrogramFeatures
|
|
|
|
|
|
from indextts.utils.front import TextNormalizer, TextTokenizer
|
|
|
|
|
|
|
|
|
class IndexTTS:
|
|
|
def __init__(
|
|
|
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None,
|
|
|
use_cuda_kernel=None,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
cfg_path (str): path to the config file.
|
|
|
model_dir (str): path to the model directory.
|
|
|
use_fp16 (bool): whether to use fp16.
|
|
|
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
|
|
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
|
|
"""
|
|
|
if device is not None:
|
|
|
self.device = device
|
|
|
self.use_fp16 = False if device == "cpu" else use_fp16
|
|
|
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
|
|
elif torch.cuda.is_available():
|
|
|
self.device = "cuda:0"
|
|
|
self.use_fp16 = use_fp16
|
|
|
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
|
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
|
self.device = "xpu"
|
|
|
self.use_fp16 = use_fp16
|
|
|
self.use_cuda_kernel = False
|
|
|
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
|
|
self.device = "mps"
|
|
|
self.use_fp16 = False
|
|
|
self.use_cuda_kernel = False
|
|
|
else:
|
|
|
self.device = "cpu"
|
|
|
self.use_fp16 = False
|
|
|
self.use_cuda_kernel = False
|
|
|
print(">> Be patient, it may take a while to run in CPU mode.")
|
|
|
|
|
|
self.cfg = OmegaConf.load(cfg_path)
|
|
|
self.model_dir = model_dir
|
|
|
self.dtype = torch.float16 if self.use_fp16 else None
|
|
|
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.gpt = UnifiedVoice(**self.cfg.gpt)
|
|
|
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
|
|
load_checkpoint(self.gpt, self.gpt_path)
|
|
|
self.gpt = self.gpt.to(self.device)
|
|
|
if self.use_fp16:
|
|
|
self.gpt.eval().half()
|
|
|
else:
|
|
|
self.gpt.eval()
|
|
|
print(">> GPT weights restored from:", self.gpt_path)
|
|
|
if self.use_fp16:
|
|
|
try:
|
|
|
import deepspeed
|
|
|
|
|
|
use_deepspeed = True
|
|
|
except (ImportError, OSError, CalledProcessError) as e:
|
|
|
use_deepspeed = False
|
|
|
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
|
|
|
|
|
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
|
|
else:
|
|
|
self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False)
|
|
|
|
|
|
if self.use_cuda_kernel:
|
|
|
|
|
|
try:
|
|
|
from indextts.BigVGAN.alias_free_activation.cuda import load
|
|
|
|
|
|
anti_alias_activation_cuda = load.load()
|
|
|
print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
|
|
|
except:
|
|
|
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
|
|
self.use_cuda_kernel = False
|
|
|
self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel)
|
|
|
self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint)
|
|
|
vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu")
|
|
|
self.bigvgan.load_state_dict(vocoder_dict["generator"])
|
|
|
self.bigvgan = self.bigvgan.to(self.device)
|
|
|
|
|
|
self.bigvgan.remove_weight_norm()
|
|
|
self.bigvgan.eval()
|
|
|
print(">> bigvgan weights restored from:", self.bigvgan_path)
|
|
|
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
|
|
|
self.normalizer = TextNormalizer()
|
|
|
self.normalizer.load()
|
|
|
print(">> TextNormalizer loaded")
|
|
|
self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer)
|
|
|
print(">> bpe model loaded from:", self.bpe_path)
|
|
|
|
|
|
self.cache_audio_prompt = None
|
|
|
self.cache_cond_mel = None
|
|
|
|
|
|
self.gr_progress = None
|
|
|
self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None
|
|
|
|
|
|
def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30):
|
|
|
"""
|
|
|
Shrink special tokens (silent_token and stop_mel_token) in codes
|
|
|
codes: [B, T]
|
|
|
"""
|
|
|
code_lens = []
|
|
|
codes_list = []
|
|
|
device = codes.device
|
|
|
dtype = codes.dtype
|
|
|
isfix = False
|
|
|
for i in range(0, codes.shape[0]):
|
|
|
code = codes[i]
|
|
|
if not torch.any(code == self.stop_mel_token).item():
|
|
|
len_ = code.size(0)
|
|
|
else:
|
|
|
stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False)
|
|
|
len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0)
|
|
|
|
|
|
count = torch.sum(code == silent_token).item()
|
|
|
if count > max_consecutive:
|
|
|
|
|
|
ncode_idx = []
|
|
|
n = 0
|
|
|
for k in range(len_):
|
|
|
assert code[
|
|
|
k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here"
|
|
|
if code[k] != silent_token:
|
|
|
ncode_idx.append(k)
|
|
|
n = 0
|
|
|
elif code[k] == silent_token and n < 10:
|
|
|
ncode_idx.append(k)
|
|
|
n += 1
|
|
|
|
|
|
|
|
|
|
|
|
len_ = len(ncode_idx)
|
|
|
codes_list.append(code[ncode_idx])
|
|
|
isfix = True
|
|
|
else:
|
|
|
|
|
|
codes_list.append(code[:len_])
|
|
|
code_lens.append(len_)
|
|
|
if isfix:
|
|
|
if len(codes_list) > 1:
|
|
|
codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token)
|
|
|
else:
|
|
|
codes = codes_list[0].unsqueeze(0)
|
|
|
else:
|
|
|
|
|
|
pass
|
|
|
|
|
|
max_len = max(code_lens)
|
|
|
if max_len < codes.shape[1]:
|
|
|
codes = codes[:, :max_len]
|
|
|
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
|
|
return codes, code_lens
|
|
|
|
|
|
def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]:
|
|
|
"""
|
|
|
Segment data bucketing.
|
|
|
if ``bucket_max_size=1``, return all segments in one bucket.
|
|
|
"""
|
|
|
outputs: List[Dict] = []
|
|
|
for idx, sent in enumerate(segments):
|
|
|
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
|
|
|
|
|
|
if len(outputs) > bucket_max_size:
|
|
|
|
|
|
buckets: List[List[Dict]] = []
|
|
|
factor = 1.5
|
|
|
last_bucket = None
|
|
|
last_bucket_sent_len_median = 0
|
|
|
|
|
|
for sent in sorted(outputs, key=lambda x: x["len"]):
|
|
|
current_sent_len = sent["len"]
|
|
|
if current_sent_len == 0:
|
|
|
print(">> skip empty segment")
|
|
|
continue
|
|
|
if last_bucket is None \
|
|
|
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
|
|
|
or len(last_bucket) >= bucket_max_size:
|
|
|
|
|
|
buckets.append([sent])
|
|
|
last_bucket = buckets[-1]
|
|
|
last_bucket_sent_len_median = current_sent_len
|
|
|
else:
|
|
|
|
|
|
last_bucket.append(sent)
|
|
|
mid = len(last_bucket) // 2
|
|
|
last_bucket_sent_len_median = last_bucket[mid]["len"]
|
|
|
last_bucket = None
|
|
|
|
|
|
out_buckets: List[List[Dict]] = []
|
|
|
only_ones: List[Dict] = []
|
|
|
for b in buckets:
|
|
|
if len(b) == 1:
|
|
|
only_ones.append(b[0])
|
|
|
else:
|
|
|
out_buckets.append(b)
|
|
|
if len(only_ones) > 0:
|
|
|
|
|
|
|
|
|
for i in range(len(out_buckets)):
|
|
|
b = out_buckets[i]
|
|
|
if len(b) < bucket_max_size:
|
|
|
b.append(only_ones.pop(0))
|
|
|
if len(only_ones) == 0:
|
|
|
break
|
|
|
|
|
|
if len(only_ones) > 0:
|
|
|
out_buckets.extend(
|
|
|
[only_ones[i:i + bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)])
|
|
|
return out_buckets
|
|
|
return [outputs]
|
|
|
|
|
|
def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor:
|
|
|
if self.model_version and self.model_version >= 1.5:
|
|
|
|
|
|
|
|
|
tokens = [t.squeeze(0) for t in tokens]
|
|
|
return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token,
|
|
|
padding_side="right")
|
|
|
max_len = max(t.size(1) for t in tokens)
|
|
|
outputs = []
|
|
|
for tensor in tokens:
|
|
|
pad_len = max_len - tensor.size(1)
|
|
|
if pad_len > 0:
|
|
|
n = min(8, pad_len)
|
|
|
tensor = torch.nn.functional.pad(tensor, (0, n), value=self.cfg.gpt.stop_text_token)
|
|
|
tensor = torch.nn.functional.pad(tensor, (0, pad_len - n), value=self.cfg.gpt.start_text_token)
|
|
|
tensor = tensor[:, :max_len]
|
|
|
outputs.append(tensor)
|
|
|
tokens = torch.cat(outputs, dim=0)
|
|
|
return tokens
|
|
|
|
|
|
def torch_empty_cache(self):
|
|
|
try:
|
|
|
if "cuda" in str(self.device):
|
|
|
torch.cuda.empty_cache()
|
|
|
elif "mps" in str(self.device):
|
|
|
torch.mps.empty_cache()
|
|
|
except Exception as e:
|
|
|
pass
|
|
|
|
|
|
def _set_gr_progress(self, value, desc):
|
|
|
if self.gr_progress is not None:
|
|
|
self.gr_progress(value, desc=desc)
|
|
|
|
|
|
|
|
|
def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100,
|
|
|
segments_bucket_max_size=4, **generation_kwargs):
|
|
|
"""
|
|
|
Args:
|
|
|
``max_text_tokens_per_segment``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
|
|
|
- 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
|
|
|
- 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
|
|
``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
|
|
|
- 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
|
|
|
- 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
|
|
"""
|
|
|
print(">> starting fast inference...")
|
|
|
|
|
|
self._set_gr_progress(0, "starting fast inference...")
|
|
|
if verbose:
|
|
|
print(f"origin text:{text}")
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
|
|
|
audio, sr = torchaudio.load(audio_prompt)
|
|
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
if audio.shape[0] > 1:
|
|
|
audio = audio[0].unsqueeze(0)
|
|
|
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
|
|
|
|
|
max_audio_length_seconds = 50
|
|
|
max_audio_samples = int(max_audio_length_seconds * 24000)
|
|
|
|
|
|
if audio.shape[1] > max_audio_samples:
|
|
|
if verbose:
|
|
|
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
|
|
|
audio = audio[:, :max_audio_samples]
|
|
|
|
|
|
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
|
|
cond_mel_frame = cond_mel.shape[-1]
|
|
|
if verbose:
|
|
|
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
|
|
|
|
|
|
self.cache_audio_prompt = audio_prompt
|
|
|
self.cache_cond_mel = cond_mel
|
|
|
else:
|
|
|
cond_mel = self.cache_cond_mel
|
|
|
cond_mel_frame = cond_mel.shape[-1]
|
|
|
pass
|
|
|
|
|
|
auto_conditioning = cond_mel
|
|
|
cond_mel_lengths = torch.tensor([cond_mel_frame], device=self.device)
|
|
|
|
|
|
|
|
|
text_tokens_list = self.tokenizer.tokenize(text)
|
|
|
|
|
|
segments = self.tokenizer.split_segments(text_tokens_list,
|
|
|
max_text_tokens_per_segment=max_text_tokens_per_segment)
|
|
|
if verbose:
|
|
|
print(">> text token count:", len(text_tokens_list))
|
|
|
print(" segments count:", len(segments))
|
|
|
print(" max_text_tokens_per_segment:", max_text_tokens_per_segment)
|
|
|
print(*segments, sep="\n")
|
|
|
do_sample = generation_kwargs.pop("do_sample", True)
|
|
|
top_p = generation_kwargs.pop("top_p", 0.8)
|
|
|
top_k = generation_kwargs.pop("top_k", 30)
|
|
|
temperature = generation_kwargs.pop("temperature", 1.0)
|
|
|
autoregressive_batch_size = 1
|
|
|
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
|
|
|
num_beams = generation_kwargs.pop("num_beams", 3)
|
|
|
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
|
|
|
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
|
|
|
sampling_rate = 24000
|
|
|
|
|
|
|
|
|
wavs = []
|
|
|
gpt_gen_time = 0
|
|
|
gpt_forward_time = 0
|
|
|
bigvgan_time = 0
|
|
|
|
|
|
|
|
|
all_text_tokens: List[List[torch.Tensor]] = []
|
|
|
self._set_gr_progress(0.1, "text processing...")
|
|
|
bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1
|
|
|
all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size)
|
|
|
bucket_count = len(all_segments)
|
|
|
if verbose:
|
|
|
print(">> segments bucket_count:", bucket_count,
|
|
|
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments],
|
|
|
"bucket_max_size:", bucket_max_size)
|
|
|
for segments in all_segments:
|
|
|
temp_tokens: List[torch.Tensor] = []
|
|
|
all_text_tokens.append(temp_tokens)
|
|
|
for item in segments:
|
|
|
sent = item["sent"]
|
|
|
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
|
|
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
|
|
if verbose:
|
|
|
print(text_tokens)
|
|
|
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
|
|
|
|
|
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
|
|
print("text_token_syms is same as segment tokens", text_token_syms == sent)
|
|
|
temp_tokens.append(text_tokens)
|
|
|
|
|
|
|
|
|
all_batch_num = sum(len(s) for s in all_segments)
|
|
|
all_batch_codes = []
|
|
|
processed_num = 0
|
|
|
for item_tokens in all_text_tokens:
|
|
|
batch_num = len(item_tokens)
|
|
|
if batch_num > 1:
|
|
|
batch_text_tokens = self.pad_tokens_cat(item_tokens)
|
|
|
else:
|
|
|
batch_text_tokens = item_tokens[0]
|
|
|
processed_num += batch_num
|
|
|
|
|
|
self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
|
|
|
f"gpt speech inference {processed_num}/{all_batch_num}...")
|
|
|
m_start_time = time.perf_counter()
|
|
|
with torch.no_grad():
|
|
|
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
|
|
|
dtype=self.dtype):
|
|
|
temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens,
|
|
|
cond_mel_lengths=cond_mel_lengths,
|
|
|
|
|
|
do_sample=do_sample,
|
|
|
top_p=top_p,
|
|
|
top_k=top_k,
|
|
|
temperature=temperature,
|
|
|
num_return_sequences=autoregressive_batch_size,
|
|
|
length_penalty=length_penalty,
|
|
|
num_beams=num_beams,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
max_generate_length=max_mel_tokens,
|
|
|
**generation_kwargs)
|
|
|
all_batch_codes.append(temp_codes)
|
|
|
gpt_gen_time += time.perf_counter() - m_start_time
|
|
|
|
|
|
|
|
|
self._set_gr_progress(0.5, "gpt latents inference...")
|
|
|
all_idxs = []
|
|
|
all_latents = []
|
|
|
has_warned = False
|
|
|
for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments):
|
|
|
for i in range(batch_codes.shape[0]):
|
|
|
codes = batch_codes[i]
|
|
|
if not has_warned and codes[-1] != self.stop_mel_token:
|
|
|
warnings.warn(
|
|
|
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
|
|
f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
|
|
|
category=RuntimeWarning
|
|
|
)
|
|
|
has_warned = True
|
|
|
codes = codes.unsqueeze(0)
|
|
|
if verbose:
|
|
|
print("codes:", codes.shape)
|
|
|
print(codes)
|
|
|
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
|
|
if verbose:
|
|
|
print("fix codes:", codes.shape)
|
|
|
print(codes)
|
|
|
print("code_lens:", code_lens)
|
|
|
text_tokens = batch_tokens[i]
|
|
|
all_idxs.append(batch_segments[i]["idx"])
|
|
|
m_start_time = time.perf_counter()
|
|
|
with torch.no_grad():
|
|
|
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
latent = \
|
|
|
self.gpt(auto_conditioning, text_tokens,
|
|
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
|
|
code_lens * self.gpt.mel_length_compression,
|
|
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
|
|
device=text_tokens.device),
|
|
|
return_latent=True, clip_inputs=False)
|
|
|
gpt_forward_time += time.perf_counter() - m_start_time
|
|
|
all_latents.append(latent)
|
|
|
del all_batch_codes, all_text_tokens, all_segments
|
|
|
|
|
|
chunk_size = 2
|
|
|
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
|
|
if verbose:
|
|
|
print(">> all_latents:", len(all_latents))
|
|
|
print(" latents length:", [l.shape[1] for l in all_latents])
|
|
|
chunk_latents = [all_latents[i: i + chunk_size] for i in range(0, len(all_latents), chunk_size)]
|
|
|
chunk_length = len(chunk_latents)
|
|
|
latent_length = len(all_latents)
|
|
|
|
|
|
|
|
|
self._set_gr_progress(0.7, "bigvgan decoding...")
|
|
|
tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
|
|
|
for items in chunk_latents:
|
|
|
tqdm_progress.update(len(items))
|
|
|
latent = torch.cat(items, dim=1)
|
|
|
with torch.no_grad():
|
|
|
with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
m_start_time = time.perf_counter()
|
|
|
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
|
|
|
bigvgan_time += time.perf_counter() - m_start_time
|
|
|
wav = wav.squeeze(1)
|
|
|
pass
|
|
|
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
|
|
wavs.append(wav.cpu())
|
|
|
|
|
|
|
|
|
tqdm_progress.close()
|
|
|
del all_latents, chunk_latents
|
|
|
end_time = time.perf_counter()
|
|
|
self.torch_empty_cache()
|
|
|
|
|
|
|
|
|
self._set_gr_progress(0.9, "saving audio...")
|
|
|
wav = torch.cat(wavs, dim=1)
|
|
|
wav_length = wav.shape[-1] / sampling_rate
|
|
|
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
|
|
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
|
|
|
print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
|
|
|
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
|
|
|
print(f">> Total fast inference time: {end_time - start_time:.2f} seconds")
|
|
|
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
|
|
print(f">> [fast] bigvgan chunk_length: {chunk_length}")
|
|
|
print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}",
|
|
|
f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "")
|
|
|
print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}")
|
|
|
|
|
|
|
|
|
wav = wav.cpu()
|
|
|
if output_path:
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
|
|
print(">> wav file saved to:", output_path)
|
|
|
return output_path
|
|
|
else:
|
|
|
|
|
|
wav_data = wav.type(torch.int16)
|
|
|
wav_data = wav_data.numpy().T
|
|
|
return (sampling_rate, wav_data)
|
|
|
|
|
|
|
|
|
def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120,
|
|
|
**generation_kwargs):
|
|
|
print(">> starting inference...")
|
|
|
self._set_gr_progress(0, "starting inference...")
|
|
|
if verbose:
|
|
|
print(f"origin text:{text}")
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt:
|
|
|
audio, sr = torchaudio.load(audio_prompt)
|
|
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
if audio.shape[0] > 1:
|
|
|
audio = audio[0].unsqueeze(0)
|
|
|
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
|
|
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
|
|
cond_mel_frame = cond_mel.shape[-1]
|
|
|
if verbose:
|
|
|
print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype)
|
|
|
|
|
|
self.cache_audio_prompt = audio_prompt
|
|
|
self.cache_cond_mel = cond_mel
|
|
|
else:
|
|
|
cond_mel = self.cache_cond_mel
|
|
|
cond_mel_frame = cond_mel.shape[-1]
|
|
|
pass
|
|
|
|
|
|
self._set_gr_progress(0.1, "text processing...")
|
|
|
auto_conditioning = cond_mel
|
|
|
text_tokens_list = self.tokenizer.tokenize(text)
|
|
|
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
|
|
if verbose:
|
|
|
print("text token count:", len(text_tokens_list))
|
|
|
print("segments count:", len(segments))
|
|
|
print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
|
|
|
print(*segments, sep="\n")
|
|
|
do_sample = generation_kwargs.pop("do_sample", True)
|
|
|
top_p = generation_kwargs.pop("top_p", 0.8)
|
|
|
top_k = generation_kwargs.pop("top_k", 30)
|
|
|
temperature = generation_kwargs.pop("temperature", 1.0)
|
|
|
autoregressive_batch_size = 1
|
|
|
length_penalty = generation_kwargs.pop("length_penalty", 0.0)
|
|
|
num_beams = generation_kwargs.pop("num_beams", 3)
|
|
|
repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0)
|
|
|
max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600)
|
|
|
sampling_rate = 24000
|
|
|
|
|
|
|
|
|
wavs = []
|
|
|
gpt_gen_time = 0
|
|
|
gpt_forward_time = 0
|
|
|
bigvgan_time = 0
|
|
|
progress = 0
|
|
|
has_warned = False
|
|
|
for sent in segments:
|
|
|
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
|
|
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
if verbose:
|
|
|
print(text_tokens)
|
|
|
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
|
|
|
|
|
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
|
|
print("text_token_syms is same as segment tokens", text_token_syms == sent)
|
|
|
|
|
|
|
|
|
|
|
|
progress += 1
|
|
|
self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments),
|
|
|
f"gpt latents inference {progress}/{len(segments)}...")
|
|
|
m_start_time = time.perf_counter()
|
|
|
with torch.no_grad():
|
|
|
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
|
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
|
|
device=text_tokens.device),
|
|
|
|
|
|
do_sample=do_sample,
|
|
|
top_p=top_p,
|
|
|
top_k=top_k,
|
|
|
temperature=temperature,
|
|
|
num_return_sequences=autoregressive_batch_size,
|
|
|
length_penalty=length_penalty,
|
|
|
num_beams=num_beams,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
max_generate_length=max_mel_tokens,
|
|
|
**generation_kwargs)
|
|
|
gpt_gen_time += time.perf_counter() - m_start_time
|
|
|
if not has_warned and (codes[:, -1] != self.stop_mel_token).any():
|
|
|
warnings.warn(
|
|
|
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
|
|
f"Input text tokens: {text_tokens.shape[1]}. "
|
|
|
f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
|
|
|
category=RuntimeWarning
|
|
|
)
|
|
|
has_warned = True
|
|
|
|
|
|
code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype)
|
|
|
if verbose:
|
|
|
print(codes, type(codes))
|
|
|
print(f"codes shape: {codes.shape}, codes type: {codes.dtype}")
|
|
|
print(f"code len: {code_lens}")
|
|
|
|
|
|
|
|
|
|
|
|
codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30)
|
|
|
if verbose:
|
|
|
print(codes, type(codes))
|
|
|
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
|
|
print(f"code len: {code_lens}")
|
|
|
self._set_gr_progress(0.2 + 0.4 * progress / len(segments),
|
|
|
f"gpt speech inference {progress}/{len(segments)}...")
|
|
|
m_start_time = time.perf_counter()
|
|
|
|
|
|
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
latent = \
|
|
|
self.gpt(auto_conditioning, text_tokens,
|
|
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
|
|
code_lens * self.gpt.mel_length_compression,
|
|
|
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
|
|
device=text_tokens.device),
|
|
|
return_latent=True, clip_inputs=False)
|
|
|
gpt_forward_time += time.perf_counter() - m_start_time
|
|
|
|
|
|
m_start_time = time.perf_counter()
|
|
|
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
|
|
|
bigvgan_time += time.perf_counter() - m_start_time
|
|
|
wav = wav.squeeze(1)
|
|
|
|
|
|
wav = torch.clamp(32767 * wav, -32767.0, 32767.0)
|
|
|
if verbose:
|
|
|
print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max())
|
|
|
|
|
|
wavs.append(wav.cpu())
|
|
|
end_time = time.perf_counter()
|
|
|
self._set_gr_progress(0.9, "saving audio...")
|
|
|
wav = torch.cat(wavs, dim=1)
|
|
|
wav_length = wav.shape[-1] / sampling_rate
|
|
|
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
|
|
print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds")
|
|
|
print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds")
|
|
|
print(f">> bigvgan_time: {bigvgan_time:.2f} seconds")
|
|
|
print(f">> Total inference time: {end_time - start_time:.2f} seconds")
|
|
|
print(f">> Generated audio length: {wav_length:.2f} seconds")
|
|
|
print(f">> RTF: {(end_time - start_time) / wav_length:.4f}")
|
|
|
|
|
|
|
|
|
wav = wav.cpu()
|
|
|
if output_path:
|
|
|
|
|
|
if os.path.isfile(output_path):
|
|
|
os.remove(output_path)
|
|
|
print(">> remove old wav file:", output_path)
|
|
|
if os.path.dirname(output_path) != "":
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
|
|
print(">> wav file saved to:", output_path)
|
|
|
return output_path
|
|
|
else:
|
|
|
|
|
|
wav_data = wav.type(torch.int16)
|
|
|
wav_data = wav_data.numpy().T
|
|
|
return (sampling_rate, wav_data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
prompt_wav = "examples/voice_01.wav"
|
|
|
text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。'
|
|
|
|
|
|
tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False)
|
|
|
tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True)
|
|
|
|