|
|
""" |
|
|
Chiluka - Main inference API for TTS synthesis. |
|
|
|
|
|
Example usage: |
|
|
from chiluka import Chiluka |
|
|
|
|
|
# Simple usage (uses bundled models) |
|
|
tts = Chiluka() |
|
|
|
|
|
# Generate speech |
|
|
wav = tts.synthesize( |
|
|
text="Hello, world!", |
|
|
reference_audio="path/to/reference.wav", |
|
|
language="en" |
|
|
) |
|
|
|
|
|
# Save to file |
|
|
tts.save_wav(wav, "output.wav") |
|
|
""" |
|
|
|
|
|
import os |
|
|
import yaml |
|
|
import torch |
|
|
import torchaudio |
|
|
import librosa |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Optional, Union |
|
|
|
|
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
from .models import build_model, load_ASR_models, load_F0_models, load_plbert |
|
|
from .models.diffusion import DiffusionSampler, ADPM2Sampler, KarrasSchedule |
|
|
from .text_utils import TextCleaner |
|
|
from .utils import recursive_munch, length_to_mask |
|
|
|
|
|
|
|
|
|
|
|
PACKAGE_DIR = Path(__file__).parent.absolute() |
|
|
DEFAULT_PRETRAINED_DIR = PACKAGE_DIR / "pretrained" |
|
|
DEFAULT_CONFIG_PATH = PACKAGE_DIR / "configs" / "config_ft.yml" |
|
|
DEFAULT_CHECKPOINT_DIR = PACKAGE_DIR / "checkpoints" |
|
|
|
|
|
|
|
|
def get_default_checkpoint(): |
|
|
"""Find the first checkpoint in the checkpoints directory.""" |
|
|
if DEFAULT_CHECKPOINT_DIR.exists(): |
|
|
checkpoints = list(DEFAULT_CHECKPOINT_DIR.glob("*.pth")) |
|
|
if checkpoints: |
|
|
return str(checkpoints[0]) |
|
|
return None |
|
|
|
|
|
|
|
|
class Chiluka: |
|
|
""" |
|
|
Chiluka TTS - Text-to-Speech synthesis using StyleTTS2. |
|
|
|
|
|
Args: |
|
|
config_path: Path to the YAML config file. If None, uses bundled config. |
|
|
checkpoint_path: Path to the trained model checkpoint (.pth file). If None, uses bundled checkpoint. |
|
|
pretrained_dir: Directory containing pretrained sub-models (ASR/, JDC/, PLBERT/). If None, uses bundled models. |
|
|
device: Device to use ('cuda' or 'cpu'). If None, auto-detects. |
|
|
|
|
|
Example: |
|
|
# Use bundled models (simplest) |
|
|
tts = Chiluka() |
|
|
|
|
|
# Or specify custom paths |
|
|
tts = Chiluka( |
|
|
config_path="my_config.yml", |
|
|
checkpoint_path="my_model.pth" |
|
|
) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config_path: Optional[str] = None, |
|
|
checkpoint_path: Optional[str] = None, |
|
|
pretrained_dir: Optional[str] = None, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
config_path = config_path or str(DEFAULT_CONFIG_PATH) |
|
|
checkpoint_path = checkpoint_path or get_default_checkpoint() |
|
|
pretrained_dir = pretrained_dir or str(DEFAULT_PRETRAINED_DIR) |
|
|
|
|
|
if not checkpoint_path: |
|
|
raise ValueError( |
|
|
"No checkpoint found. Please either:\n" |
|
|
"1. Place a .pth checkpoint in: {}\n" |
|
|
"2. Specify checkpoint_path parameter".format(DEFAULT_CHECKPOINT_DIR) |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading config from {config_path}...") |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
self.pretrained_dir = Path(pretrained_dir) |
|
|
asr_config = self.pretrained_dir / "ASR" / "config.yml" |
|
|
asr_path = self.pretrained_dir / "ASR" / "epoch_00080.pth" |
|
|
f0_path = self.pretrained_dir / "JDC" / "bst.t7" |
|
|
plbert_dir = self.pretrained_dir / "PLBERT" |
|
|
|
|
|
|
|
|
self._verify_pretrained_models(asr_path, f0_path, plbert_dir) |
|
|
|
|
|
|
|
|
print(f"Loading ASR model...") |
|
|
self.text_aligner = load_ASR_models(str(asr_path), str(asr_config)) |
|
|
|
|
|
print(f"Loading F0 model...") |
|
|
self.pitch_extractor = load_F0_models(str(f0_path)) |
|
|
|
|
|
print(f"Loading PL-BERT...") |
|
|
self.plbert = load_plbert(str(plbert_dir)) |
|
|
|
|
|
|
|
|
self.model_params = recursive_munch(self.config["model_params"]) |
|
|
self.model = build_model(self.model_params, self.text_aligner, self.pitch_extractor, self.plbert) |
|
|
|
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
self._load_checkpoint(checkpoint_path) |
|
|
|
|
|
|
|
|
for key in self.model: |
|
|
self.model[key].eval().to(self.device) |
|
|
|
|
|
|
|
|
self.sampler = DiffusionSampler( |
|
|
self.model.diffusion.diffusion, |
|
|
sampler=ADPM2Sampler(), |
|
|
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), |
|
|
clamp=False, |
|
|
) |
|
|
|
|
|
|
|
|
self.textcleaner = TextCleaner() |
|
|
|
|
|
|
|
|
self.to_mel = torchaudio.transforms.MelSpectrogram( |
|
|
n_mels=80, n_fft=2048, win_length=1200, hop_length=300 |
|
|
) |
|
|
|
|
|
|
|
|
self._phonemizers = {} |
|
|
|
|
|
print("✓ Chiluka TTS initialized successfully!") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
model: str = None, |
|
|
repo_id: str = None, |
|
|
device: Optional[str] = None, |
|
|
force_download: bool = False, |
|
|
token: Optional[str] = None, |
|
|
**kwargs, |
|
|
) -> "Chiluka": |
|
|
""" |
|
|
Load Chiluka TTS from HuggingFace Hub or with auto-downloaded weights. |
|
|
|
|
|
This is the recommended way to load Chiluka when you don't have local weights. |
|
|
Weights are automatically downloaded and cached on first use. |
|
|
|
|
|
Args: |
|
|
model: Model variant to load. Options: |
|
|
- 'hindi_english' (default) - Hindi + English multi-speaker TTS |
|
|
- 'telugu' - Telugu + English single-speaker TTS |
|
|
repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts'). |
|
|
If None, uses the default repository. |
|
|
device: Device to use ('cuda' or 'cpu'). Auto-detects if None. |
|
|
force_download: If True, re-download even if cached. |
|
|
token: HuggingFace API token for private repositories. |
|
|
**kwargs: Additional arguments passed to Chiluka constructor. |
|
|
|
|
|
Returns: |
|
|
Initialized Chiluka TTS model ready for inference. |
|
|
|
|
|
Examples: |
|
|
# Hindi-English model (default) |
|
|
>>> tts = Chiluka.from_pretrained() |
|
|
|
|
|
# Telugu model |
|
|
>>> tts = Chiluka.from_pretrained(model="telugu") |
|
|
|
|
|
# Specific HuggingFace repository |
|
|
>>> tts = Chiluka.from_pretrained(repo_id="myuser/my-model") |
|
|
|
|
|
# Force re-download |
|
|
>>> tts = Chiluka.from_pretrained(force_download=True) |
|
|
""" |
|
|
from .hub import download_from_hf, get_model_paths, DEFAULT_HF_REPO, DEFAULT_MODEL |
|
|
|
|
|
model = model or DEFAULT_MODEL |
|
|
repo_id = repo_id or DEFAULT_HF_REPO |
|
|
|
|
|
|
|
|
download_from_hf( |
|
|
repo_id=repo_id, |
|
|
force_download=force_download, |
|
|
token=token, |
|
|
) |
|
|
|
|
|
|
|
|
paths = get_model_paths(model=model, repo_id=repo_id) |
|
|
|
|
|
return cls( |
|
|
config_path=paths["config_path"], |
|
|
checkpoint_path=paths["checkpoint_path"], |
|
|
pretrained_dir=paths["pretrained_dir"], |
|
|
device=device, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir): |
|
|
"""Verify all pretrained models exist.""" |
|
|
missing = [] |
|
|
if not asr_path.exists(): |
|
|
missing.append(f"ASR model: {asr_path}") |
|
|
if not f0_path.exists(): |
|
|
missing.append(f"F0 model: {f0_path}") |
|
|
if not plbert_dir.exists(): |
|
|
missing.append(f"PLBERT directory: {plbert_dir}") |
|
|
|
|
|
if missing: |
|
|
raise FileNotFoundError( |
|
|
"Missing pretrained models:\n" + |
|
|
"\n".join(f" - {m}" for m in missing) + |
|
|
f"\n\nExpected in: {self.pretrained_dir}" |
|
|
) |
|
|
|
|
|
def _load_checkpoint(self, checkpoint_path: str): |
|
|
"""Load model weights from checkpoint.""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
for key in self.model: |
|
|
if key in checkpoint["net"]: |
|
|
try: |
|
|
self.model[key].load_state_dict(checkpoint["net"][key]) |
|
|
except Exception: |
|
|
state_dict = checkpoint["net"][key] |
|
|
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
self.model[key].load_state_dict(new_state_dict) |
|
|
|
|
|
def _get_phonemizer(self, language: str): |
|
|
"""Get or create phonemizer backend for a language.""" |
|
|
if language not in self._phonemizers: |
|
|
import phonemizer |
|
|
self._phonemizers[language] = phonemizer.backend.EspeakBackend( |
|
|
language=language, preserve_punctuation=True, with_stress=True |
|
|
) |
|
|
return self._phonemizers[language] |
|
|
|
|
|
def _preprocess_mel(self, wave: np.ndarray, mean: float = -4, std: float = 4) -> torch.Tensor: |
|
|
"""Convert waveform to normalized mel spectrogram.""" |
|
|
wave_tensor = torch.from_numpy(wave).float() |
|
|
mel_tensor = self.to_mel(wave_tensor) |
|
|
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std |
|
|
return mel_tensor |
|
|
|
|
|
def compute_style(self, audio_path: str, sr: int = 24000) -> torch.Tensor: |
|
|
""" |
|
|
Compute style embedding from reference audio. |
|
|
|
|
|
Args: |
|
|
audio_path: Path to reference audio file |
|
|
sr: Target sample rate |
|
|
|
|
|
Returns: |
|
|
Style embedding tensor |
|
|
""" |
|
|
wave, orig_sr = librosa.load(audio_path, sr=sr) |
|
|
audio, _ = librosa.effects.trim(wave, top_db=30) |
|
|
if orig_sr != sr: |
|
|
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) |
|
|
|
|
|
mel_tensor = self._preprocess_mel(audio).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ref_s = self.model.style_encoder(mel_tensor.unsqueeze(1)) |
|
|
ref_p = self.model.predictor_encoder(mel_tensor.unsqueeze(1)) |
|
|
|
|
|
return torch.cat([ref_s, ref_p], dim=1) |
|
|
|
|
|
def synthesize( |
|
|
self, |
|
|
text: str, |
|
|
reference_audio: str, |
|
|
language: str = "en", |
|
|
alpha: float = 0.3, |
|
|
beta: float = 0.7, |
|
|
diffusion_steps: int = 5, |
|
|
embedding_scale: float = 1.0, |
|
|
sr: int = 24000, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Synthesize speech from text. |
|
|
|
|
|
Args: |
|
|
text: Input text to synthesize |
|
|
reference_audio: Path to reference audio for style transfer |
|
|
language: Language code for phonemization (e.g., 'en', 'te', 'hi') |
|
|
alpha: Style mixing coefficient for acoustic features (0-1) |
|
|
beta: Style mixing coefficient for prosodic features (0-1) |
|
|
diffusion_steps: Number of diffusion sampling steps |
|
|
embedding_scale: Classifier-free guidance scale |
|
|
sr: Sample rate |
|
|
|
|
|
Returns: |
|
|
Generated audio waveform as numpy array |
|
|
""" |
|
|
|
|
|
ref_s = self.compute_style(reference_audio, sr=sr) |
|
|
|
|
|
|
|
|
phonemizer = self._get_phonemizer(language) |
|
|
text = text.strip() |
|
|
ps = phonemizer.phonemize([text]) |
|
|
ps = word_tokenize(ps[0]) |
|
|
ps = " ".join(ps) |
|
|
|
|
|
|
|
|
tokens = self.textcleaner(ps) |
|
|
tokens.insert(0, 0) |
|
|
tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0) |
|
|
|
|
|
|
|
|
max_len = self.model.bert.config.max_position_embeddings |
|
|
if tokens.shape[-1] > max_len: |
|
|
tokens = tokens[:, :max_len] |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) |
|
|
text_mask = length_to_mask(input_lengths).to(self.device) |
|
|
|
|
|
|
|
|
t_en = self.model.text_encoder(tokens, input_lengths, text_mask) |
|
|
bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) |
|
|
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) |
|
|
|
|
|
|
|
|
s_pred = self.sampler( |
|
|
noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), |
|
|
embedding=bert_dur, |
|
|
embedding_scale=embedding_scale, |
|
|
features=ref_s, |
|
|
num_steps=diffusion_steps, |
|
|
).squeeze(1) |
|
|
|
|
|
s = s_pred[:, 128:] |
|
|
ref = s_pred[:, :128] |
|
|
|
|
|
|
|
|
ref = alpha * ref + (1 - alpha) * ref_s[:, :128] |
|
|
s = beta * s + (1 - beta) * ref_s[:, 128:] |
|
|
|
|
|
|
|
|
d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) |
|
|
x, _ = self.model.predictor.lstm(d) |
|
|
duration = self.model.predictor.duration_proj(x) |
|
|
duration = torch.sigmoid(duration).sum(axis=-1) |
|
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
|
|
|
|
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
|
|
c_frame = 0 |
|
|
for i in range(pred_aln_trg.size(0)): |
|
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
|
|
c_frame += int(pred_dur[i].data) |
|
|
|
|
|
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) |
|
|
|
|
|
|
|
|
if self.model_params.decoder.type == "hifigan": |
|
|
asr_new = torch.zeros_like(en) |
|
|
asr_new[:, :, 0] = en[:, :, 0] |
|
|
asr_new[:, :, 1:] = en[:, :, 0:-1] |
|
|
en = asr_new |
|
|
|
|
|
|
|
|
F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) |
|
|
|
|
|
|
|
|
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) |
|
|
if self.model_params.decoder.type == "hifigan": |
|
|
asr_new = torch.zeros_like(asr) |
|
|
asr_new[:, :, 0] = asr[:, :, 0] |
|
|
asr_new[:, :, 1:] = asr[:, :, 0:-1] |
|
|
asr = asr_new |
|
|
|
|
|
|
|
|
out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
|
|
|
|
|
return out.squeeze().cpu().numpy()[..., :-50] |
|
|
|
|
|
def save_wav(self, wav: np.ndarray, path: str, sr: int = 24000): |
|
|
""" |
|
|
Save waveform to WAV file. |
|
|
|
|
|
Args: |
|
|
wav: Audio waveform as numpy array |
|
|
path: Output file path |
|
|
sr: Sample rate |
|
|
""" |
|
|
import scipy.io.wavfile as wavfile |
|
|
wav_int16 = (wav * 32767).clip(-32768, 32767).astype(np.int16) |
|
|
wavfile.write(path, sr, wav_int16) |
|
|
print(f"Saved audio to {path}") |
|
|
|
|
|
def play(self, wav: np.ndarray, sr: int = 24000): |
|
|
""" |
|
|
Play audio through speakers (requires pyaudio). |
|
|
|
|
|
Args: |
|
|
wav: Audio waveform as numpy array |
|
|
sr: Sample rate |
|
|
""" |
|
|
try: |
|
|
import pyaudio |
|
|
except ImportError: |
|
|
raise ImportError("pyaudio is required for playback. Install with: pip install pyaudio") |
|
|
|
|
|
audio_int16 = (wav * 32767.0).clip(-32768, 32767).astype("int16").tobytes() |
|
|
p = pyaudio.PyAudio() |
|
|
stream = p.open(format=pyaudio.paInt16, channels=1, rate=sr, output=True) |
|
|
stream.write(audio_int16) |
|
|
stream.stop_stream() |
|
|
stream.close() |
|
|
p.terminate() |
|
|
|