GitHub Actions
Deploy from GitHub: 0bf18943d192a2812c57599f6c25bf9739d523bf
1c7725b
"""TTS engine wrapper for Qwen3-TTS."""
from __future__ import annotations
import io
import threading
import time
import wave
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
class TTSEngineProtocol(ABC):
"""Protocol for TTS engines, enabling dependency injection and mocking."""
@abstractmethod
def synthesize(self, text: str) -> Iterator[bytes]:
"""Synthesize text to audio.
Args:
text: Text to synthesize.
Yields:
WAV audio data chunks.
"""
...
@property
@abstractmethod
def sample_rate(self) -> int:
"""Return the sample rate of generated audio."""
...
@property
def batch_size(self) -> int:
"""Return the batch size for parallel processing (default: 1)."""
return 1
@dataclass
class TTSStyle:
"""Defines a TTS speaking style with its configuration."""
id: str # Unique identifier (e.g., "technical", "narrative")
name: str # Display name (e.g., "Technical Documentation")
icon: str # Font Awesome icon class (e.g., "fa-gear")
description: str # Short description for tooltips
prompt: str # The instruct prompt for the TTS model
# === TTS STYLES ===
# Each style provides a different speaking approach optimized for specific content types
STYLE_TECHNICAL = TTSStyle(
id="technical",
name="Technical",
icon="fa-microchip",
description="Clear, precise reading for code and technical documentation",
prompt=(
"You are a technical speech engine reading engineering documents. "
"Your task is to convert text into clear, accurate spoken output. "
"Read in a neutral, controlled, professional voice. "
"Do not sound expressive, emotional, or conversational. "
"Do not use audiobook, storytelling, or presenter intonation. "
"Prioritize intelligibility and correctness over naturalness. "
"Maintain steady pacing and flat prosody appropriate for scientific material. "
"Pronounce all acronyms as individual letters unless they are standard spoken words. "
"Pronounce symbols, operators, and punctuation when they affect meaning. "
"Preserve capitalization, parentheses, and formatting as part of the spoken output. "
"When reading code, equations, or identifiers, slow down and speak every token clearly. "
"Insert short pauses at commas and longer pauses at periods and line breaks. "
"Do not summarize, interpret, or rephrase. "
"Read exactly what is written."
),
)
STYLE_NARRATIVE = TTSStyle(
id="narrative",
name="Narrative",
icon="fa-book-open",
description="Natural, engaging reading for articles and stories",
prompt=(
"You are a professional narrative voice reading long-form text. "
"Your task is to tell a story in a clear, engaging, and natural way. "
"Use a warm, expressive, and fluid voice. "
"Vary intonation and rhythm to reflect meaning, emotion, and emphasis. "
"Sound human and immersive, not robotic or monotone. "
"Maintain smooth pacing, slowing for important moments, speeding up for transitions. "
"Use natural pauses at punctuation and paragraph breaks. "
"Pronounce all words clearly, but do not over-articulate symbols or formatting. "
"Read acronyms as spoken words when they are commonly pronounced that way. "
"Preserve the narrative flow and emotional tone of the text. "
"Do not flatten or neutralize the delivery."
),
)
STYLE_CHILD_NARRATIVE = TTSStyle(
id="child_narrative",
name="Child Narrative",
icon="fa-child",
description="Playful, expressive reading for children's stories",
prompt=(
"You are a storyteller reading aloud to young children. "
"Your task is to tell a story in a friendly, gentle, and engaging way. "
"Use a warm, soft, and expressive voice. "
"Sound kind, calm, and reassuring. "
"Vary intonation to match emotions and actions in the story. "
"Maintain a slow to moderate pace with clear articulation. "
"Insert natural pauses so children can follow along. "
"Pronounce words simply and clearly. "
"Read acronyms and difficult words in their most familiar spoken form. "
"Keep the tone playful but soothing. "
"Do not sound technical, formal, or adult-oriented."
),
)
STYLE_NEWS = TTSStyle(
id="news",
name="News",
icon="fa-newspaper",
description="Authoritative, clear delivery for news and reports",
prompt=(
"You are a professional news anchor delivering broadcast news. "
"Your task is to read information clearly, confidently, and with authority. "
"Use a neutral, composed, and trustworthy voice. "
"Avoid emotional or dramatic delivery. "
"Do not sound conversational or casual. "
"Maintain a steady, moderate pace with crisp articulation. "
"Use controlled intonation to mark headlines, key facts, and transitions. "
"Pronounce names, numbers, acronyms, and places carefully and accurately. "
"Pause briefly at commas and longer at periods and topic changes. "
"Sound factual, objective, and broadcast-ready at all times."
),
)
STYLE_ACADEMIC = TTSStyle(
id="academic",
name="Academic",
icon="fa-graduation-cap",
description="Measured, scholarly reading for papers and research",
prompt=(
"You are an academic speech engine reading peer-reviewed scientific papers. "
"Your task is to render complex scholarly text into clear, precise spoken language. "
"Use a neutral, formal, and controlled voice. "
"Do not sound expressive, emotional, or conversational. "
"Do not use audiobook or presenter intonation. "
"Maintain steady pacing suitable for dense technical material. "
"Favor clarity and accuracy over naturalness. "
"Pronounce technical terminology, Greek letters, acronyms, and units correctly. "
"Read acronyms as individual letters unless they are standard spoken words. "
"Preserve capitalization, punctuation, and structure when they affect meaning. "
"Insert short pauses at commas and longer pauses at periods and section breaks. "
"Slow down slightly for equations, symbols, gene names, and references. "
"Do not summarize, interpret, or simplify the text. "
"Read exactly what is written."
),
)
# Registry of all available styles
TTS_STYLES: dict[str, TTSStyle] = {
style.id: style
for style in [
STYLE_TECHNICAL,
STYLE_NARRATIVE,
STYLE_CHILD_NARRATIVE,
STYLE_NEWS,
STYLE_ACADEMIC,
]
}
# Default style
DEFAULT_STYLE = STYLE_TECHNICAL
def get_style(style_id: str) -> TTSStyle:
"""Get a TTS style by ID, falling back to default if not found."""
return TTS_STYLES.get(style_id, DEFAULT_STYLE)
# Language to default voice mapping
LANGUAGE_VOICES: dict[str, str] = {
"english": "Ryan",
"chinese": "Vivian",
"japanese": "Ono_Anna",
"korean": "Sohee",
}
# Default chunk size for streaming
# Larger chunks = more stable voice, fewer artifacts at boundaries
# Smaller chunks = faster first audio but potential voice instability
# 1800 chars provides good balance for natural speech flow
DEFAULT_CHUNK_SIZE = 1800
# Idle timeout before unloading model from GPU (seconds)
# Set to 0 to disable auto-unloading
IDLE_TIMEOUT = 300 # 5 minutes
class QwenTTSEngine(TTSEngineProtocol):
"""TTS engine using Qwen3-TTS model with automatic GPU memory management."""
# Available voices for CustomVoice model:
# Chinese: Vivian, Serena, Uncle_Fu, Dylan (Beijing), Eric (Sichuan)
# English: Ryan, Aiden
# Japanese: Ono_Anna
# Korean: Sohee
AVAILABLE_VOICES = [
"Vivian",
"Serena",
"Uncle_Fu",
"Dylan",
"Eric",
"Ryan",
"Aiden",
"Ono_Anna",
"Sohee",
]
def __init__(
self,
voice: str | None = None,
language: str = "english",
device: str = "cuda",
chunk_size: int = DEFAULT_CHUNK_SIZE,
model_name: str = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
idle_timeout: int = IDLE_TIMEOUT,
) -> None:
"""Initialize the TTS engine.
Args:
voice: Voice name to use for synthesis. If None, uses default for language.
Available voices:
Chinese: Vivian, Serena, Uncle_Fu, Dylan, Eric
English: Ryan, Aiden
Japanese: Ono_Anna
Korean: Sohee
language: Language for TTS. One of: english, chinese, japanese, korean.
Sets default voice if voice is None.
device: Device to run the model on ('cuda' or 'cpu').
chunk_size: Maximum characters per chunk (smaller = faster streaming start).
model_name: HuggingFace model identifier.
"""
import logging
import warnings
import torch
# Suppress the pad_token_id warning from transformers
logging.getLogger("transformers.generation.utils").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", message=".*pad_token_id.*")
self.language = language.lower()
self.voice = voice or LANGUAGE_VOICES.get(self.language, "Ryan")
self.device = device
self.chunk_size = chunk_size
self._sample_rate = 24000
self._batch_size = 1 # Will be calculated after model loads
self._model_name = model_name
self._dtype = torch.bfloat16 if device == "cuda" else torch.float32
self._attn_impl = "flash_attention_2" if device == "cuda" else "eager"
# Idle timeout management
self._idle_timeout = idle_timeout
self._last_activity = time.time()
self._model_loaded = False
self._model_state = "unloaded" # unloaded, loading, loaded, unloading
self._lock = threading.Lock()
self._unload_timer: threading.Timer | None = None
# Calibrated seconds per character (measured and updated over time)
self._seconds_per_char: float | None = None
# Cumulative stats for running average
self._total_chars_processed: int = 0
self._total_time_spent: float = 0.0
# Current style for TTS
self._style: TTSStyle = DEFAULT_STYLE
# Model will be loaded on first request (lazy loading)
self.model = None
# Load model immediately if no idle timeout (always keep loaded)
if idle_timeout == 0:
self._load_model()
@property
def style(self) -> TTSStyle:
"""Return the current TTS style."""
return self._style
def set_style(self, style_id: str) -> None:
"""Set the TTS style by ID.
Args:
style_id: Style identifier (technical, narrative, news, casual, academic).
"""
self._style = get_style(style_id)
@property
def model_state(self) -> str:
"""Return the current model state: unloaded, loading, loaded, or unloading."""
return self._model_state
@property
def seconds_per_char(self) -> float | None:
"""Return calibrated seconds per character, or None if not yet measured."""
return self._seconds_per_char
@property
def total_chars_processed(self) -> int:
"""Return total characters processed since startup."""
return self._total_chars_processed
def _update_timing_stats(self, chars: int, elapsed: float) -> None:
"""Update cumulative timing statistics.
Args:
chars: Number of characters processed.
elapsed: Time taken in seconds.
"""
self._total_chars_processed += chars
self._total_time_spent += elapsed
if self._total_chars_processed > 0:
self._seconds_per_char = self._total_time_spent / self._total_chars_processed
def calibrate(self, test_text: str = "Hello, this is a calibration test.") -> float:
"""Run a calibration test to measure seconds per character.
Args:
test_text: Short text to use for calibration.
Returns:
Measured seconds per character.
"""
self._ensure_model_loaded()
start = time.time()
# Consume the generator to complete synthesis
for _ in self.synthesize(test_text):
pass
elapsed = time.time() - start
self._seconds_per_char = elapsed / len(test_text)
print(f"⏱️ Calibrated: {self._seconds_per_char:.4f}s per character")
return self._seconds_per_char
def _load_model(self) -> None:
"""Load the model onto GPU or CPU."""
if self._model_loaded:
return
import torch
from qwen_tts import Qwen3TTSModel
self._model_state = "loading"
device_name = "GPU" if self.device == "cuda" else "CPU"
print(f"🔄 Loading TTS model onto {device_name}...")
start = time.time()
# Check if CUDA is actually available when requested
if self.device == "cuda" and not torch.cuda.is_available():
print("⚠️ CUDA requested but not available, falling back to CPU")
self.device = "cpu"
self._dtype = torch.float32
self._attn_impl = "eager"
device_name = "CPU"
try:
self.model = Qwen3TTSModel.from_pretrained(
self._model_name,
device_map=self.device,
dtype=self._dtype,
attn_implementation=self._attn_impl,
)
except Exception:
# Fallback without flash attention
self.model = Qwen3TTSModel.from_pretrained(
self._model_name,
device_map=self.device,
dtype=self._dtype,
)
self._model_loaded = True
self._model_state = "loaded"
# Calculate optimal batch size based on available VRAM
if self.device == "cuda":
self._batch_size = self._calculate_batch_size()
print(f" Batch size: {self._batch_size} (based on available VRAM)")
elapsed = time.time() - start
print(f"✅ Model loaded in {elapsed:.1f}s")
def _unload_model(self) -> None:
"""Unload the model from GPU to free memory."""
with self._lock:
if not self._model_loaded or self.model is None:
return
import gc
import torch
self._model_state = "unloading"
print("💤 Unloading TTS model from GPU (idle timeout)...")
# Delete model and clear references
del self.model
self.model = None
self._model_loaded = False
# Force garbage collection and clear CUDA cache
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
self._model_state = "unloaded"
print("✅ GPU memory freed")
def _schedule_unload(self) -> None:
"""Schedule model unload after idle timeout."""
if self._idle_timeout <= 0:
return
# Cancel existing timer
if self._unload_timer is not None:
self._unload_timer.cancel()
# Schedule new unload
self._unload_timer = threading.Timer(self._idle_timeout, self._unload_model)
self._unload_timer.daemon = True
self._unload_timer.start()
def _ensure_model_loaded(self) -> None:
"""Ensure model is loaded before use."""
with self._lock:
self._last_activity = time.time()
# Cancel any pending unload
if self._unload_timer is not None:
self._unload_timer.cancel()
self._unload_timer = None
# Load model if not loaded
if not self._model_loaded:
self._load_model()
def _calculate_batch_size(self) -> int:
"""Calculate optimal batch size based on available GPU memory.
Returns:
Recommended batch size for parallel chunk processing.
"""
import torch
if not torch.cuda.is_available():
return 1
try:
# Get GPU memory info
gpu_mem = torch.cuda.get_device_properties(0).total_memory
allocated = torch.cuda.memory_allocated(0)
reserved = torch.cuda.memory_reserved(0)
# Available memory (conservative estimate)
available = gpu_mem - max(allocated, reserved)
# Model uses ~6GB, each batch item needs ~2-3GB for generation
# Use conservative 3GB per batch item estimate
mem_per_batch = 3 * 1024 * 1024 * 1024 # 3GB
# Calculate batch size, minimum 1, cap at 8
batch_size = max(1, min(8, int(available / mem_per_batch)))
return batch_size
except Exception:
return 1
@property
def sample_rate(self) -> int:
"""Return the sample rate of generated audio."""
return self._sample_rate
@property
def batch_size(self) -> int:
"""Return the current batch size."""
return self._batch_size
def synthesize(self, text: str) -> Iterator[bytes]:
"""Synthesize text to WAV audio using batched GPU inference.
Args:
text: Text to synthesize.
Yields:
WAV audio data chunks.
"""
if not text.strip():
return
# Ensure model is loaded (lazy loading with idle timeout)
self._ensure_model_loaded()
# Type guard - model is guaranteed to be loaded after _ensure_model_loaded
assert self.model is not None, "Model failed to load"
# Track timing for this synthesis
synthesis_start = time.time()
chars_in_text = len(text)
try:
# Split text into chunks for streaming
chunks = self._split_text(text)
# First chunk includes WAV header
first_chunk = True
# Process chunks in batches for GPU efficiency
batch_size = self._batch_size
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
# Filter empty chunks
batch = [c for c in batch if c.strip()]
if not batch:
continue
# Always use batched call for consistent GPU memory allocation
# Use the current style's prompt for delivery
style_prompt = self._style.prompt
batch_instruct = [style_prompt] * len(batch) if len(batch) > 1 else style_prompt
audios, sr = self.model.generate_custom_voice(
text=batch if len(batch) > 1 else batch[0],
speaker=[self.voice] * len(batch) if len(batch) > 1 else self.voice,
instruct=batch_instruct,
# Use lower temperature for more stable, consistent voice
temperature=0.7,
repetition_penalty=1.1,
)
# Ensure audios is a list for consistent iteration
if len(batch) == 1:
audios = [audios]
# Yield each audio chunk in order
for audio in audios:
wav_bytes = self._audio_to_wav(audio, sr, include_header=first_chunk)
first_chunk = False
yield wav_bytes
finally:
# Update timing stats for future estimates
elapsed = time.time() - synthesis_start
self._update_timing_stats(chars_in_text, elapsed)
# Schedule model unload after idle timeout
self._schedule_unload()
def _split_text(self, text: str, max_chars: int | None = None) -> list[str]:
"""Split text into chunks suitable for TTS.
Splits on sentence boundaries when possible.
Args:
text: Text to split.
max_chars: Maximum characters per chunk. Uses self.chunk_size if None.
Returns:
List of text chunks.
"""
import re
if max_chars is None:
max_chars = self.chunk_size
# Split on sentence boundaries
sentences = re.split(r"(?<=[.!?])\s+", text)
chunks: list[str] = []
current_chunk: list[str] = []
current_length = 0
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if current_length + len(sentence) > max_chars and current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += len(sentence) + 1
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def _audio_to_wav(
self,
audio: npt.NDArray[np.float32] | list[float],
sample_rate: int,
include_header: bool = True,
) -> bytes:
"""Convert audio array to WAV bytes.
Args:
audio: Audio data as numpy array or list.
sample_rate: Sample rate of the audio.
include_header: Whether to include WAV header.
Returns:
WAV audio data as bytes.
"""
import numpy as np
# Convert to numpy array if needed
if isinstance(audio, list):
audio = np.array(audio, dtype=np.float32)
# Ensure audio is 1D
if audio.ndim > 1:
audio = audio.flatten()
# Normalize and convert to 16-bit PCM
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
if include_header:
# Write full WAV file
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
result: bytes = buffer.getvalue()
return result
else:
# Return raw PCM data
pcm_data: bytes = audio_int16.tobytes()
return pcm_data
class MockTTSEngine(TTSEngineProtocol):
"""Mock TTS engine for testing."""
def __init__(self, sample_rate: int = 24000) -> None:
"""Initialize the mock TTS engine.
Args:
sample_rate: Sample rate for generated audio.
"""
self._sample_rate = sample_rate
@property
def sample_rate(self) -> int:
"""Return the sample rate of generated audio."""
return self._sample_rate
def synthesize(self, text: str) -> Iterator[bytes]:
"""Generate silent WAV audio for testing.
Args:
text: Text to synthesize (used to determine duration).
Yields:
WAV audio data with silence.
"""
if not text.strip():
return
# Generate ~0.1 seconds of silence per word
words = len(text.split())
duration_samples = int(self._sample_rate * 0.1 * max(1, words))
# Create silent audio
silence = b"\x00\x00" * duration_samples
# Write WAV header + silence
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(self._sample_rate)
wav_file.writeframes(silence)
yield buffer.getvalue()