[Fix] How to use this model with modern Coqui-TTS (CLI & Python Solution)

#1
by Professor - opened

If you are trying to follow the usage instructions in the Model Card, you’ve likely realized they are broken. The command relies on a specific git commit from 2021 that no longer compiles on modern Python environments.

Even if you install the latest coqui-tts, you will hit errors like KeyError: 'output_sample_rate' or crashes during inference because the model config is outdated.

The Solution
I have created a standalone Colab Notebook that fixes these compatibility issues automatically. It patches the config in-memory and applies a runtime fix so you can run this model with the latest stable libraries.

🚀 Open in Google Colab


For Python Users (Script Version)
If you prefer running this in your own script instead of Colab, you need to apply a "Monkey Patch" to fix a bug in how newer VITS implementations handle speaker embeddings.

Here is the full working code snippet:

import os
import torch
from TTS.tts.models.vits import Vits
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from scipy.io.wavfile import write

# 1. Load & Patch Config
conf = VitsConfig()
conf.load_json("config.json")
conf.output_sample_rate = 22050
conf.audio.output_sample_rate = 22050
conf.phoneme_language = "en"
conf.use_speaker_embedding = True
conf.model_args.use_speaker_embedding = True
conf.model_args.d_vector_file = False # Important: Disable looking for a file

# 2. Load Model
tokenizer = TTSTokenizer.init_from_config(conf)[0]
model = Vits(config=conf, ap=None, tokenizer=None, speaker_manager=None)
model.load_checkpoint(config=conf, checkpoint_path="model.pth", eval=True)

# 3. CRITICAL RUNTIME FIX (The Monkey Patch)
# This overrides the internal function that causes crashes in newer library versions
def fixed_set_cond_input(aux_input):
    return None, aux_input["d_vector"], None, None
model._set_cond_input = fixed_set_cond_input

# 4. Inference
speaker_manager = SpeakerManager(encoder_model_path="SE_checkpoint.pth.tar", encoder_config_path="config_se.json")
d_vectors = speaker_manager.compute_embedding_from_clip(["conditioning_audio.wav"])
d_vector_tensor = torch.tensor(d_vectors, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)

token_ids = tokenizer.text_to_ids("Muraho, nishimiye gukoresha iri koranabuhanga.")
x = torch.LongTensor(token_ids).unsqueeze(0)

outputs = model.inference(x, aux_input={"d_vector": d_vector_tensor})
write("output.wav", 22050, outputs["model_outputs"].squeeze().detach().numpy())

Sign up or log in to comment