File size: 6,997 Bytes
bcf5d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
modules/tts_engine.py
──────────────────────────────────────────────────────────────────────────────
VoiceVerse Pro β€” Stable Dual-Speaker TTS Engine
"""

from __future__ import annotations
import io
import logging
import re
import gc
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import numpy as np
import soundfile as sf
import torch

logger = logging.getLogger(__name__)

class TTSBackend(str, Enum):
    SPEECHT5 = "SpeechT5 (Microsoft)"
    GTTS = "gTTS (Network)"

@dataclass
class TTSConfig:
    backend: TTSBackend = TTSBackend.SPEECHT5
    speecht5_model: str = "microsoft/speecht5_tts"
    speecht5_vocoder: str = "microsoft/speecht5_hifigan"
    speecht5_embeddings_dataset: str = "Matthijs/cmu-arctic-xvectors"
    # Single speaker default
    speaker_id: int = 7306
    # Podcast defaults
    female_speaker_id: int = 7306
    male_speaker_id: int = 1138
    sample_rate: int = 16_000
    max_chunk_chars: int = 250 

class TTSEngine:
    def __init__(self, config: Optional[TTSConfig] = None) -> None:
        self.config = config or TTSConfig()
        self._st5_pipe = None
        self._emb_cache: dict = {}  # Cache multiple speaker IDs

    def synthesize(self, script: str) -> bytes:
        """Solo narration logic (Same as your working code)"""
        script = script.strip()
        if not script: raise ValueError("Empty script.")
        
        with torch.inference_mode():
            if self.config.backend == TTSBackend.SPEECHT5:
                # Reuse the podcast logic but with one speaker for consistency
                return self._run_st5_engine([(None, script)], solo=True)
            return self._synthesize_gtts(script)

    def synthesize_podcast(self, script: str) -> bytes:
        """Podcast logic using the stable chunk/cleanup pattern"""
        script = script.strip()
        if not script: raise ValueError("Empty script.")

        if self.config.backend == TTSBackend.GTTS:
            return self._synthesize_gtts(re.sub(r'\[.*?\]', '', script))

        lines = self._parse_podcast_lines(script)
        with torch.inference_mode():
            return self._run_st5_engine(lines, solo=False)

    def _run_st5_engine(self, lines: list[tuple[Optional[str], str]], solo: bool) -> bytes:
        """The stable core loop with gc.collect() and inference_mode."""
        pipe = self._get_pipe()
        all_audio: list[np.ndarray] = []

        for speaker, text in lines:
            # Determine which embedding to use
            if solo:
                emb = self._get_embedding(self.config.speaker_id)
            else:
                spk_id = self.config.female_speaker_id if speaker == "HOST" else self.config.male_speaker_id
                emb = self._get_embedding(spk_id)

            chunks = self._split_into_chunks(text, self.config.max_chunk_chars)
            
            for chunk in chunks:
                if not chunk.strip(): continue
                try:
                    result = pipe(chunk.strip(), forward_params={"speaker_embeddings": emb})
                    audio_np = np.array(result["audio"], dtype=np.float32).squeeze()
                    all_audio.append(audio_np)
                    
                    # Short silence between sentences
                    all_audio.append(np.zeros(int(self.config.sample_rate * 0.2), dtype=np.float32))
                    
                    # CRITICAL CLEANUP
                    del result
                    gc.collect()
                except Exception as exc:
                    logger.error("Chunk failed: %s", exc)
                    gc.collect()
            
            # Longer pause between speaker turns
            if not solo:
                all_audio.append(np.zeros(int(self.config.sample_rate * 0.5), dtype=np.float32))

        if not all_audio: raise RuntimeError("TTS produced no audio.")
        return self._numpy_to_wav_bytes(np.concatenate(all_audio), self.config.sample_rate)

    def _get_pipe(self):
        if self._st5_pipe is not None: return self._st5_pipe
        from transformers import pipeline, SpeechT5HifiGan
        vocoder = SpeechT5HifiGan.from_pretrained(self.config.speecht5_vocoder)
        self._st5_pipe = pipeline("text-to-speech", model=self.config.speecht5_model, vocoder=vocoder, device=-1)
        return self._st5_pipe

    def _get_embedding(self, speaker_id: int):
        if speaker_id in self._emb_cache: return self._emb_cache[speaker_id]
        from datasets import load_dataset
        ds = load_dataset(self.config.speecht5_embeddings_dataset, split="validation")
        vector = ds[speaker_id]["xvector"]
        self._emb_cache[speaker_id] = torch.tensor(vector, dtype=torch.float32).view(1, -1)
        return self._emb_cache[speaker_id]

    @staticmethod
    def _parse_podcast_lines(script: str) -> list[tuple[str, str]]:
        result = []
        current_speaker, current_text = None, []
        for line in script.splitlines():
            s = line.strip()
            if not s: continue
            h_match = re.match(r'^\[HOST\]:?\s*(.*)', s, re.IGNORECASE)
            g_match = re.match(r'^\[GUEST\]:?\s*(.*)', s, re.IGNORECASE)
            if h_match:
                if current_speaker: result.append((current_speaker, " ".join(current_text)))
                current_speaker, current_text = "HOST", [h_match.group(1)]
            elif g_match:
                if current_speaker: result.append((current_speaker, " ".join(current_text)))
                current_speaker, current_text = "GUEST", [g_match.group(1)]
            elif current_speaker:
                current_text.append(s)
        if current_speaker: result.append((current_speaker, " ".join(current_text)))
        return result

    @staticmethod
    def _split_into_chunks(text: str, max_chars: int) -> list[str]:
        sentences = re.split(r"(?<=[.!?])\s+", text)
        chunks, current = [], ""
        for s in sentences:
            if not s.strip(): continue
            if len(current) + len(s) + 1 > max_chars and current:
                chunks.append(current.strip())
                current = s
            else:
                current = f"{current} {s}".strip() if current else s
        if current.strip(): chunks.append(current.strip())
        return chunks

    @staticmethod
    def _numpy_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
        max_val = np.abs(audio).max()
        if max_val > 1e-6: audio = audio / max_val * 0.95
        buf = io.BytesIO()
        sf.write(buf, audio, sample_rate, format="WAV", subtype="PCM_16")
        buf.seek(0)
        return buf.read()

    @staticmethod
    def _synthesize_gtts(script: str) -> bytes:
        from gtts import gTTS
        buf = io.BytesIO()
        gTTS(text=script, lang="en").write_to_fp(buf)
        return buf.getvalue()