#!/usr/bin/env python3 """ ================================================================================ Priority 4: Sentence-Level Streaming TTS Server with Opus Output ================================================================================ Production TTS requires streaming audio, not waiting for full generation. This server implements: 1. Sentence-level chunking: Split input at punctuation boundaries 2. Per-sentence generation: Generate mel-spectrogram for each sentence 3. Opus encoding: Stream compressed audio chunks as each sentence completes 4. Time-to-first-audio (TTFA): Sub-500ms for short sentences Architecture: Client → HTTP POST /synthesize → Preprocess (diacritize, normalize) → Chunk into sentences → For each sentence: → Generate with EPSS(7) + BF16 → Encode to Opus → Yield audio chunk → Client receives streaming audio Opus encoding: - 24kHz sample rate (matches F5-TTS output) - 16-32kbps bitrate - ~10x smaller than WAV - Streaming-compatible (Ogg Opus container) Dependencies: pip install fastapi uvicorn opuslib pydub Usage: # Start server python 04_streaming_server.py --host 0.0.0.0 --port 8000 # Client request curl -X POST http://localhost:8000/synthesize \ -H "Content-Type: application/json" \ -d '{ "text": "مرحبا بك. كيف حالك اليوم؟", "ref_audio": "reference.wav", "ref_text": "مرحبا" }' \ --output output.opus ================================================================================ """ import argparse import asyncio import io import os import sys import time import warnings from contextlib import asynccontextmanager from pathlib import Path from typing import AsyncGenerator, List, Optional import numpy as np import soundfile as sf import torch import torchaudio from cached_path import cached_path from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from f5_tts.infer.utils_infer import load_vocoder, preprocess_ref_audio_text from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer from habibi_tts.model.utils import dialect_id_map, text_list_formatter from hydra.utils import get_class from omegaconf import OmegaConf from pydantic import BaseModel warnings.filterwarnings("ignore") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml") CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors" VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt" N_MEL_CHANNELS = 100 HOP_LENGTH = 256 WIN_LENGTH = 1024 N_FFT = 1024 TARGET_SAMPLE_RATE = 24000 # --------------------------------------------------------------------------- # Global model state (loaded once at startup) # --------------------------------------------------------------------------- model_global = None vocoder_global = None # --------------------------------------------------------------------------- # Model Loading # --------------------------------------------------------------------------- def load_production_model(device=DEVICE): """Load optimized Habibi-TTS ALG model for production.""" print(f"[LOAD] Loading production model on {device}...") model_cfg = OmegaConf.load(MODEL_CFG_PATH) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch ckpt_file = str(cached_path(CKPT_URL)) vocab_file = str(cached_path(VOCAB_URL)) vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom") model = CFM( transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS), mel_spec_kwargs=dict( n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, n_mel_channels=N_MEL_CHANNELS, target_sample_rate=TARGET_SAMPLE_RATE, mel_spec_type="vocos", ), odeint_kwargs=dict(method="euler"), vocab_char_map=vocab_char_map, ).to(device) # Load checkpoint from safetensors.torch import load_file checkpoint = load_file(ckpt_file, device=device) checkpoint = {"ema_model_state_dict": checkpoint} checkpoint["model_state_dict"] = { k.replace("ema_model.", ""): v for k, v in checkpoint["ema_model_state_dict"].items() if k not in ["initted", "step"] } for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: if key in checkpoint["model_state_dict"]: del checkpoint["model_state_dict"][key] model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint torch.cuda.empty_cache() # BF16 optimization if device == "cuda": model = model.to(torch.bfloat16) print("[OPT] Model converted to BF16") # torch.compile for transformer backbone if device == "cuda": model.transformer = torch.compile(model.transformer, mode="reduce-overhead", fullgraph=False) print("[OPT] torch.compile applied") model.eval() return model # --------------------------------------------------------------------------- # Audio Processing # --------------------------------------------------------------------------- def chunk_text(text: str, max_chars: int = 135) -> List[str]: """Split text into sentence-level chunks.""" import re sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[؛:،。!?])", text) chunks = [] current_chunk = "" for sentence in sentences: if not sentence.strip(): continue if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: current_chunk += sentence + " " if sentence and sentence[-1].isascii() else sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and sentence[-1].isascii() else sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks def wav_to_opus_bytes(wav: np.ndarray, sr: int = 24000, bitrate: str = "24k") -> bytes: """Convert WAV numpy array to Opus-encoded bytes.""" try: import subprocess # Write to temporary WAV wav_int16 = (wav * 32767).astype(np.int16) wav_buffer = io.BytesIO() sf.write(wav_buffer, wav_int16, sr, format="WAV", subtype="PCM_16") wav_bytes = wav_buffer.getvalue() # Convert to Opus using ffmpeg proc = subprocess.run( ["ffmpeg", "-i", "-", "-c:a", "libopus", "-b:a", bitrate, "-f", "ogg", "-"], input=wav_bytes, capture_output=True, ) if proc.returncode != 0: # Fallback: return WAV if opus encoding fails return wav_bytes return proc.stdout except Exception: # Fallback to WAV wav_buffer = io.BytesIO() sf.write(wav_buffer, wav, sr, format="WAV") return wav_buffer.getvalue() # --------------------------------------------------------------------------- # Streaming Inference # --------------------------------------------------------------------------- def infer_sentence( ref_audio: torch.Tensor, ref_text: str, gen_text: str, model_obj, vocoder, nfe_step: int = 7, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, device: str = DEVICE, ) -> np.ndarray: """Generate audio for a single sentence.""" audio = ref_audio.to(device) ref_audio_len = audio.shape[-1] // HOP_LENGTH # Prepare text with dialect ID text_list = [ref_text + gen_text] final_text_list = text_list_formatter(text_list, dialect_id=dialect_id_map["ALG"]) # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) with torch.inference_mode(): generated, _ = model_obj.sample( cond=audio, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated = generated.permute(0, 2, 1) generated_wave = vocoder.decode(generated) generated_wave = generated_wave.squeeze().cpu().numpy() return generated_wave async def stream_synthesize( text: str, ref_audio_path: str, ref_text: str, model_obj, vocoder, nfe_step: int = 7, device: str = DEVICE, ) -> AsyncGenerator[bytes, None]: """Stream synthesized audio in sentence-level chunks.""" # Preprocess reference ref_audio_path, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text) # Load reference audio audio, sr = torchaudio.load(ref_audio_path) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) if sr != TARGET_SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sr, TARGET_SAMPLE_RATE) audio = resampler(audio) # Normalize RMS rms = torch.sqrt(torch.mean(torch.square(audio))) target_rms = 0.1 if rms < target_rms: audio = audio * target_rms / rms # Chunk text sentences = chunk_text(text) print(f"[STREAM] Text split into {len(sentences)} chunks") # Generate and stream each sentence for i, sentence in enumerate(sentences): t0 = time.perf_counter() wav = infer_sentence( audio, ref_text, sentence, model_obj, vocoder, nfe_step=nfe_step, device=device, ) t1 = time.perf_counter() # Re-normalize if needed if rms < target_rms: wav = wav * rms.item() / target_rms # Encode to Opus opus_bytes = wav_to_opus_bytes(wav, sr=TARGET_SAMPLE_RATE, bitrate="24k") print(f"[STREAM] Chunk {i+1}/{len(sentences)}: {len(sentence)} chars, " f"gen={t1-t0:.3f}s, audio={len(wav)/TARGET_SAMPLE_RATE:.2f}s") yield opus_bytes # --------------------------------------------------------------------------- # FastAPI Application # --------------------------------------------------------------------------- class SynthesizeRequest(BaseModel): text: str ref_audio: str ref_text: str = "" nfe_step: int = 7 cfg_strength: float = 2.0 sway_sampling_coef: float = -1.0 speed: float = 1.0 output_format: str = "opus" # opus, wav @asynccontextmanager async def lifespan(app: FastAPI): """Load model at startup.""" global model_global, vocoder_global print("[STARTUP] Loading production model...") model_global = load_production_model(device=DEVICE) vocoder_global = load_vocoder("vocos", is_local=False, local_path="", device=DEVICE) print("[STARTUP] Model loaded. Server ready.") yield print("[SHUTDOWN] Cleaning up...") del model_global, vocoder_global torch.cuda.empty_cache() if DEVICE == "cuda" else None app = FastAPI(title="Habibi-TTS ALG Streaming Server", lifespan=lifespan) @app.post("/synthesize") async def synthesize(request: SynthesizeRequest): """Stream synthesized audio for the given text.""" if not os.path.exists(request.ref_audio): raise HTTPException(status_code=400, detail=f"Reference audio not found: {request.ref_audio}") async def generate(): async for chunk in stream_synthesize( request.text, request.ref_audio, request.ref_text, model_global, vocoder_global, nfe_step=request.nfe_step, device=DEVICE, ): yield chunk media_type = "audio/ogg" if request.output_format == "opus" else "audio/wav" return StreamingResponse(generate(), media_type=media_type) @app.post("/synthesize_sync") async def synthesize_sync(request: SynthesizeRequest): """Synchronous synthesis (full audio returned).""" if not os.path.exists(request.ref_audio): raise HTTPException(status_code=400, detail=f"Reference audio not found: {request.ref_audio}") chunks = [] async for chunk in stream_synthesize( request.text, request.ref_audio, request.ref_text, model_global, vocoder_global, nfe_step=request.nfe_step, device=DEVICE, ): chunks.append(chunk) full_audio = b"".join(chunks) media_type = "audio/ogg" if request.output_format == "opus" else "audio/wav" return StreamingResponse(io.BytesIO(full_audio), media_type=media_type) @app.get("/health") async def health(): """Health check endpoint.""" return {"status": "ok", "model_loaded": model_global is not None} @app.get("/info") async def info(): """Server info.""" return { "model": "Habibi-TTS ALG (Specialized)", "base_model": "F5-TTS v1 Base", "device": DEVICE, "optimizations": ["BF16", "torch.compile", "EPSS"], "default_nfe": 7, "sample_rate": TARGET_SAMPLE_RATE, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Habibi-TTS ALG Streaming Server") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--workers", type=int, default=1) args = parser.parse_args() import uvicorn uvicorn.run(app, host=args.host, port=args.port, workers=args.workers) if __name__ == "__main__": main()