chiluka-tts / inference.py
Seemanth's picture
Add Chiluka TTS models (Hindi-English + Telugu)
13f85be verified
"""
Chiluka - Main inference API for TTS synthesis.
Example usage:
from chiluka import Chiluka
# Simple usage (uses bundled models)
tts = Chiluka()
# Generate speech
wav = tts.synthesize(
text="Hello, world!",
reference_audio="path/to/reference.wav",
language="en"
)
# Save to file
tts.save_wav(wav, "output.wav")
"""
import os
import yaml
import torch
import torchaudio
import librosa
import numpy as np
from pathlib import Path
from typing import Optional, Union
from nltk.tokenize import word_tokenize
from .models import build_model, load_ASR_models, load_F0_models, load_plbert
from .models.diffusion import DiffusionSampler, ADPM2Sampler, KarrasSchedule
from .text_utils import TextCleaner
from .utils import recursive_munch, length_to_mask
# Get package directory
PACKAGE_DIR = Path(__file__).parent.absolute()
DEFAULT_PRETRAINED_DIR = PACKAGE_DIR / "pretrained"
DEFAULT_CONFIG_PATH = PACKAGE_DIR / "configs" / "config_ft.yml"
DEFAULT_CHECKPOINT_DIR = PACKAGE_DIR / "checkpoints"
def get_default_checkpoint():
"""Find the first checkpoint in the checkpoints directory."""
if DEFAULT_CHECKPOINT_DIR.exists():
checkpoints = list(DEFAULT_CHECKPOINT_DIR.glob("*.pth"))
if checkpoints:
return str(checkpoints[0])
return None
class Chiluka:
"""
Chiluka TTS - Text-to-Speech synthesis using StyleTTS2.
Args:
config_path: Path to the YAML config file. If None, uses bundled config.
checkpoint_path: Path to the trained model checkpoint (.pth file). If None, uses bundled checkpoint.
pretrained_dir: Directory containing pretrained sub-models (ASR/, JDC/, PLBERT/). If None, uses bundled models.
device: Device to use ('cuda' or 'cpu'). If None, auto-detects.
Example:
# Use bundled models (simplest)
tts = Chiluka()
# Or specify custom paths
tts = Chiluka(
config_path="my_config.yml",
checkpoint_path="my_model.pth"
)
"""
def __init__(
self,
config_path: Optional[str] = None,
checkpoint_path: Optional[str] = None,
pretrained_dir: Optional[str] = None,
device: Optional[str] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Resolve paths - use bundled defaults if not specified
config_path = config_path or str(DEFAULT_CONFIG_PATH)
checkpoint_path = checkpoint_path or get_default_checkpoint()
pretrained_dir = pretrained_dir or str(DEFAULT_PRETRAINED_DIR)
if not checkpoint_path:
raise ValueError(
"No checkpoint found. Please either:\n"
"1. Place a .pth checkpoint in: {}\n"
"2. Specify checkpoint_path parameter".format(DEFAULT_CHECKPOINT_DIR)
)
# Load config
print(f"Loading config from {config_path}...")
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
# Resolve pretrained paths
self.pretrained_dir = Path(pretrained_dir)
asr_config = self.pretrained_dir / "ASR" / "config.yml"
asr_path = self.pretrained_dir / "ASR" / "epoch_00080.pth"
f0_path = self.pretrained_dir / "JDC" / "bst.t7"
plbert_dir = self.pretrained_dir / "PLBERT"
# Verify pretrained models exist
self._verify_pretrained_models(asr_path, f0_path, plbert_dir)
# Load pretrained models
print(f"Loading ASR model...")
self.text_aligner = load_ASR_models(str(asr_path), str(asr_config))
print(f"Loading F0 model...")
self.pitch_extractor = load_F0_models(str(f0_path))
print(f"Loading PL-BERT...")
self.plbert = load_plbert(str(plbert_dir))
# Build model
self.model_params = recursive_munch(self.config["model_params"])
self.model = build_model(self.model_params, self.text_aligner, self.pitch_extractor, self.plbert)
# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}...")
self._load_checkpoint(checkpoint_path)
# Move to device and set to eval mode
for key in self.model:
self.model[key].eval().to(self.device)
# Build sampler
self.sampler = DiffusionSampler(
self.model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False,
)
# Text cleaner
self.textcleaner = TextCleaner()
# Mel spectrogram transform
self.to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
# Cache for phonemizer backends
self._phonemizers = {}
print("✓ Chiluka TTS initialized successfully!")
@classmethod
def from_pretrained(
cls,
model: str = None,
repo_id: str = None,
device: Optional[str] = None,
force_download: bool = False,
token: Optional[str] = None,
**kwargs,
) -> "Chiluka":
"""
Load Chiluka TTS from HuggingFace Hub or with auto-downloaded weights.
This is the recommended way to load Chiluka when you don't have local weights.
Weights are automatically downloaded and cached on first use.
Args:
model: Model variant to load. Options:
- 'hindi_english' (default) - Hindi + English multi-speaker TTS
- 'telugu' - Telugu + English single-speaker TTS
repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts').
If None, uses the default repository.
device: Device to use ('cuda' or 'cpu'). Auto-detects if None.
force_download: If True, re-download even if cached.
token: HuggingFace API token for private repositories.
**kwargs: Additional arguments passed to Chiluka constructor.
Returns:
Initialized Chiluka TTS model ready for inference.
Examples:
# Hindi-English model (default)
>>> tts = Chiluka.from_pretrained()
# Telugu model
>>> tts = Chiluka.from_pretrained(model="telugu")
# Specific HuggingFace repository
>>> tts = Chiluka.from_pretrained(repo_id="myuser/my-model")
# Force re-download
>>> tts = Chiluka.from_pretrained(force_download=True)
"""
from .hub import download_from_hf, get_model_paths, DEFAULT_HF_REPO, DEFAULT_MODEL
model = model or DEFAULT_MODEL
repo_id = repo_id or DEFAULT_HF_REPO
# Download model files (or use cache)
download_from_hf(
repo_id=repo_id,
force_download=force_download,
token=token,
)
# Get paths to model files for the selected variant
paths = get_model_paths(model=model, repo_id=repo_id)
return cls(
config_path=paths["config_path"],
checkpoint_path=paths["checkpoint_path"],
pretrained_dir=paths["pretrained_dir"],
device=device,
**kwargs,
)
def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir):
"""Verify all pretrained models exist."""
missing = []
if not asr_path.exists():
missing.append(f"ASR model: {asr_path}")
if not f0_path.exists():
missing.append(f"F0 model: {f0_path}")
if not plbert_dir.exists():
missing.append(f"PLBERT directory: {plbert_dir}")
if missing:
raise FileNotFoundError(
"Missing pretrained models:\n" +
"\n".join(f" - {m}" for m in missing) +
f"\n\nExpected in: {self.pretrained_dir}"
)
def _load_checkpoint(self, checkpoint_path: str):
"""Load model weights from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
for key in self.model:
if key in checkpoint["net"]:
try:
self.model[key].load_state_dict(checkpoint["net"][key])
except Exception:
state_dict = checkpoint["net"][key]
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
self.model[key].load_state_dict(new_state_dict)
def _get_phonemizer(self, language: str):
"""Get or create phonemizer backend for a language."""
if language not in self._phonemizers:
import phonemizer
self._phonemizers[language] = phonemizer.backend.EspeakBackend(
language=language, preserve_punctuation=True, with_stress=True
)
return self._phonemizers[language]
def _preprocess_mel(self, wave: np.ndarray, mean: float = -4, std: float = 4) -> torch.Tensor:
"""Convert waveform to normalized mel spectrogram."""
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = self.to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def compute_style(self, audio_path: str, sr: int = 24000) -> torch.Tensor:
"""
Compute style embedding from reference audio.
Args:
audio_path: Path to reference audio file
sr: Target sample rate
Returns:
Style embedding tensor
"""
wave, orig_sr = librosa.load(audio_path, sr=sr)
audio, _ = librosa.effects.trim(wave, top_db=30)
if orig_sr != sr:
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr)
mel_tensor = self._preprocess_mel(audio).to(self.device)
with torch.no_grad():
ref_s = self.model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = self.model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1)
def synthesize(
self,
text: str,
reference_audio: str,
language: str = "en",
alpha: float = 0.3,
beta: float = 0.7,
diffusion_steps: int = 5,
embedding_scale: float = 1.0,
sr: int = 24000,
) -> np.ndarray:
"""
Synthesize speech from text.
Args:
text: Input text to synthesize
reference_audio: Path to reference audio for style transfer
language: Language code for phonemization (e.g., 'en', 'te', 'hi')
alpha: Style mixing coefficient for acoustic features (0-1)
beta: Style mixing coefficient for prosodic features (0-1)
diffusion_steps: Number of diffusion sampling steps
embedding_scale: Classifier-free guidance scale
sr: Sample rate
Returns:
Generated audio waveform as numpy array
"""
# Compute style from reference
ref_s = self.compute_style(reference_audio, sr=sr)
# Phonemize text
phonemizer = self._get_phonemizer(language)
text = text.strip()
ps = phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = " ".join(ps)
# Convert to tokens
tokens = self.textcleaner(ps)
tokens.insert(0, 0) # Add start token
tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0)
# Truncate if too long
max_len = self.model.bert.config.max_position_embeddings
if tokens.shape[-1] > max_len:
tokens = tokens[:, :max_len]
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device)
text_mask = length_to_mask(input_lengths).to(self.device)
# Encode text
t_en = self.model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# Sample style
s_pred = self.sampler(
noise=torch.randn((1, 256)).unsqueeze(1).to(self.device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s,
num_steps=diffusion_steps,
).squeeze(1)
s = s_pred[:, 128:]
ref = s_pred[:, :128]
# Mix styles
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
# Predict duration
d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.model.predictor.lstm(d)
duration = self.model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
# Build alignment
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device))
# Adjust for hifigan decoder
if self.model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
# Predict F0 and energy
F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s)
# Encode for decoder
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device))
if self.model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
# Decode waveform
out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50]
def save_wav(self, wav: np.ndarray, path: str, sr: int = 24000):
"""
Save waveform to WAV file.
Args:
wav: Audio waveform as numpy array
path: Output file path
sr: Sample rate
"""
import scipy.io.wavfile as wavfile
wav_int16 = (wav * 32767).clip(-32768, 32767).astype(np.int16)
wavfile.write(path, sr, wav_int16)
print(f"Saved audio to {path}")
def play(self, wav: np.ndarray, sr: int = 24000):
"""
Play audio through speakers (requires pyaudio).
Args:
wav: Audio waveform as numpy array
sr: Sample rate
"""
try:
import pyaudio
except ImportError:
raise ImportError("pyaudio is required for playback. Install with: pip install pyaudio")
audio_int16 = (wav * 32767.0).clip(-32768, 32767).astype("int16").tobytes()
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=sr, output=True)
stream.write(audio_int16)
stream.stop_stream()
stream.close()
p.terminate()