Spaces:
Running
Running
| """ | |
| modules/tts_engine.py | |
| ────────────────────────────────────────────────────────────────────────────── | |
| VoiceVerse Pro — Stable Dual-Speaker TTS Engine | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import logging | |
| import re | |
| import gc | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Optional | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class TTSBackend(str, Enum): | |
| SPEECHT5 = "SpeechT5 (Microsoft)" | |
| GTTS = "gTTS (Network)" | |
| class TTSConfig: | |
| backend: TTSBackend = TTSBackend.SPEECHT5 | |
| speecht5_model: str = "microsoft/speecht5_tts" | |
| speecht5_vocoder: str = "microsoft/speecht5_hifigan" | |
| speecht5_embeddings_dataset: str = "Matthijs/cmu-arctic-xvectors" | |
| # Single speaker default | |
| speaker_id: int = 7306 | |
| # Podcast defaults | |
| female_speaker_id: int = 7306 | |
| male_speaker_id: int = 1138 | |
| sample_rate: int = 16_000 | |
| max_chunk_chars: int = 250 | |
| class TTSEngine: | |
| def __init__(self, config: Optional[TTSConfig] = None) -> None: | |
| self.config = config or TTSConfig() | |
| self._st5_pipe = None | |
| self._emb_cache: dict = {} # Cache multiple speaker IDs | |
| def synthesize(self, script: str) -> bytes: | |
| """Solo narration logic (Same as your working code)""" | |
| script = script.strip() | |
| if not script: raise ValueError("Empty script.") | |
| with torch.inference_mode(): | |
| if self.config.backend == TTSBackend.SPEECHT5: | |
| # Reuse the podcast logic but with one speaker for consistency | |
| return self._run_st5_engine([(None, script)], solo=True) | |
| return self._synthesize_gtts(script) | |
| def synthesize_podcast(self, script: str) -> bytes: | |
| """Podcast logic using the stable chunk/cleanup pattern""" | |
| script = script.strip() | |
| if not script: raise ValueError("Empty script.") | |
| if self.config.backend == TTSBackend.GTTS: | |
| return self._synthesize_gtts(re.sub(r'\[.*?\]', '', script)) | |
| lines = self._parse_podcast_lines(script) | |
| with torch.inference_mode(): | |
| return self._run_st5_engine(lines, solo=False) | |
| def _run_st5_engine(self, lines: list[tuple[Optional[str], str]], solo: bool) -> bytes: | |
| """The stable core loop with gc.collect() and inference_mode.""" | |
| pipe = self._get_pipe() | |
| all_audio: list[np.ndarray] = [] | |
| for speaker, text in lines: | |
| # Determine which embedding to use | |
| if solo: | |
| emb = self._get_embedding(self.config.speaker_id) | |
| else: | |
| spk_id = self.config.female_speaker_id if speaker == "HOST" else self.config.male_speaker_id | |
| emb = self._get_embedding(spk_id) | |
| chunks = self._split_into_chunks(text, self.config.max_chunk_chars) | |
| for chunk in chunks: | |
| if not chunk.strip(): continue | |
| try: | |
| result = pipe(chunk.strip(), forward_params={"speaker_embeddings": emb}) | |
| audio_np = np.array(result["audio"], dtype=np.float32).squeeze() | |
| all_audio.append(audio_np) | |
| # Short silence between sentences | |
| all_audio.append(np.zeros(int(self.config.sample_rate * 0.2), dtype=np.float32)) | |
| # CRITICAL CLEANUP | |
| del result | |
| gc.collect() | |
| except Exception as exc: | |
| logger.error("Chunk failed: %s", exc) | |
| gc.collect() | |
| # Longer pause between speaker turns | |
| if not solo: | |
| all_audio.append(np.zeros(int(self.config.sample_rate * 0.5), dtype=np.float32)) | |
| if not all_audio: raise RuntimeError("TTS produced no audio.") | |
| return self._numpy_to_wav_bytes(np.concatenate(all_audio), self.config.sample_rate) | |
| def _get_pipe(self): | |
| if self._st5_pipe is not None: return self._st5_pipe | |
| from transformers import pipeline, SpeechT5HifiGan | |
| vocoder = SpeechT5HifiGan.from_pretrained(self.config.speecht5_vocoder) | |
| self._st5_pipe = pipeline("text-to-speech", model=self.config.speecht5_model, vocoder=vocoder, device=-1) | |
| return self._st5_pipe | |
| def _get_embedding(self, speaker_id: int): | |
| if speaker_id in self._emb_cache: return self._emb_cache[speaker_id] | |
| from datasets import load_dataset | |
| ds = load_dataset(self.config.speecht5_embeddings_dataset, split="validation") | |
| vector = ds[speaker_id]["xvector"] | |
| self._emb_cache[speaker_id] = torch.tensor(vector, dtype=torch.float32).view(1, -1) | |
| return self._emb_cache[speaker_id] | |
| def _parse_podcast_lines(script: str) -> list[tuple[str, str]]: | |
| result = [] | |
| current_speaker, current_text = None, [] | |
| for line in script.splitlines(): | |
| s = line.strip() | |
| if not s: continue | |
| h_match = re.match(r'^\[HOST\]:?\s*(.*)', s, re.IGNORECASE) | |
| g_match = re.match(r'^\[GUEST\]:?\s*(.*)', s, re.IGNORECASE) | |
| if h_match: | |
| if current_speaker: result.append((current_speaker, " ".join(current_text))) | |
| current_speaker, current_text = "HOST", [h_match.group(1)] | |
| elif g_match: | |
| if current_speaker: result.append((current_speaker, " ".join(current_text))) | |
| current_speaker, current_text = "GUEST", [g_match.group(1)] | |
| elif current_speaker: | |
| current_text.append(s) | |
| if current_speaker: result.append((current_speaker, " ".join(current_text))) | |
| return result | |
| def _split_into_chunks(text: str, max_chars: int) -> list[str]: | |
| sentences = re.split(r"(?<=[.!?])\s+", text) | |
| chunks, current = [], "" | |
| for s in sentences: | |
| if not s.strip(): continue | |
| if len(current) + len(s) + 1 > max_chars and current: | |
| chunks.append(current.strip()) | |
| current = s | |
| else: | |
| current = f"{current} {s}".strip() if current else s | |
| if current.strip(): chunks.append(current.strip()) | |
| return chunks | |
| def _numpy_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: | |
| max_val = np.abs(audio).max() | |
| if max_val > 1e-6: audio = audio / max_val * 0.95 | |
| buf = io.BytesIO() | |
| sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16") | |
| buf.seek(0) | |
| return buf.read() | |
| def _synthesize_gtts(script: str) -> bytes: | |
| from gtts import gTTS | |
| buf = io.BytesIO() | |
| gTTS(text=script, lang="en").write_to_fp(buf) | |
| return buf.getvalue() | |