""" Chiluka - Main inference API for TTS synthesis. Example usage: from chiluka import Chiluka # Simple usage (uses bundled models) tts = Chiluka() # Generate speech wav = tts.synthesize( text="Hello, world!", reference_audio="path/to/reference.wav", language="en" ) # Save to file tts.save_wav(wav, "output.wav") """ import os import yaml import torch import torchaudio import librosa import numpy as np from pathlib import Path from typing import Optional, Union from nltk.tokenize import word_tokenize from .models import build_model, load_ASR_models, load_F0_models, load_plbert from .models.diffusion import DiffusionSampler, ADPM2Sampler, KarrasSchedule from .text_utils import TextCleaner from .utils import recursive_munch, length_to_mask # Get package directory PACKAGE_DIR = Path(__file__).parent.absolute() DEFAULT_PRETRAINED_DIR = PACKAGE_DIR / "pretrained" DEFAULT_CONFIG_PATH = PACKAGE_DIR / "configs" / "config_ft.yml" DEFAULT_CHECKPOINT_DIR = PACKAGE_DIR / "checkpoints" def get_default_checkpoint(): """Find the first checkpoint in the checkpoints directory.""" if DEFAULT_CHECKPOINT_DIR.exists(): checkpoints = list(DEFAULT_CHECKPOINT_DIR.glob("*.pth")) if checkpoints: return str(checkpoints[0]) return None class Chiluka: """ Chiluka TTS - Text-to-Speech synthesis using StyleTTS2. Args: config_path: Path to the YAML config file. If None, uses bundled config. checkpoint_path: Path to the trained model checkpoint (.pth file). If None, uses bundled checkpoint. pretrained_dir: Directory containing pretrained sub-models (ASR/, JDC/, PLBERT/). If None, uses bundled models. device: Device to use ('cuda' or 'cpu'). If None, auto-detects. Example: # Use bundled models (simplest) tts = Chiluka() # Or specify custom paths tts = Chiluka( config_path="my_config.yml", checkpoint_path="my_model.pth" ) """ def __init__( self, config_path: Optional[str] = None, checkpoint_path: Optional[str] = None, pretrained_dir: Optional[str] = None, device: Optional[str] = None, ): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Resolve paths - use bundled defaults if not specified config_path = config_path or str(DEFAULT_CONFIG_PATH) checkpoint_path = checkpoint_path or get_default_checkpoint() pretrained_dir = pretrained_dir or str(DEFAULT_PRETRAINED_DIR) if not checkpoint_path: raise ValueError( "No checkpoint found. Please either:\n" "1. Place a .pth checkpoint in: {}\n" "2. Specify checkpoint_path parameter".format(DEFAULT_CHECKPOINT_DIR) ) # Load config print(f"Loading config from {config_path}...") with open(config_path, 'r') as f: self.config = yaml.safe_load(f) # Resolve pretrained paths self.pretrained_dir = Path(pretrained_dir) asr_config = self.pretrained_dir / "ASR" / "config.yml" asr_path = self.pretrained_dir / "ASR" / "epoch_00080.pth" f0_path = self.pretrained_dir / "JDC" / "bst.t7" plbert_dir = self.pretrained_dir / "PLBERT" # Verify pretrained models exist self._verify_pretrained_models(asr_path, f0_path, plbert_dir) # Load pretrained models print(f"Loading ASR model...") self.text_aligner = load_ASR_models(str(asr_path), str(asr_config)) print(f"Loading F0 model...") self.pitch_extractor = load_F0_models(str(f0_path)) print(f"Loading PL-BERT...") self.plbert = load_plbert(str(plbert_dir)) # Build model self.model_params = recursive_munch(self.config["model_params"]) self.model = build_model(self.model_params, self.text_aligner, self.pitch_extractor, self.plbert) # Load checkpoint print(f"Loading checkpoint from {checkpoint_path}...") self._load_checkpoint(checkpoint_path) # Move to device and set to eval mode for key in self.model: self.model[key].eval().to(self.device) # Build sampler self.sampler = DiffusionSampler( self.model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), clamp=False, ) # Text cleaner self.textcleaner = TextCleaner() # Mel spectrogram transform self.to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300 ) # Cache for phonemizer backends self._phonemizers = {} print("✓ Chiluka TTS initialized successfully!") @classmethod def from_pretrained( cls, model: str = None, repo_id: str = None, device: Optional[str] = None, force_download: bool = False, token: Optional[str] = None, **kwargs, ) -> "Chiluka": """ Load Chiluka TTS from HuggingFace Hub or with auto-downloaded weights. This is the recommended way to load Chiluka when you don't have local weights. Weights are automatically downloaded and cached on first use. Args: model: Model variant to load. Options: - 'hindi_english' (default) - Hindi + English multi-speaker TTS - 'telugu' - Telugu + English single-speaker TTS repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts'). If None, uses the default repository. device: Device to use ('cuda' or 'cpu'). Auto-detects if None. force_download: If True, re-download even if cached. token: HuggingFace API token for private repositories. **kwargs: Additional arguments passed to Chiluka constructor. Returns: Initialized Chiluka TTS model ready for inference. Examples: # Hindi-English model (default) >>> tts = Chiluka.from_pretrained() # Telugu model >>> tts = Chiluka.from_pretrained(model="telugu") # Specific HuggingFace repository >>> tts = Chiluka.from_pretrained(repo_id="myuser/my-model") # Force re-download >>> tts = Chiluka.from_pretrained(force_download=True) """ from .hub import download_from_hf, get_model_paths, DEFAULT_HF_REPO, DEFAULT_MODEL model = model or DEFAULT_MODEL repo_id = repo_id or DEFAULT_HF_REPO # Download model files (or use cache) download_from_hf( repo_id=repo_id, force_download=force_download, token=token, ) # Get paths to model files for the selected variant paths = get_model_paths(model=model, repo_id=repo_id) return cls( config_path=paths["config_path"], checkpoint_path=paths["checkpoint_path"], pretrained_dir=paths["pretrained_dir"], device=device, **kwargs, ) def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir): """Verify all pretrained models exist.""" missing = [] if not asr_path.exists(): missing.append(f"ASR model: {asr_path}") if not f0_path.exists(): missing.append(f"F0 model: {f0_path}") if not plbert_dir.exists(): missing.append(f"PLBERT directory: {plbert_dir}") if missing: raise FileNotFoundError( "Missing pretrained models:\n" + "\n".join(f" - {m}" for m in missing) + f"\n\nExpected in: {self.pretrained_dir}" ) def _load_checkpoint(self, checkpoint_path: str): """Load model weights from checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=self.device) for key in self.model: if key in checkpoint["net"]: try: self.model[key].load_state_dict(checkpoint["net"][key]) except Exception: state_dict = checkpoint["net"][key] new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} self.model[key].load_state_dict(new_state_dict) def _get_phonemizer(self, language: str): """Get or create phonemizer backend for a language.""" if language not in self._phonemizers: import phonemizer self._phonemizers[language] = phonemizer.backend.EspeakBackend( language=language, preserve_punctuation=True, with_stress=True ) return self._phonemizers[language] def _preprocess_mel(self, wave: np.ndarray, mean: float = -4, std: float = 4) -> torch.Tensor: """Convert waveform to normalized mel spectrogram.""" wave_tensor = torch.from_numpy(wave).float() mel_tensor = self.to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std return mel_tensor def compute_style(self, audio_path: str, sr: int = 24000) -> torch.Tensor: """ Compute style embedding from reference audio. Args: audio_path: Path to reference audio file sr: Target sample rate Returns: Style embedding tensor """ wave, orig_sr = librosa.load(audio_path, sr=sr) audio, _ = librosa.effects.trim(wave, top_db=30) if orig_sr != sr: audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) mel_tensor = self._preprocess_mel(audio).to(self.device) with torch.no_grad(): ref_s = self.model.style_encoder(mel_tensor.unsqueeze(1)) ref_p = self.model.predictor_encoder(mel_tensor.unsqueeze(1)) return torch.cat([ref_s, ref_p], dim=1) def synthesize( self, text: str, reference_audio: str, language: str = "en", alpha: float = 0.3, beta: float = 0.7, diffusion_steps: int = 5, embedding_scale: float = 1.0, sr: int = 24000, ) -> np.ndarray: """ Synthesize speech from text. Args: text: Input text to synthesize reference_audio: Path to reference audio for style transfer language: Language code for phonemization (e.g., 'en', 'te', 'hi') alpha: Style mixing coefficient for acoustic features (0-1) beta: Style mixing coefficient for prosodic features (0-1) diffusion_steps: Number of diffusion sampling steps embedding_scale: Classifier-free guidance scale sr: Sample rate Returns: Generated audio waveform as numpy array """ # Compute style from reference ref_s = self.compute_style(reference_audio, sr=sr) # Phonemize text phonemizer = self._get_phonemizer(language) text = text.strip() ps = phonemizer.phonemize([text]) ps = word_tokenize(ps[0]) ps = " ".join(ps) # Convert to tokens tokens = self.textcleaner(ps) tokens.insert(0, 0) # Add start token tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0) # Truncate if too long max_len = self.model.bert.config.max_position_embeddings if tokens.shape[-1] > max_len: tokens = tokens[:, :max_len] with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) text_mask = length_to_mask(input_lengths).to(self.device) # Encode text t_en = self.model.text_encoder(tokens, input_lengths, text_mask) bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) # Sample style s_pred = self.sampler( noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), embedding=bert_dur, embedding_scale=embedding_scale, features=ref_s, num_steps=diffusion_steps, ).squeeze(1) s = s_pred[:, 128:] ref = s_pred[:, :128] # Mix styles ref = alpha * ref + (1 - alpha) * ref_s[:, :128] s = beta * s + (1 - beta) * ref_s[:, 128:] # Predict duration d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = self.model.predictor.lstm(d) duration = self.model.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) # Build alignment pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) # Adjust for hifigan decoder if self.model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(en) asr_new[:, :, 0] = en[:, :, 0] asr_new[:, :, 1:] = en[:, :, 0:-1] en = asr_new # Predict F0 and energy F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) # Encode for decoder asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) if self.model_params.decoder.type == "hifigan": asr_new = torch.zeros_like(asr) asr_new[:, :, 0] = asr[:, :, 0] asr_new[:, :, 1:] = asr[:, :, 0:-1] asr = asr_new # Decode waveform out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy()[..., :-50] def save_wav(self, wav: np.ndarray, path: str, sr: int = 24000): """ Save waveform to WAV file. Args: wav: Audio waveform as numpy array path: Output file path sr: Sample rate """ import scipy.io.wavfile as wavfile wav_int16 = (wav * 32767).clip(-32768, 32767).astype(np.int16) wavfile.write(path, sr, wav_int16) print(f"Saved audio to {path}") def play(self, wav: np.ndarray, sr: int = 24000): """ Play audio through speakers (requires pyaudio). Args: wav: Audio waveform as numpy array sr: Sample rate """ try: import pyaudio except ImportError: raise ImportError("pyaudio is required for playback. Install with: pip install pyaudio") audio_int16 = (wav * 32767.0).clip(-32768, 32767).astype("int16").tobytes() p = pyaudio.PyAudio() stream = p.open(format=pyaudio.paInt16, channels=1, rate=sr, output=True) stream.write(audio_int16) stream.stop_stream() stream.close() p.terminate()