| | import torch |
| | import torchaudio |
| | from indextts.infer import IndexTTS |
| | from indextts.utils.feature_extractors import MelSpectrogramFeatures |
| | from torch.nn import functional as F |
| |
|
| | if __name__ == "__main__": |
| | """ |
| | Test the padding of text tokens in inference. |
| | ``` |
| | python tests/padding_test.py checkpoints |
| | python tests/padding_test.py IndexTTS-1.5 |
| | ``` |
| | """ |
| | import transformers |
| | transformers.set_seed(42) |
| | import sys |
| | sys.path.append("..") |
| | if len(sys.argv) > 1: |
| | model_dir = sys.argv[1] |
| | else: |
| | model_dir = "checkpoints" |
| | audio_prompt="tests/sample_prompt.wav" |
| | tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) |
| | text = "晕 XUAN4 是 一 种 not very good GAN3 觉" |
| | text_tokens = tts.tokenizer.encode(text) |
| | text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) |
| |
|
| | audio, sr = torchaudio.load(audio_prompt) |
| | audio = torch.mean(audio, dim=0, keepdim=True) |
| | audio = torchaudio.transforms.Resample(sr, 24000)(audio) |
| | auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device) |
| | cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device) |
| | with torch.no_grad(): |
| | kwargs = { |
| | "cond_mel_lengths": cond_mel_lengths, |
| | "do_sample": False, |
| | "top_p": 0.8, |
| | "top_k": None, |
| | "temperature": 1.0, |
| | "num_return_sequences": 1, |
| | "length_penalty": 0.0, |
| | "num_beams": 1, |
| | "repetition_penalty": 10.0, |
| | "max_generate_length": 100, |
| | } |
| | |
| | baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) |
| | baseline = baseline.squeeze(0) |
| | print("Inference padded text tokens...") |
| | pad_text_tokens = [ |
| | F.pad(text_tokens, (8, 0), value=0), |
| | F.pad(text_tokens, (0, 8), value=1), |
| | F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), |
| | F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1), |
| | F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1), |
| | ] |
| | output_for_padded = [] |
| | for t in pad_text_tokens: |
| | |
| | out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) |
| | output_for_padded.append(out.squeeze(0)) |
| | |
| | print("Inference padded text tokens as one batch...") |
| | batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device) |
| | assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2 |
| | batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs) |
| | del pad_text_tokens |
| | mismatch_idx = [] |
| | print("baseline:", baseline.shape, baseline) |
| | print("--"*10) |
| | print("baseline vs padded output:") |
| | for i in range(len(output_for_padded)): |
| | if not baseline.equal(output_for_padded[i]): |
| | mismatch_idx.append(i) |
| | |
| | if len(mismatch_idx) > 0: |
| | print("mismatch:", mismatch_idx) |
| | for i in mismatch_idx: |
| | print(f"[{i}]: {output_for_padded[i]}") |
| | else: |
| | print("all matched") |
| | |
| | del output_for_padded |
| | print("--"*10) |
| | print("baseline vs batched output:") |
| | mismatch_idx = [] |
| | for i in range(batch_output.shape[0]): |
| | if not baseline.equal(batch_output[i]): |
| | mismatch_idx.append(i) |
| | if len(mismatch_idx) > 0: |
| | print("mismatch:", mismatch_idx) |
| | for i in mismatch_idx: |
| | print(f"[{i}]: {batch_output[i]}") |
| | |
| | else: |
| | print("all matched") |
| | |
| | print("Test finished.") |