|
|
""" |
|
|
Main TTS Engine for SYSPIN Multi-lingual TTS |
|
|
Loads and runs VITS models for inference |
|
|
Supports: |
|
|
- JIT traced models (.pt) - Hindi, Bengali, Kannada, etc. |
|
|
- Coqui TTS checkpoints (.pth) - Bhojpuri, etc. |
|
|
- Facebook MMS models - Gujarati |
|
|
Includes style/prosody control |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Union, List, Tuple, Any |
|
|
import numpy as np |
|
|
import torch |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS |
|
|
from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer |
|
|
from .downloader import ModelDownloader |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TTSOutput: |
|
|
"""Output from TTS synthesis""" |
|
|
|
|
|
audio: np.ndarray |
|
|
sample_rate: int |
|
|
duration: float |
|
|
voice: str |
|
|
text: str |
|
|
style: Optional[str] = None |
|
|
|
|
|
|
|
|
class StyleProcessor: |
|
|
""" |
|
|
Simple prosody/style control via audio post-processing |
|
|
Supports pitch shifting, speed change, and energy modification |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def apply_pitch_shift( |
|
|
audio: np.ndarray, sample_rate: int, pitch_factor: float |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Shift pitch without changing duration using phase vocoder |
|
|
pitch_factor > 1.0 = higher pitch, < 1.0 = lower pitch |
|
|
""" |
|
|
if pitch_factor == 1.0: |
|
|
return audio |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
|
|
|
|
|
|
semitones = 12 * np.log2(pitch_factor) |
|
|
shifted = librosa.effects.pitch_shift( |
|
|
audio.astype(np.float32), sr=sample_rate, n_steps=semitones |
|
|
) |
|
|
return shifted |
|
|
except ImportError: |
|
|
|
|
|
from scipy import signal |
|
|
|
|
|
|
|
|
stretched = signal.resample(audio, int(len(audio) / pitch_factor)) |
|
|
return signal.resample(stretched, len(audio)) |
|
|
|
|
|
@staticmethod |
|
|
def apply_speed_change( |
|
|
audio: np.ndarray, sample_rate: int, speed_factor: float |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Change speed/tempo without changing pitch |
|
|
speed_factor > 1.0 = faster, < 1.0 = slower |
|
|
""" |
|
|
if speed_factor == 1.0: |
|
|
return audio |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
|
|
|
|
|
|
stretched = librosa.effects.time_stretch( |
|
|
audio.astype(np.float32), rate=speed_factor |
|
|
) |
|
|
return stretched |
|
|
except ImportError: |
|
|
|
|
|
from scipy import signal |
|
|
|
|
|
target_length = int(len(audio) / speed_factor) |
|
|
return signal.resample(audio, target_length) |
|
|
|
|
|
@staticmethod |
|
|
def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray: |
|
|
""" |
|
|
Modify audio energy/volume |
|
|
energy_factor > 1.0 = louder, < 1.0 = softer |
|
|
""" |
|
|
if energy_factor == 1.0: |
|
|
return audio |
|
|
|
|
|
|
|
|
modified = audio * energy_factor |
|
|
|
|
|
|
|
|
if energy_factor > 1.0: |
|
|
max_val = np.max(np.abs(modified)) |
|
|
if max_val > 0.95: |
|
|
modified = np.tanh(modified * 2) * 0.95 |
|
|
|
|
|
return modified |
|
|
|
|
|
@staticmethod |
|
|
def apply_style( |
|
|
audio: np.ndarray, |
|
|
sample_rate: int, |
|
|
speed: float = 1.0, |
|
|
pitch: float = 1.0, |
|
|
energy: float = 1.0, |
|
|
) -> np.ndarray: |
|
|
"""Apply all style modifications""" |
|
|
result = audio |
|
|
|
|
|
|
|
|
if pitch != 1.0: |
|
|
result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch) |
|
|
|
|
|
if speed != 1.0: |
|
|
result = StyleProcessor.apply_speed_change(result, sample_rate, speed) |
|
|
|
|
|
if energy != 1.0: |
|
|
result = StyleProcessor.apply_energy_change(result, energy) |
|
|
|
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def get_preset(preset_name: str) -> Dict[str, float]: |
|
|
"""Get style parameters from preset name""" |
|
|
return STYLE_PRESETS.get(preset_name, STYLE_PRESETS["default"]) |
|
|
|
|
|
|
|
|
class TTSEngine: |
|
|
""" |
|
|
Multi-lingual TTS Engine using SYSPIN VITS models |
|
|
|
|
|
Supports 11 Indian languages with male/female voices: |
|
|
- Hindi, Bengali, Marathi, Telugu, Kannada |
|
|
- Bhojpuri, Chhattisgarhi, Maithili, Magahi, English |
|
|
- Gujarati (via Facebook MMS) |
|
|
|
|
|
Features: |
|
|
- Style/prosody control (pitch, speed, energy) |
|
|
- Preset styles (happy, sad, calm, excited, etc.) |
|
|
- JIT traced models (.pt) and Coqui TTS checkpoints (.pth) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
models_dir: str = MODELS_DIR, |
|
|
device: str = "auto", |
|
|
preload_voices: Optional[List[str]] = None, |
|
|
): |
|
|
""" |
|
|
Initialize TTS Engine |
|
|
|
|
|
Args: |
|
|
models_dir: Directory containing downloaded models |
|
|
device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto') |
|
|
preload_voices: List of voice keys to preload into memory |
|
|
""" |
|
|
self.models_dir = Path(models_dir) |
|
|
self.device = self._get_device(device) |
|
|
|
|
|
|
|
|
self._models: Dict[str, torch.jit.ScriptModule] = {} |
|
|
self._tokenizers: Dict[str, TTSTokenizer] = {} |
|
|
|
|
|
|
|
|
self._coqui_models: Dict[str, Any] = {} |
|
|
|
|
|
|
|
|
self._mms_models: Dict[str, Any] = {} |
|
|
self._mms_tokenizers: Dict[str, Any] = {} |
|
|
|
|
|
|
|
|
self.downloader = ModelDownloader(models_dir) |
|
|
|
|
|
|
|
|
self.normalizer = TextNormalizer() |
|
|
|
|
|
|
|
|
self.style_processor = StyleProcessor() |
|
|
|
|
|
|
|
|
if preload_voices: |
|
|
for voice in preload_voices: |
|
|
self.load_voice(voice) |
|
|
|
|
|
logger.info(f"TTS Engine initialized on device: {self.device}") |
|
|
|
|
|
def _get_device(self, device: str) -> torch.device: |
|
|
"""Determine the best device for inference""" |
|
|
if device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
return torch.device("cpu") |
|
|
return torch.device(device) |
|
|
|
|
|
def load_voice(self, voice_key: str, download_if_missing: bool = True) -> bool: |
|
|
""" |
|
|
Load a voice model into memory |
|
|
|
|
|
Args: |
|
|
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male') |
|
|
download_if_missing: Download model if not found locally |
|
|
|
|
|
Returns: |
|
|
True if loaded successfully |
|
|
""" |
|
|
|
|
|
if voice_key in self._models or voice_key in self._coqui_models: |
|
|
return True |
|
|
|
|
|
if voice_key not in LANGUAGE_CONFIGS: |
|
|
raise ValueError(f"Unknown voice: {voice_key}") |
|
|
|
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
model_dir = self.models_dir / voice_key |
|
|
|
|
|
|
|
|
if not model_dir.exists(): |
|
|
if download_if_missing: |
|
|
logger.info(f"Model not found, downloading {voice_key}...") |
|
|
self.downloader.download_model(voice_key) |
|
|
else: |
|
|
raise FileNotFoundError(f"Model directory not found: {model_dir}") |
|
|
|
|
|
|
|
|
pth_files = list(model_dir.glob("*.pth")) |
|
|
pt_files = list(model_dir.glob("*.pt")) |
|
|
|
|
|
if pth_files: |
|
|
|
|
|
return self._load_coqui_voice(voice_key, model_dir, pth_files[0]) |
|
|
elif pt_files: |
|
|
|
|
|
return self._load_jit_voice(voice_key, model_dir, pt_files[0]) |
|
|
else: |
|
|
raise FileNotFoundError(f"No .pt or .pth model file found in {model_dir}") |
|
|
|
|
|
def _load_jit_voice( |
|
|
self, voice_key: str, model_dir: Path, model_path: Path |
|
|
) -> bool: |
|
|
""" |
|
|
Load a JIT traced VITS model (.pt file) |
|
|
""" |
|
|
|
|
|
chars_path = model_dir / "chars.txt" |
|
|
if chars_path.exists(): |
|
|
tokenizer = TTSTokenizer.from_chars_file(str(chars_path)) |
|
|
else: |
|
|
|
|
|
chars_files = list(model_dir.glob("*chars*.txt")) |
|
|
if chars_files: |
|
|
tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0])) |
|
|
else: |
|
|
raise FileNotFoundError(f"No chars.txt found in {model_dir}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading JIT model from {model_path}") |
|
|
model = torch.jit.load(str(model_path), map_location=self.device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
self._models[voice_key] = model |
|
|
self._tokenizers[voice_key] = tokenizer |
|
|
|
|
|
logger.info(f"Loaded JIT voice: {voice_key}") |
|
|
return True |
|
|
|
|
|
def _load_coqui_voice( |
|
|
self, voice_key: str, model_dir: Path, checkpoint_path: Path |
|
|
) -> bool: |
|
|
""" |
|
|
Load a Coqui TTS checkpoint model (.pth file) |
|
|
""" |
|
|
config_path = model_dir / "config.json" |
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError(f"No config.json found in {model_dir}") |
|
|
|
|
|
try: |
|
|
from TTS.utils.synthesizer import Synthesizer |
|
|
|
|
|
logger.info(f"Loading Coqui TTS checkpoint from {checkpoint_path}") |
|
|
|
|
|
|
|
|
use_cuda = self.device.type == "cuda" |
|
|
synthesizer = Synthesizer( |
|
|
tts_checkpoint=str(checkpoint_path), |
|
|
tts_config_path=str(config_path), |
|
|
use_cuda=use_cuda, |
|
|
) |
|
|
|
|
|
|
|
|
self._coqui_models[voice_key] = synthesizer |
|
|
|
|
|
logger.info(f"Loaded Coqui voice: {voice_key}") |
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Coqui TTS library not installed. " "Install it with: pip install TTS" |
|
|
) |
|
|
|
|
|
def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: |
|
|
""" |
|
|
Synthesize using Coqui TTS model (for Bhojpuri etc.) |
|
|
""" |
|
|
if voice_key not in self._coqui_models: |
|
|
self.load_voice(voice_key) |
|
|
|
|
|
synthesizer = self._coqui_models[voice_key] |
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
|
|
|
|
|
|
wav = synthesizer.tts(text) |
|
|
|
|
|
|
|
|
audio_np = np.array(wav, dtype=np.float32) |
|
|
sample_rate = synthesizer.output_sample_rate |
|
|
|
|
|
return audio_np, sample_rate |
|
|
|
|
|
def _load_mms_voice(self, voice_key: str) -> bool: |
|
|
""" |
|
|
Load Facebook MMS model for Gujarati |
|
|
""" |
|
|
if voice_key in self._mms_models: |
|
|
return True |
|
|
|
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
logger.info(f"Loading MMS model: {config.hf_model_id}") |
|
|
|
|
|
try: |
|
|
from transformers import VitsModel, AutoTokenizer |
|
|
|
|
|
|
|
|
model = VitsModel.from_pretrained(config.hf_model_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id) |
|
|
|
|
|
model = model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
self._mms_models[voice_key] = model |
|
|
self._mms_tokenizers[voice_key] = tokenizer |
|
|
|
|
|
logger.info(f"Loaded MMS voice: {voice_key}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MMS model: {e}") |
|
|
raise |
|
|
|
|
|
def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: |
|
|
""" |
|
|
Synthesize using Facebook MMS model (for Gujarati) |
|
|
""" |
|
|
if voice_key not in self._mms_models: |
|
|
self._load_mms_voice(voice_key) |
|
|
|
|
|
model = self._mms_models[voice_key] |
|
|
tokenizer = self._mms_tokenizers[voice_key] |
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(**inputs) |
|
|
|
|
|
|
|
|
audio = output.waveform.squeeze().cpu().numpy() |
|
|
return audio, config.sample_rate |
|
|
|
|
|
def unload_voice(self, voice_key: str): |
|
|
"""Unload a voice to free memory""" |
|
|
if voice_key in self._models: |
|
|
del self._models[voice_key] |
|
|
del self._tokenizers[voice_key] |
|
|
if voice_key in self._coqui_models: |
|
|
del self._coqui_models[voice_key] |
|
|
if voice_key in self._mms_models: |
|
|
del self._mms_models[voice_key] |
|
|
del self._mms_tokenizers[voice_key] |
|
|
torch.cuda.empty_cache() if self.device.type == "cuda" else None |
|
|
logger.info(f"Unloaded voice: {voice_key}") |
|
|
|
|
|
def synthesize( |
|
|
self, |
|
|
text: str, |
|
|
voice: str = "hi_male", |
|
|
speed: float = 1.0, |
|
|
pitch: float = 1.0, |
|
|
energy: float = 1.0, |
|
|
style: Optional[str] = None, |
|
|
normalize_text: bool = True, |
|
|
) -> TTSOutput: |
|
|
""" |
|
|
Synthesize speech from text with style control |
|
|
|
|
|
Args: |
|
|
text: Input text to synthesize |
|
|
voice: Voice key (e.g., 'hi_male', 'bn_female', 'gu_mms') |
|
|
speed: Speech speed multiplier (0.5-2.0) |
|
|
pitch: Pitch multiplier (0.5-2.0), >1 = higher |
|
|
energy: Energy/volume multiplier (0.5-2.0) |
|
|
style: Style preset name (e.g., 'happy', 'sad', 'calm') |
|
|
normalize_text: Whether to apply text normalization |
|
|
|
|
|
Returns: |
|
|
TTSOutput with audio array and metadata |
|
|
""" |
|
|
|
|
|
if style and style in STYLE_PRESETS: |
|
|
preset = STYLE_PRESETS[style] |
|
|
speed = speed * preset["speed"] |
|
|
pitch = pitch * preset["pitch"] |
|
|
energy = energy * preset["energy"] |
|
|
|
|
|
config = LANGUAGE_CONFIGS[voice] |
|
|
|
|
|
|
|
|
if normalize_text: |
|
|
text = self.normalizer.clean_text(text, config.code) |
|
|
|
|
|
|
|
|
if "mms" in voice: |
|
|
audio_np, sample_rate = self._synthesize_mms(text, voice) |
|
|
|
|
|
elif voice in self._coqui_models: |
|
|
audio_np, sample_rate = self._synthesize_coqui(text, voice) |
|
|
else: |
|
|
|
|
|
if voice not in self._models and voice not in self._coqui_models: |
|
|
self.load_voice(voice) |
|
|
|
|
|
|
|
|
if voice in self._coqui_models: |
|
|
audio_np, sample_rate = self._synthesize_coqui(text, voice) |
|
|
else: |
|
|
|
|
|
model = self._models[voice] |
|
|
tokenizer = self._tokenizers[voice] |
|
|
|
|
|
|
|
|
token_ids = tokenizer.text_to_ids(text) |
|
|
x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
audio = model(x) |
|
|
|
|
|
audio_np = audio.squeeze().cpu().numpy() |
|
|
sample_rate = config.sample_rate |
|
|
|
|
|
|
|
|
audio_np = self.style_processor.apply_style( |
|
|
audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy |
|
|
) |
|
|
|
|
|
|
|
|
duration = len(audio_np) / sample_rate |
|
|
|
|
|
return TTSOutput( |
|
|
audio=audio_np, |
|
|
sample_rate=sample_rate, |
|
|
duration=duration, |
|
|
voice=voice, |
|
|
text=text, |
|
|
style=style, |
|
|
) |
|
|
|
|
|
def synthesize_to_file( |
|
|
self, |
|
|
text: str, |
|
|
output_path: str, |
|
|
voice: str = "hi_male", |
|
|
speed: float = 1.0, |
|
|
pitch: float = 1.0, |
|
|
energy: float = 1.0, |
|
|
style: Optional[str] = None, |
|
|
normalize_text: bool = True, |
|
|
) -> str: |
|
|
""" |
|
|
Synthesize speech and save to file |
|
|
|
|
|
Args: |
|
|
text: Input text to synthesize |
|
|
output_path: Path to save audio file |
|
|
voice: Voice key |
|
|
speed: Speech speed multiplier |
|
|
pitch: Pitch multiplier |
|
|
energy: Energy multiplier |
|
|
style: Style preset name |
|
|
normalize_text: Whether to apply text normalization |
|
|
|
|
|
Returns: |
|
|
Path to saved file |
|
|
""" |
|
|
import soundfile as sf |
|
|
|
|
|
output = self.synthesize( |
|
|
text, voice, speed, pitch, energy, style, normalize_text |
|
|
) |
|
|
sf.write(output_path, output.audio, output.sample_rate) |
|
|
|
|
|
logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)") |
|
|
return output_path |
|
|
|
|
|
def get_loaded_voices(self) -> List[str]: |
|
|
"""Get list of currently loaded voices""" |
|
|
return ( |
|
|
list(self._models.keys()) |
|
|
+ list(self._coqui_models.keys()) |
|
|
+ list(self._mms_models.keys()) |
|
|
) |
|
|
|
|
|
def get_available_voices(self) -> Dict[str, Dict]: |
|
|
"""Get all available voices with their status""" |
|
|
voices = {} |
|
|
for key, config in LANGUAGE_CONFIGS.items(): |
|
|
is_mms = "mms" in key |
|
|
model_dir = self.models_dir / key |
|
|
|
|
|
|
|
|
if is_mms: |
|
|
model_type = "mms" |
|
|
elif model_dir.exists() and list(model_dir.glob("*.pth")): |
|
|
model_type = "coqui" |
|
|
else: |
|
|
model_type = "vits" |
|
|
|
|
|
voices[key] = { |
|
|
"name": config.name, |
|
|
"code": config.code, |
|
|
"gender": ( |
|
|
"male" |
|
|
if "male" in key |
|
|
else ("female" if "female" in key else "neutral") |
|
|
), |
|
|
"loaded": key in self._models |
|
|
or key in self._coqui_models |
|
|
or key in self._mms_models, |
|
|
"downloaded": is_mms or self.downloader.get_model_path(key) is not None, |
|
|
"type": model_type, |
|
|
} |
|
|
return voices |
|
|
|
|
|
def get_style_presets(self) -> Dict[str, Dict]: |
|
|
"""Get available style presets""" |
|
|
return STYLE_PRESETS |
|
|
|
|
|
def batch_synthesize( |
|
|
self, texts: List[str], voice: str = "hi_male", speed: float = 1.0 |
|
|
) -> List[TTSOutput]: |
|
|
"""Synthesize multiple texts""" |
|
|
return [self.synthesize(text, voice, speed) for text in texts] |
|
|
|
|
|
|
|
|
|
|
|
def synthesize( |
|
|
text: str, voice: str = "hi_male", output_path: Optional[str] = None |
|
|
) -> Union[TTSOutput, str]: |
|
|
""" |
|
|
Quick synthesis function |
|
|
|
|
|
Args: |
|
|
text: Text to synthesize |
|
|
voice: Voice key |
|
|
output_path: If provided, saves to file and returns path |
|
|
|
|
|
Returns: |
|
|
TTSOutput if no output_path, else path to saved file |
|
|
""" |
|
|
engine = TTSEngine() |
|
|
|
|
|
if output_path: |
|
|
return engine.synthesize_to_file(text, output_path, voice) |
|
|
return engine.synthesize(text, voice) |
|
|
|