| |
| """ |
| ================================================================================ |
| Priority 4: Sentence-Level Streaming TTS Server with Opus Output |
| ================================================================================ |
| |
| Production TTS requires streaming audio, not waiting for full generation. |
| This server implements: |
| |
| 1. Sentence-level chunking: Split input at punctuation boundaries |
| 2. Per-sentence generation: Generate mel-spectrogram for each sentence |
| 3. Opus encoding: Stream compressed audio chunks as each sentence completes |
| 4. Time-to-first-audio (TTFA): Sub-500ms for short sentences |
| |
| Architecture: |
| Client → HTTP POST /synthesize |
| → Preprocess (diacritize, normalize) |
| → Chunk into sentences |
| → For each sentence: |
| → Generate with EPSS(7) + BF16 |
| → Encode to Opus |
| → Yield audio chunk |
| → Client receives streaming audio |
| |
| Opus encoding: |
| - 24kHz sample rate (matches F5-TTS output) |
| - 16-32kbps bitrate |
| - ~10x smaller than WAV |
| - Streaming-compatible (Ogg Opus container) |
| |
| Dependencies: |
| pip install fastapi uvicorn opuslib pydub |
| |
| Usage: |
| # Start server |
| python 04_streaming_server.py --host 0.0.0.0 --port 8000 |
| |
| # Client request |
| curl -X POST http://localhost:8000/synthesize \ |
| -H "Content-Type: application/json" \ |
| -d '{ |
| "text": "مرحبا بك. كيف حالك اليوم؟", |
| "ref_audio": "reference.wav", |
| "ref_text": "مرحبا" |
| }' \ |
| --output output.opus |
| |
| ================================================================================ |
| """ |
|
|
| import argparse |
| import asyncio |
| import io |
| import os |
| import sys |
| import time |
| import warnings |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| from typing import AsyncGenerator, List, Optional |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import torchaudio |
| from cached_path import cached_path |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import StreamingResponse |
| from f5_tts.infer.utils_infer import load_vocoder, preprocess_ref_audio_text |
| from f5_tts.model import CFM |
| from f5_tts.model.utils import get_tokenizer |
| from habibi_tts.model.utils import dialect_id_map, text_list_formatter |
| from hydra.utils import get_class |
| from omegaconf import OmegaConf |
| from pydantic import BaseModel |
|
|
| warnings.filterwarnings("ignore") |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MODEL_CFG_PATH = str(Path(__file__).parent / "configs" / "F5TTS_v1_Base.yaml") |
| CKPT_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/model_100000.safetensors" |
| VOCAB_URL = "hf://SWivid/Habibi-TTS/Specialized/ALG/vocab.txt" |
|
|
| N_MEL_CHANNELS = 100 |
| HOP_LENGTH = 256 |
| WIN_LENGTH = 1024 |
| N_FFT = 1024 |
| TARGET_SAMPLE_RATE = 24000 |
|
|
| |
| |
| |
|
|
| model_global = None |
| vocoder_global = None |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_production_model(device=DEVICE): |
| """Load optimized Habibi-TTS ALG model for production.""" |
| print(f"[LOAD] Loading production model on {device}...") |
|
|
| model_cfg = OmegaConf.load(MODEL_CFG_PATH) |
| model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") |
| model_arc = model_cfg.model.arch |
|
|
| ckpt_file = str(cached_path(CKPT_URL)) |
| vocab_file = str(cached_path(VOCAB_URL)) |
|
|
| vocab_char_map, vocab_size = get_tokenizer(vocab_file, "custom") |
|
|
| model = CFM( |
| transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=N_MEL_CHANNELS), |
| mel_spec_kwargs=dict( |
| n_fft=N_FFT, |
| hop_length=HOP_LENGTH, |
| win_length=WIN_LENGTH, |
| n_mel_channels=N_MEL_CHANNELS, |
| target_sample_rate=TARGET_SAMPLE_RATE, |
| mel_spec_type="vocos", |
| ), |
| odeint_kwargs=dict(method="euler"), |
| vocab_char_map=vocab_char_map, |
| ).to(device) |
|
|
| |
| from safetensors.torch import load_file |
| checkpoint = load_file(ckpt_file, device=device) |
| checkpoint = {"ema_model_state_dict": checkpoint} |
| checkpoint["model_state_dict"] = { |
| k.replace("ema_model.", ""): v |
| for k, v in checkpoint["ema_model_state_dict"].items() |
| if k not in ["initted", "step"] |
| } |
| for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]: |
| if key in checkpoint["model_state_dict"]: |
| del checkpoint["model_state_dict"][key] |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| del checkpoint |
| torch.cuda.empty_cache() |
|
|
| |
| if device == "cuda": |
| model = model.to(torch.bfloat16) |
| print("[OPT] Model converted to BF16") |
|
|
| |
| if device == "cuda": |
| model.transformer = torch.compile(model.transformer, mode="reduce-overhead", fullgraph=False) |
| print("[OPT] torch.compile applied") |
|
|
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
|
|
|
|
| def chunk_text(text: str, max_chars: int = 135) -> List[str]: |
| """Split text into sentence-level chunks.""" |
| import re |
| sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[؛:،。!?])", text) |
| chunks = [] |
| current_chunk = "" |
|
|
| for sentence in sentences: |
| if not sentence.strip(): |
| continue |
| if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: |
| current_chunk += sentence + " " if sentence and sentence[-1].isascii() else sentence |
| else: |
| if current_chunk: |
| chunks.append(current_chunk.strip()) |
| current_chunk = sentence + " " if sentence and sentence[-1].isascii() else sentence |
|
|
| if current_chunk: |
| chunks.append(current_chunk.strip()) |
|
|
| return chunks |
|
|
|
|
| def wav_to_opus_bytes(wav: np.ndarray, sr: int = 24000, bitrate: str = "24k") -> bytes: |
| """Convert WAV numpy array to Opus-encoded bytes.""" |
| try: |
| import subprocess |
| |
| wav_int16 = (wav * 32767).astype(np.int16) |
| wav_buffer = io.BytesIO() |
| sf.write(wav_buffer, wav_int16, sr, format="WAV", subtype="PCM_16") |
| wav_bytes = wav_buffer.getvalue() |
|
|
| |
| proc = subprocess.run( |
| ["ffmpeg", "-i", "-", "-c:a", "libopus", "-b:a", bitrate, "-f", "ogg", "-"], |
| input=wav_bytes, |
| capture_output=True, |
| ) |
| if proc.returncode != 0: |
| |
| return wav_bytes |
| return proc.stdout |
| except Exception: |
| |
| wav_buffer = io.BytesIO() |
| sf.write(wav_buffer, wav, sr, format="WAV") |
| return wav_buffer.getvalue() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def infer_sentence( |
| ref_audio: torch.Tensor, |
| ref_text: str, |
| gen_text: str, |
| model_obj, |
| vocoder, |
| nfe_step: int = 7, |
| cfg_strength: float = 2.0, |
| sway_sampling_coef: float = -1.0, |
| speed: float = 1.0, |
| device: str = DEVICE, |
| ) -> np.ndarray: |
| """Generate audio for a single sentence.""" |
| audio = ref_audio.to(device) |
| ref_audio_len = audio.shape[-1] // HOP_LENGTH |
|
|
| |
| text_list = [ref_text + gen_text] |
| final_text_list = text_list_formatter(text_list, dialect_id=dialect_id_map["ALG"]) |
|
|
| |
| ref_text_len = len(ref_text.encode("utf-8")) |
| gen_text_len = len(gen_text.encode("utf-8")) |
| duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) |
|
|
| with torch.inference_mode(): |
| generated, _ = model_obj.sample( |
| cond=audio, |
| text=final_text_list, |
| duration=duration, |
| steps=nfe_step, |
| cfg_strength=cfg_strength, |
| sway_sampling_coef=sway_sampling_coef, |
| ) |
|
|
| generated = generated.to(torch.float32) |
| generated = generated[:, ref_audio_len:, :] |
| generated = generated.permute(0, 2, 1) |
| generated_wave = vocoder.decode(generated) |
| generated_wave = generated_wave.squeeze().cpu().numpy() |
|
|
| return generated_wave |
|
|
|
|
| async def stream_synthesize( |
| text: str, |
| ref_audio_path: str, |
| ref_text: str, |
| model_obj, |
| vocoder, |
| nfe_step: int = 7, |
| device: str = DEVICE, |
| ) -> AsyncGenerator[bytes, None]: |
| """Stream synthesized audio in sentence-level chunks.""" |
| |
| ref_audio_path, ref_text = preprocess_ref_audio_text(ref_audio_path, ref_text) |
|
|
| |
| audio, sr = torchaudio.load(ref_audio_path) |
| if audio.shape[0] > 1: |
| audio = torch.mean(audio, dim=0, keepdim=True) |
| if sr != TARGET_SAMPLE_RATE: |
| resampler = torchaudio.transforms.Resample(sr, TARGET_SAMPLE_RATE) |
| audio = resampler(audio) |
|
|
| |
| rms = torch.sqrt(torch.mean(torch.square(audio))) |
| target_rms = 0.1 |
| if rms < target_rms: |
| audio = audio * target_rms / rms |
|
|
| |
| sentences = chunk_text(text) |
| print(f"[STREAM] Text split into {len(sentences)} chunks") |
|
|
| |
| for i, sentence in enumerate(sentences): |
| t0 = time.perf_counter() |
| wav = infer_sentence( |
| audio, ref_text, sentence, model_obj, vocoder, |
| nfe_step=nfe_step, device=device, |
| ) |
| t1 = time.perf_counter() |
|
|
| |
| if rms < target_rms: |
| wav = wav * rms.item() / target_rms |
|
|
| |
| opus_bytes = wav_to_opus_bytes(wav, sr=TARGET_SAMPLE_RATE, bitrate="24k") |
|
|
| print(f"[STREAM] Chunk {i+1}/{len(sentences)}: {len(sentence)} chars, " |
| f"gen={t1-t0:.3f}s, audio={len(wav)/TARGET_SAMPLE_RATE:.2f}s") |
|
|
| yield opus_bytes |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SynthesizeRequest(BaseModel): |
| text: str |
| ref_audio: str |
| ref_text: str = "" |
| nfe_step: int = 7 |
| cfg_strength: float = 2.0 |
| sway_sampling_coef: float = -1.0 |
| speed: float = 1.0 |
| output_format: str = "opus" |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Load model at startup.""" |
| global model_global, vocoder_global |
| print("[STARTUP] Loading production model...") |
| model_global = load_production_model(device=DEVICE) |
| vocoder_global = load_vocoder("vocos", is_local=False, local_path="", device=DEVICE) |
| print("[STARTUP] Model loaded. Server ready.") |
| yield |
| print("[SHUTDOWN] Cleaning up...") |
| del model_global, vocoder_global |
| torch.cuda.empty_cache() if DEVICE == "cuda" else None |
|
|
|
|
| app = FastAPI(title="Habibi-TTS ALG Streaming Server", lifespan=lifespan) |
|
|
|
|
| @app.post("/synthesize") |
| async def synthesize(request: SynthesizeRequest): |
| """Stream synthesized audio for the given text.""" |
| if not os.path.exists(request.ref_audio): |
| raise HTTPException(status_code=400, detail=f"Reference audio not found: {request.ref_audio}") |
|
|
| async def generate(): |
| async for chunk in stream_synthesize( |
| request.text, |
| request.ref_audio, |
| request.ref_text, |
| model_global, |
| vocoder_global, |
| nfe_step=request.nfe_step, |
| device=DEVICE, |
| ): |
| yield chunk |
|
|
| media_type = "audio/ogg" if request.output_format == "opus" else "audio/wav" |
| return StreamingResponse(generate(), media_type=media_type) |
|
|
|
|
| @app.post("/synthesize_sync") |
| async def synthesize_sync(request: SynthesizeRequest): |
| """Synchronous synthesis (full audio returned).""" |
| if not os.path.exists(request.ref_audio): |
| raise HTTPException(status_code=400, detail=f"Reference audio not found: {request.ref_audio}") |
|
|
| chunks = [] |
| async for chunk in stream_synthesize( |
| request.text, |
| request.ref_audio, |
| request.ref_text, |
| model_global, |
| vocoder_global, |
| nfe_step=request.nfe_step, |
| device=DEVICE, |
| ): |
| chunks.append(chunk) |
|
|
| full_audio = b"".join(chunks) |
| media_type = "audio/ogg" if request.output_format == "opus" else "audio/wav" |
| return StreamingResponse(io.BytesIO(full_audio), media_type=media_type) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| """Health check endpoint.""" |
| return {"status": "ok", "model_loaded": model_global is not None} |
|
|
|
|
| @app.get("/info") |
| async def info(): |
| """Server info.""" |
| return { |
| "model": "Habibi-TTS ALG (Specialized)", |
| "base_model": "F5-TTS v1 Base", |
| "device": DEVICE, |
| "optimizations": ["BF16", "torch.compile", "EPSS"], |
| "default_nfe": 7, |
| "sample_rate": TARGET_SAMPLE_RATE, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Habibi-TTS ALG Streaming Server") |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=8000) |
| parser.add_argument("--workers", type=int, default=1) |
| args = parser.parse_args() |
|
|
| import uvicorn |
| uvicorn.run(app, host=args.host, port=args.port, workers=args.workers) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|