| | """ |
| | modules/tts_engine.py |
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | VoiceVerse Pro β Stable Dual-Speaker TTS Engine |
| | """ |
| |
|
| | from __future__ import annotations |
| | import io |
| | import logging |
| | import re |
| | import gc |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | from typing import Optional |
| | import numpy as np |
| | import soundfile as sf |
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class TTSBackend(str, Enum): |
| | SPEECHT5 = "SpeechT5 (Microsoft)" |
| | GTTS = "gTTS (Network)" |
| |
|
| | @dataclass |
| | class TTSConfig: |
| | backend: TTSBackend = TTSBackend.SPEECHT5 |
| | speecht5_model: str = "microsoft/speecht5_tts" |
| | speecht5_vocoder: str = "microsoft/speecht5_hifigan" |
| | speecht5_embeddings_dataset: str = "Matthijs/cmu-arctic-xvectors" |
| | |
| | speaker_id: int = 7306 |
| | |
| | female_speaker_id: int = 7306 |
| | male_speaker_id: int = 1138 |
| | sample_rate: int = 16_000 |
| | max_chunk_chars: int = 250 |
| |
|
| | class TTSEngine: |
| | def __init__(self, config: Optional[TTSConfig] = None) -> None: |
| | self.config = config or TTSConfig() |
| | self._st5_pipe = None |
| | self._emb_cache: dict = {} |
| |
|
| | def synthesize(self, script: str) -> bytes: |
| | """Solo narration logic (Same as your working code)""" |
| | script = script.strip() |
| | if not script: raise ValueError("Empty script.") |
| | |
| | with torch.inference_mode(): |
| | if self.config.backend == TTSBackend.SPEECHT5: |
| | |
| | return self._run_st5_engine([(None, script)], solo=True) |
| | return self._synthesize_gtts(script) |
| |
|
| | def synthesize_podcast(self, script: str) -> bytes: |
| | """Podcast logic using the stable chunk/cleanup pattern""" |
| | script = script.strip() |
| | if not script: raise ValueError("Empty script.") |
| |
|
| | if self.config.backend == TTSBackend.GTTS: |
| | return self._synthesize_gtts(re.sub(r'\[.*?\]', '', script)) |
| |
|
| | lines = self._parse_podcast_lines(script) |
| | with torch.inference_mode(): |
| | return self._run_st5_engine(lines, solo=False) |
| |
|
| | def _run_st5_engine(self, lines: list[tuple[Optional[str], str]], solo: bool) -> bytes: |
| | """The stable core loop with gc.collect() and inference_mode.""" |
| | pipe = self._get_pipe() |
| | all_audio: list[np.ndarray] = [] |
| |
|
| | for speaker, text in lines: |
| | |
| | if solo: |
| | emb = self._get_embedding(self.config.speaker_id) |
| | else: |
| | spk_id = self.config.female_speaker_id if speaker == "HOST" else self.config.male_speaker_id |
| | emb = self._get_embedding(spk_id) |
| |
|
| | chunks = self._split_into_chunks(text, self.config.max_chunk_chars) |
| | |
| | for chunk in chunks: |
| | if not chunk.strip(): continue |
| | try: |
| | result = pipe(chunk.strip(), forward_params={"speaker_embeddings": emb}) |
| | audio_np = np.array(result["audio"], dtype=np.float32).squeeze() |
| | all_audio.append(audio_np) |
| | |
| | |
| | all_audio.append(np.zeros(int(self.config.sample_rate * 0.2), dtype=np.float32)) |
| | |
| | |
| | del result |
| | gc.collect() |
| | except Exception as exc: |
| | logger.error("Chunk failed: %s", exc) |
| | gc.collect() |
| | |
| | |
| | if not solo: |
| | all_audio.append(np.zeros(int(self.config.sample_rate * 0.5), dtype=np.float32)) |
| |
|
| | if not all_audio: raise RuntimeError("TTS produced no audio.") |
| | return self._numpy_to_wav_bytes(np.concatenate(all_audio), self.config.sample_rate) |
| |
|
| | def _get_pipe(self): |
| | if self._st5_pipe is not None: return self._st5_pipe |
| | from transformers import pipeline, SpeechT5HifiGan |
| | vocoder = SpeechT5HifiGan.from_pretrained(self.config.speecht5_vocoder) |
| | self._st5_pipe = pipeline("text-to-speech", model=self.config.speecht5_model, vocoder=vocoder, device=-1) |
| | return self._st5_pipe |
| |
|
| | def _get_embedding(self, speaker_id: int): |
| | if speaker_id in self._emb_cache: return self._emb_cache[speaker_id] |
| | from datasets import load_dataset |
| | ds = load_dataset(self.config.speecht5_embeddings_dataset, split="validation") |
| | vector = ds[speaker_id]["xvector"] |
| | self._emb_cache[speaker_id] = torch.tensor(vector, dtype=torch.float32).view(1, -1) |
| | return self._emb_cache[speaker_id] |
| |
|
| | @staticmethod |
| | def _parse_podcast_lines(script: str) -> list[tuple[str, str]]: |
| | result = [] |
| | current_speaker, current_text = None, [] |
| | for line in script.splitlines(): |
| | s = line.strip() |
| | if not s: continue |
| | h_match = re.match(r'^\[HOST\]:?\s*(.*)', s, re.IGNORECASE) |
| | g_match = re.match(r'^\[GUEST\]:?\s*(.*)', s, re.IGNORECASE) |
| | if h_match: |
| | if current_speaker: result.append((current_speaker, " ".join(current_text))) |
| | current_speaker, current_text = "HOST", [h_match.group(1)] |
| | elif g_match: |
| | if current_speaker: result.append((current_speaker, " ".join(current_text))) |
| | current_speaker, current_text = "GUEST", [g_match.group(1)] |
| | elif current_speaker: |
| | current_text.append(s) |
| | if current_speaker: result.append((current_speaker, " ".join(current_text))) |
| | return result |
| |
|
| | @staticmethod |
| | def _split_into_chunks(text: str, max_chars: int) -> list[str]: |
| | sentences = re.split(r"(?<=[.!?])\s+", text) |
| | chunks, current = [], "" |
| | for s in sentences: |
| | if not s.strip(): continue |
| | if len(current) + len(s) + 1 > max_chars and current: |
| | chunks.append(current.strip()) |
| | current = s |
| | else: |
| | current = f"{current} {s}".strip() if current else s |
| | if current.strip(): chunks.append(current.strip()) |
| | return chunks |
| |
|
| | @staticmethod |
| | def _numpy_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: |
| | max_val = np.abs(audio).max() |
| | if max_val > 1e-6: audio = audio / max_val * 0.95 |
| | buf = io.BytesIO() |
| | sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16") |
| | buf.seek(0) |
| | return buf.read() |
| |
|
| | @staticmethod |
| | def _synthesize_gtts(script: str) -> bytes: |
| | from gtts import gTTS |
| | buf = io.BytesIO() |
| | gTTS(text=script, lang="en").write_to_fp(buf) |
| | return buf.getvalue() |