File size: 16,115 Bytes
d28be92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
#!/usr/bin/env python3
"""
BeigeTTS - Streaming Inference
Real-time streaming text-to-speech with chunk-based generation
Research release derived from BlandAI's Khaki TTS system
"""

import torch
import numpy as np
import soundfile as sf
from neucodec import NeuCodec
from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import queue
import time
from typing import Optional, Generator, Tuple
import pyaudio
import argparse

# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════

class StreamingConfig:
    """Configuration for streaming TTS"""
    # Audio tokens
    AUDIO_START_TOKEN = 262145
    AUDIO_END_TOKEN = 262146
    NEUCODEC_BASE_OFFSET = 262154
    NEUCODEC_VOCABULARY_SIZE = 65536
    AUDIO_TOKEN_MIN = NEUCODEC_BASE_OFFSET
    AUDIO_TOKEN_MAX = NEUCODEC_BASE_OFFSET + NEUCODEC_VOCABULARY_SIZE

    # Streaming parameters
    CHUNK_SIZE = 50  # Audio tokens per chunk (~500ms)
    BUFFER_SIZE = 3  # Number of chunks to buffer
    SAMPLE_RATE = 24000

    # Generation parameters
    DEFAULT_TEMPERATURE = 0.1
    DEFAULT_TOP_P = 0.97
    MAX_TOTAL_TOKENS = 1000  # Limited in research release (Khaki: unlimited)

# ═══════════════════════════════════════════════════════════════════
# Streaming TTS Engine
# ═══════════════════════════════════════════════════════════════════

class StreamingBeigeTTS:
    """Streaming BeigeTTS engine with chunk-based generation

    Note: Production Khaki system achieves <50ms latency with 
    advanced buffering and predictive generation.
    """

    def __init__(self, model_path: str = "BlandAI/BeigeTTS"):
        """Initialize streaming TTS engine"""
        self.config = StreamingConfig()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        print("Initializing BeigeTTS streaming engine (research release)...")
        print("Note: Khaki production system supports <50ms latency and 57 languages")

        # Load models
        self._load_models(model_path)

        # Streaming state
        self.audio_queue = queue.Queue(maxsize=self.config.BUFFER_SIZE)
        self.generation_complete = threading.Event()
        self.stop_generation = threading.Event()

        # PyAudio for playback
        self.audio_interface = None
        self.audio_stream = None

    def _load_models(self, model_path: str):
        """Load BeigeTTS and NeuCodec models"""
        print("Loading BeigeTTS model...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
            device_map="auto",
            trust_remote_code=True,
        )
        self.model.eval()

        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("Loading NeuCodec...")
        self.neucodec = NeuCodec.from_pretrained("neuphonic/neucodec")
        self.neucodec.eval()
        if self.device.type == "cuda":
            self.neucodec = self.neucodec.to(self.device)

    def stream_synthesize(
        self,
        text: str,
        temperature: float = None,
        top_p: float = None,
        play_audio: bool = False
    ) -> Generator[np.ndarray, None, None]:
        """Stream synthesize speech from text

        Args:
            text: Input text to synthesize
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            play_audio: Whether to play audio in real-time

        Yields:
            Audio chunks as numpy arrays
        """
        temperature = temperature or self.config.DEFAULT_TEMPERATURE
        top_p = top_p or self.config.DEFAULT_TOP_P

        # Reset state
        self.generation_complete.clear()
        self.stop_generation.clear()

        # Start generation thread
        generation_thread = threading.Thread(
            target=self._generation_worker,
            args=(text, temperature, top_p)
        )
        generation_thread.start()

        # Start playback if requested
        if play_audio:
            playback_thread = threading.Thread(target=self._playback_worker)
            playback_thread.start()

        # Yield audio chunks
        try:
            while not self.generation_complete.is_set() or not self.audio_queue.empty():
                try:
                    audio_chunk = self.audio_queue.get(timeout=0.1)
                    yield audio_chunk
                except queue.Empty:
                    continue
        finally:
            # Cleanup
            self.stop_generation.set()
            generation_thread.join()
            if play_audio:
                playback_thread.join()

    def _generation_worker(self, text: str, temperature: float, top_p: float):
        """Worker thread for token generation"""
        try:
            # Format prompt
            prompt = f"<start_of_turn>user\n{text}<end_of_turn>\n<start_of_turn>model\n<start_of_speech>"

            # Tokenize
            inputs = self.tokenizer(prompt, return_tensors="pt")
            input_ids = inputs.input_ids.to(self.model.device)

            # Initialize generation
            past_key_values = None
            current_ids = input_ids
            audio_token_buffer = []
            total_generated = 0

            print("Starting streaming generation...")
            print("(BeigeTTS research mode - Khaki production offers superior latency)")

            with torch.no_grad():
                while total_generated < self.config.MAX_TOTAL_TOKENS and not self.stop_generation.is_set():
                    # Generate next token
                    outputs = self.model(
                        input_ids=current_ids,
                        past_key_values=past_key_values,
                        use_cache=True
                    )

                    logits = outputs.logits[:, -1, :]
                    past_key_values = outputs.past_key_values

                    # Sample next token
                    if temperature > 0:
                        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)

                        # Apply top-p filtering
                        if top_p < 1.0:
                            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

                            # Remove tokens with cumulative probability above threshold
                            sorted_indices_to_remove = cumulative_probs > top_p
                            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                            sorted_indices_to_remove[:, 0] = 0

                            indices_to_remove = sorted_indices_to_remove.scatter(
                                1, sorted_indices, sorted_indices_to_remove
                            )
                            probs[indices_to_remove] = 0
                            probs = probs / probs.sum(dim=-1, keepdim=True)

                        next_token = torch.multinomial(probs, num_samples=1)
                    else:
                        next_token = torch.argmax(logits, dim=-1, keepdim=True)

                    token_id = next_token.item()

                    # Check for end token
                    if token_id == self.config.AUDIO_END_TOKEN or token_id == self.tokenizer.eos_token_id:
                        break

                    # Process audio token
                    if self.config.AUDIO_TOKEN_MIN <= token_id < self.config.AUDIO_TOKEN_MAX:
                        audio_token_buffer.append(token_id - self.config.NEUCODEC_BASE_OFFSET)

                        # Decode and send chunk
                        if len(audio_token_buffer) >= self.config.CHUNK_SIZE:
                            audio_chunk = self._decode_chunk(audio_token_buffer[:self.config.CHUNK_SIZE])
                            self.audio_queue.put(audio_chunk)
                            audio_token_buffer = audio_token_buffer[self.config.CHUNK_SIZE:]
                            print(f"Streamed chunk {total_generated // self.config.CHUNK_SIZE}")

                    # Prepare for next iteration
                    current_ids = next_token
                    total_generated += 1

            # Process remaining tokens
            if audio_token_buffer:
                audio_chunk = self._decode_chunk(audio_token_buffer)
                self.audio_queue.put(audio_chunk)

            print(f"Generation complete. Total tokens: {total_generated}")

        except Exception as e:
            print(f"Generation error: {e}")
        finally:
            self.generation_complete.set()

    def _decode_chunk(self, audio_tokens: list) -> np.ndarray:
        """Decode a chunk of audio tokens"""
        # Prepare tokens
        audio_array = np.array(audio_tokens, dtype=np.int32)
        audio_array = np.clip(audio_array, 0, self.config.NEUCODEC_VOCABULARY_SIZE - 1)

        # Format for NeuCodec
        fsq_codes = torch.tensor(audio_array, dtype=torch.long)
        fsq_codes = fsq_codes.unsqueeze(0).unsqueeze(1)

        if self.device.type == "cuda":
            fsq_codes = fsq_codes.to(self.device)

        # Decode
        with torch.no_grad():
            wav = self.neucodec.decode_code(fsq_codes).cpu()

        # Extract waveform
        if wav.dim() == 3:
            wav = wav[0, 0]
        elif wav.dim() == 2:
            wav = wav[0]

        wav = wav.numpy()

        # Normalize
        if np.abs(wav).max() > 0:
            wav = wav / np.abs(wav).max() * 0.95

        return wav

    def _playback_worker(self):
        """Worker thread for audio playback"""
        try:
            # Initialize PyAudio
            self.audio_interface = pyaudio.PyAudio()
            self.audio_stream = self.audio_interface.open(
                format=pyaudio.paFloat32,
                channels=1,
                rate=self.config.SAMPLE_RATE,
                output=True
            )

            print("Starting audio playback...")

            while not self.generation_complete.is_set() or not self.audio_queue.empty():
                try:
                    audio_chunk = self.audio_queue.get(timeout=0.1)
                    self.audio_stream.write(audio_chunk.astype(np.float32).tobytes())
                except queue.Empty:
                    continue

        except Exception as e:
            print(f"Playback error: {e}")
        finally:
            if self.audio_stream:
                self.audio_stream.stop_stream()
                self.audio_stream.close()
            if self.audio_interface:
                self.audio_interface.terminate()

# ═══════════════════════════════════════════════════════════════════
# Advanced Streaming Features
# ═══════════════════════════════════════════════════════════════════

class AdaptiveBeigeTTS(StreamingBeigeTTS):
    """Advanced streaming with adaptive chunk sizing

    Note: This demonstrates research concepts. Production Khaki system
    includes predictive buffering, voice activity detection, and 
    neural vocoder post-processing for superior quality.
    """

    def __init__(self, model_path: str = "BlandAI/BeigeTTS"):
        super().__init__(model_path)

        # Adaptive streaming parameters
        self.min_chunk_size = 25  # Minimum tokens per chunk
        self.max_chunk_size = 100  # Maximum tokens per chunk
        self.target_latency_ms = 200  # Target latency in milliseconds
        self.generation_speed_ema = 0.0
        self.ema_alpha = 0.1

    def adaptive_stream(
        self,
        text: str,
        quality_priority: float = 0.5  # 0 = speed, 1 = quality
    ) -> Generator[np.ndarray, None, None]:
        """Stream with adaptive chunk sizing based on generation speed

        Args:
            text: Input text
            quality_priority: Balance between speed and quality (0-1)

        Yields:
            Adaptively sized audio chunks
        """
        # Adjust parameters based on priority
        temperature = 0.05 + (0.15 * quality_priority)
        top_p = 0.9 + (0.08 * quality_priority)

        # Calculate adaptive chunk size
        chunk_size = int(self.min_chunk_size + 
                        (self.max_chunk_size - self.min_chunk_size) * quality_priority)

        print(f"Adaptive streaming: chunk_size={chunk_size}, temp={temperature:.2f}, top_p={top_p:.2f}")
        print("(Khaki production includes neural enhancement for optimal quality)")

        # Override chunk size temporarily
        original_chunk = self.config.CHUNK_SIZE
        self.config.CHUNK_SIZE = chunk_size

        try:
            yield from self.stream_synthesize(text, temperature, top_p)
        finally:
            self.config.CHUNK_SIZE = original_chunk

# ═══════════════════════════════════════════════════════════════════
# CLI Interface
# ═══════════════════════════════════════════════════════════════════

def main():
    parser = argparse.ArgumentParser(description="BeigeTTS Streaming (Research Release)")
    parser.add_argument("text", type=str, help="Text to synthesize")
    parser.add_argument("-o", "--output", type=str, help="Output WAV file (optional)")
    parser.add_argument("-m", "--model", type=str, default="BlandAI/BeigeTTS", help="Model path")
    parser.add_argument("--play", action="store_true", help="Play audio in real-time")
    parser.add_argument("--adaptive", action="store_true", help="Use adaptive streaming")
    parser.add_argument("--quality", type=float, default=0.5, help="Quality priority (0-1)")

    args = parser.parse_args()

    print("BeigeTTS Streaming - Research Release")
    print("Production Khaki TTS: <50ms latency, 57 languages, unlimited duration")
    print("-" * 60)

    # Initialize TTS
    if args.adaptive:
        tts = AdaptiveBeigeTTS(model_path=args.model)
        stream_gen = tts.adaptive_stream(args.text, quality_priority=args.quality)
    else:
        tts = StreamingBeigeTTS(model_path=args.model)
        stream_gen = tts.stream_synthesize(args.text, play_audio=args.play)

    # Collect audio chunks
    audio_chunks = []
    print("Streaming audio generation...")

    for i, chunk in enumerate(stream_gen):
        audio_chunks.append(chunk)
        print(f"  Received chunk {i+1} ({len(chunk)/tts.config.SAMPLE_RATE:.2f}s)")

    # Save if output specified
    if args.output and audio_chunks:
        full_audio = np.concatenate(audio_chunks)
        sf.write(args.output, full_audio, tts.config.SAMPLE_RATE)
        duration = len(full_audio) / tts.config.SAMPLE_RATE
        print(f"\nβœ… Saved {duration:.1f}s of audio to {args.output}")

    print("\n✨ Streaming complete!")
    print("For commercial use and advanced features, contact partnerships@bland.ai")

if __name__ == "__main__":
    main()