"""TTS engine wrapper for Qwen3-TTS.""" from __future__ import annotations import io import threading import time import wave from abc import ABC, abstractmethod from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy as np import numpy.typing as npt class TTSEngineProtocol(ABC): """Protocol for TTS engines, enabling dependency injection and mocking.""" @abstractmethod def synthesize(self, text: str) -> Iterator[bytes]: """Synthesize text to audio. Args: text: Text to synthesize. Yields: WAV audio data chunks. """ ... @property @abstractmethod def sample_rate(self) -> int: """Return the sample rate of generated audio.""" ... @property def batch_size(self) -> int: """Return the batch size for parallel processing (default: 1).""" return 1 @dataclass class TTSStyle: """Defines a TTS speaking style with its configuration.""" id: str # Unique identifier (e.g., "technical", "narrative") name: str # Display name (e.g., "Technical Documentation") icon: str # Font Awesome icon class (e.g., "fa-gear") description: str # Short description for tooltips prompt: str # The instruct prompt for the TTS model # === TTS STYLES === # Each style provides a different speaking approach optimized for specific content types STYLE_TECHNICAL = TTSStyle( id="technical", name="Technical", icon="fa-microchip", description="Clear, precise reading for code and technical documentation", prompt=( "You are a technical speech engine reading engineering documents. " "Your task is to convert text into clear, accurate spoken output. " "Read in a neutral, controlled, professional voice. " "Do not sound expressive, emotional, or conversational. " "Do not use audiobook, storytelling, or presenter intonation. " "Prioritize intelligibility and correctness over naturalness. " "Maintain steady pacing and flat prosody appropriate for scientific material. " "Pronounce all acronyms as individual letters unless they are standard spoken words. " "Pronounce symbols, operators, and punctuation when they affect meaning. " "Preserve capitalization, parentheses, and formatting as part of the spoken output. " "When reading code, equations, or identifiers, slow down and speak every token clearly. " "Insert short pauses at commas and longer pauses at periods and line breaks. " "Do not summarize, interpret, or rephrase. " "Read exactly what is written." ), ) STYLE_NARRATIVE = TTSStyle( id="narrative", name="Narrative", icon="fa-book-open", description="Natural, engaging reading for articles and stories", prompt=( "You are a professional narrative voice reading long-form text. " "Your task is to tell a story in a clear, engaging, and natural way. " "Use a warm, expressive, and fluid voice. " "Vary intonation and rhythm to reflect meaning, emotion, and emphasis. " "Sound human and immersive, not robotic or monotone. " "Maintain smooth pacing, slowing for important moments, speeding up for transitions. " "Use natural pauses at punctuation and paragraph breaks. " "Pronounce all words clearly, but do not over-articulate symbols or formatting. " "Read acronyms as spoken words when they are commonly pronounced that way. " "Preserve the narrative flow and emotional tone of the text. " "Do not flatten or neutralize the delivery." ), ) STYLE_CHILD_NARRATIVE = TTSStyle( id="child_narrative", name="Child Narrative", icon="fa-child", description="Playful, expressive reading for children's stories", prompt=( "You are a storyteller reading aloud to young children. " "Your task is to tell a story in a friendly, gentle, and engaging way. " "Use a warm, soft, and expressive voice. " "Sound kind, calm, and reassuring. " "Vary intonation to match emotions and actions in the story. " "Maintain a slow to moderate pace with clear articulation. " "Insert natural pauses so children can follow along. " "Pronounce words simply and clearly. " "Read acronyms and difficult words in their most familiar spoken form. " "Keep the tone playful but soothing. " "Do not sound technical, formal, or adult-oriented." ), ) STYLE_NEWS = TTSStyle( id="news", name="News", icon="fa-newspaper", description="Authoritative, clear delivery for news and reports", prompt=( "You are a professional news anchor delivering broadcast news. " "Your task is to read information clearly, confidently, and with authority. " "Use a neutral, composed, and trustworthy voice. " "Avoid emotional or dramatic delivery. " "Do not sound conversational or casual. " "Maintain a steady, moderate pace with crisp articulation. " "Use controlled intonation to mark headlines, key facts, and transitions. " "Pronounce names, numbers, acronyms, and places carefully and accurately. " "Pause briefly at commas and longer at periods and topic changes. " "Sound factual, objective, and broadcast-ready at all times." ), ) STYLE_ACADEMIC = TTSStyle( id="academic", name="Academic", icon="fa-graduation-cap", description="Measured, scholarly reading for papers and research", prompt=( "You are an academic speech engine reading peer-reviewed scientific papers. " "Your task is to render complex scholarly text into clear, precise spoken language. " "Use a neutral, formal, and controlled voice. " "Do not sound expressive, emotional, or conversational. " "Do not use audiobook or presenter intonation. " "Maintain steady pacing suitable for dense technical material. " "Favor clarity and accuracy over naturalness. " "Pronounce technical terminology, Greek letters, acronyms, and units correctly. " "Read acronyms as individual letters unless they are standard spoken words. " "Preserve capitalization, punctuation, and structure when they affect meaning. " "Insert short pauses at commas and longer pauses at periods and section breaks. " "Slow down slightly for equations, symbols, gene names, and references. " "Do not summarize, interpret, or simplify the text. " "Read exactly what is written." ), ) # Registry of all available styles TTS_STYLES: dict[str, TTSStyle] = { style.id: style for style in [ STYLE_TECHNICAL, STYLE_NARRATIVE, STYLE_CHILD_NARRATIVE, STYLE_NEWS, STYLE_ACADEMIC, ] } # Default style DEFAULT_STYLE = STYLE_TECHNICAL def get_style(style_id: str) -> TTSStyle: """Get a TTS style by ID, falling back to default if not found.""" return TTS_STYLES.get(style_id, DEFAULT_STYLE) # Language to default voice mapping LANGUAGE_VOICES: dict[str, str] = { "english": "Ryan", "chinese": "Vivian", "japanese": "Ono_Anna", "korean": "Sohee", } # Default chunk size for streaming # Larger chunks = more stable voice, fewer artifacts at boundaries # Smaller chunks = faster first audio but potential voice instability # 1800 chars provides good balance for natural speech flow DEFAULT_CHUNK_SIZE = 1800 # Idle timeout before unloading model from GPU (seconds) # Set to 0 to disable auto-unloading IDLE_TIMEOUT = 300 # 5 minutes class QwenTTSEngine(TTSEngineProtocol): """TTS engine using Qwen3-TTS model with automatic GPU memory management.""" # Available voices for CustomVoice model: # Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan) # English: Ryan, Aiden # Japanese: Ono_Anna # Korean: Sohee AVAILABLE_VOICES = [ "Vivian", "Serena", "Uncle_Fu", "Dylan", "Eric", "Ryan", "Aiden", "Ono_Anna", "Sohee", ] def __init__( self, voice: str | None = None, language: str = "english", device: str = "cuda", chunk_size: int = DEFAULT_CHUNK_SIZE, model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", idle_timeout: int = IDLE_TIMEOUT, ) -> None: """Initialize the TTS engine. Args: voice: Voice name to use for synthesis. If None, uses default for language. Available voices: Chinese: Vivian, Serena, Uncle_Fu, Dylan, Eric English: Ryan, Aiden Japanese: Ono_Anna Korean: Sohee language: Language for TTS. One of: english, chinese, japanese, korean. Sets default voice if voice is None. device: Device to run the model on ('cuda' or 'cpu'). chunk_size: Maximum characters per chunk (smaller = faster streaming start). model_name: HuggingFace model identifier. """ import logging import warnings import torch # Suppress the pad_token_id warning from transformers logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR) warnings.filterwarnings("ignore", message=".*pad_token_id.*") self.language = language.lower() self.voice = voice or LANGUAGE_VOICES.get(self.language, "Ryan") self.device = device self.chunk_size = chunk_size self._sample_rate = 24000 self._batch_size = 1 # Will be calculated after model loads self._model_name = model_name self._dtype = torch.bfloat16 if device == "cuda" else torch.float32 self._attn_impl = "flash_attention_2" if device == "cuda" else "eager" # Idle timeout management self._idle_timeout = idle_timeout self._last_activity = time.time() self._model_loaded = False self._model_state = "unloaded" # unloaded, loading, loaded, unloading self._lock = threading.Lock() self._unload_timer: threading.Timer | None = None # Calibrated seconds per character (measured and updated over time) self._seconds_per_char: float | None = None # Cumulative stats for running average self._total_chars_processed: int = 0 self._total_time_spent: float = 0.0 # Current style for TTS self._style: TTSStyle = DEFAULT_STYLE # Model will be loaded on first request (lazy loading) self.model = None # Load model immediately if no idle timeout (always keep loaded) if idle_timeout == 0: self._load_model() @property def style(self) -> TTSStyle: """Return the current TTS style.""" return self._style def set_style(self, style_id: str) -> None: """Set the TTS style by ID. Args: style_id: Style identifier (technical, narrative, news, casual, academic). """ self._style = get_style(style_id) @property def model_state(self) -> str: """Return the current model state: unloaded, loading, loaded, or unloading.""" return self._model_state @property def seconds_per_char(self) -> float | None: """Return calibrated seconds per character, or None if not yet measured.""" return self._seconds_per_char @property def total_chars_processed(self) -> int: """Return total characters processed since startup.""" return self._total_chars_processed def _update_timing_stats(self, chars: int, elapsed: float) -> None: """Update cumulative timing statistics. Args: chars: Number of characters processed. elapsed: Time taken in seconds. """ self._total_chars_processed += chars self._total_time_spent += elapsed if self._total_chars_processed > 0: self._seconds_per_char = self._total_time_spent / self._total_chars_processed def calibrate(self, test_text: str = "Hello, this is a calibration test.") -> float: """Run a calibration test to measure seconds per character. Args: test_text: Short text to use for calibration. Returns: Measured seconds per character. """ self._ensure_model_loaded() start = time.time() # Consume the generator to complete synthesis for _ in self.synthesize(test_text): pass elapsed = time.time() - start self._seconds_per_char = elapsed / len(test_text) print(f"⏱️ Calibrated: {self._seconds_per_char:.4f}s per character") return self._seconds_per_char def _load_model(self) -> None: """Load the model onto GPU or CPU.""" if self._model_loaded: return import torch from qwen_tts import Qwen3TTSModel self._model_state = "loading" device_name = "GPU" if self.device == "cuda" else "CPU" print(f"🔄 Loading TTS model onto {device_name}...") start = time.time() # Check if CUDA is actually available when requested if self.device == "cuda" and not torch.cuda.is_available(): print("⚠️ CUDA requested but not available, falling back to CPU") self.device = "cpu" self._dtype = torch.float32 self._attn_impl = "eager" device_name = "CPU" try: self.model = Qwen3TTSModel.from_pretrained( self._model_name, device_map=self.device, dtype=self._dtype, attn_implementation=self._attn_impl, ) except Exception: # Fallback without flash attention self.model = Qwen3TTSModel.from_pretrained( self._model_name, device_map=self.device, dtype=self._dtype, ) self._model_loaded = True self._model_state = "loaded" # Calculate optimal batch size based on available VRAM if self.device == "cuda": self._batch_size = self._calculate_batch_size() print(f" Batch size: {self._batch_size} (based on available VRAM)") elapsed = time.time() - start print(f"✅ Model loaded in {elapsed:.1f}s") def _unload_model(self) -> None: """Unload the model from GPU to free memory.""" with self._lock: if not self._model_loaded or self.model is None: return import gc import torch self._model_state = "unloading" print("💤 Unloading TTS model from GPU (idle timeout)...") # Delete model and clear references del self.model self.model = None self._model_loaded = False # Force garbage collection and clear CUDA cache gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() self._model_state = "unloaded" print("✅ GPU memory freed") def _schedule_unload(self) -> None: """Schedule model unload after idle timeout.""" if self._idle_timeout <= 0: return # Cancel existing timer if self._unload_timer is not None: self._unload_timer.cancel() # Schedule new unload self._unload_timer = threading.Timer(self._idle_timeout, self._unload_model) self._unload_timer.daemon = True self._unload_timer.start() def _ensure_model_loaded(self) -> None: """Ensure model is loaded before use.""" with self._lock: self._last_activity = time.time() # Cancel any pending unload if self._unload_timer is not None: self._unload_timer.cancel() self._unload_timer = None # Load model if not loaded if not self._model_loaded: self._load_model() def _calculate_batch_size(self) -> int: """Calculate optimal batch size based on available GPU memory. Returns: Recommended batch size for parallel chunk processing. """ import torch if not torch.cuda.is_available(): return 1 try: # Get GPU memory info gpu_mem = torch.cuda.get_device_properties(0).total_memory allocated = torch.cuda.memory_allocated(0) reserved = torch.cuda.memory_reserved(0) # Available memory (conservative estimate) available = gpu_mem - max(allocated, reserved) # Model uses ~6GB, each batch item needs ~2-3GB for generation # Use conservative 3GB per batch item estimate mem_per_batch = 3 * 1024 * 1024 * 1024 # 3GB # Calculate batch size, minimum 1, cap at 8 batch_size = max(1, min(8, int(available / mem_per_batch))) return batch_size except Exception: return 1 @property def sample_rate(self) -> int: """Return the sample rate of generated audio.""" return self._sample_rate @property def batch_size(self) -> int: """Return the current batch size.""" return self._batch_size def synthesize(self, text: str) -> Iterator[bytes]: """Synthesize text to WAV audio using batched GPU inference. Args: text: Text to synthesize. Yields: WAV audio data chunks. """ if not text.strip(): return # Ensure model is loaded (lazy loading with idle timeout) self._ensure_model_loaded() # Type guard - model is guaranteed to be loaded after _ensure_model_loaded assert self.model is not None, "Model failed to load" # Track timing for this synthesis synthesis_start = time.time() chars_in_text = len(text) try: # Split text into chunks for streaming chunks = self._split_text(text) # First chunk includes WAV header first_chunk = True # Process chunks in batches for GPU efficiency batch_size = self._batch_size for i in range(0, len(chunks), batch_size): batch = chunks[i : i + batch_size] # Filter empty chunks batch = [c for c in batch if c.strip()] if not batch: continue # Always use batched call for consistent GPU memory allocation # Use the current style's prompt for delivery style_prompt = self._style.prompt batch_instruct = [style_prompt] * len(batch) if len(batch) > 1 else style_prompt audios, sr = self.model.generate_custom_voice( text=batch if len(batch) > 1 else batch[0], speaker=[self.voice] * len(batch) if len(batch) > 1 else self.voice, instruct=batch_instruct, # Use lower temperature for more stable, consistent voice temperature=0.7, repetition_penalty=1.1, ) # Ensure audios is a list for consistent iteration if len(batch) == 1: audios = [audios] # Yield each audio chunk in order for audio in audios: wav_bytes = self._audio_to_wav(audio, sr, include_header=first_chunk) first_chunk = False yield wav_bytes finally: # Update timing stats for future estimates elapsed = time.time() - synthesis_start self._update_timing_stats(chars_in_text, elapsed) # Schedule model unload after idle timeout self._schedule_unload() def _split_text(self, text: str, max_chars: int | None = None) -> list[str]: """Split text into chunks suitable for TTS. Splits on sentence boundaries when possible. Args: text: Text to split. max_chars: Maximum characters per chunk. Uses self.chunk_size if None. Returns: List of text chunks. """ import re if max_chars is None: max_chars = self.chunk_size # Split on sentence boundaries sentences = re.split(r"(?<=[.!?])\s+", text) chunks: list[str] = [] current_chunk: list[str] = [] current_length = 0 for sentence in sentences: sentence = sentence.strip() if not sentence: continue if current_length + len(sentence) > max_chars and current_chunk: chunks.append(" ".join(current_chunk)) current_chunk = [] current_length = 0 current_chunk.append(sentence) current_length += len(sentence) + 1 if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def _audio_to_wav( self, audio: npt.NDArray[np.float32] | list[float], sample_rate: int, include_header: bool = True, ) -> bytes: """Convert audio array to WAV bytes. Args: audio: Audio data as numpy array or list. sample_rate: Sample rate of the audio. include_header: Whether to include WAV header. Returns: WAV audio data as bytes. """ import numpy as np # Convert to numpy array if needed if isinstance(audio, list): audio = np.array(audio, dtype=np.float32) # Ensure audio is 1D if audio.ndim > 1: audio = audio.flatten() # Normalize and convert to 16-bit PCM audio = np.clip(audio, -1.0, 1.0) audio_int16 = (audio * 32767).astype(np.int16) if include_header: # Write full WAV file buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(sample_rate) wav_file.writeframes(audio_int16.tobytes()) result: bytes = buffer.getvalue() return result else: # Return raw PCM data pcm_data: bytes = audio_int16.tobytes() return pcm_data class MockTTSEngine(TTSEngineProtocol): """Mock TTS engine for testing.""" def __init__(self, sample_rate: int = 24000) -> None: """Initialize the mock TTS engine. Args: sample_rate: Sample rate for generated audio. """ self._sample_rate = sample_rate @property def sample_rate(self) -> int: """Return the sample rate of generated audio.""" return self._sample_rate def synthesize(self, text: str) -> Iterator[bytes]: """Generate silent WAV audio for testing. Args: text: Text to synthesize (used to determine duration). Yields: WAV audio data with silence. """ if not text.strip(): return # Generate ~0.1 seconds of silence per word words = len(text.split()) duration_samples = int(self._sample_rate * 0.1 * max(1, words)) # Create silent audio silence = b"\x00\x00" * duration_samples # Write WAV header + silence buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(self._sample_rate) wav_file.writeframes(silence) yield buffer.getvalue()