File size: 5,046 Bytes
35bb6f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | from __future__ import annotations
import base64
import json
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from api.src.core.config import settings
from api.src.inference.model_manager import ModelManager
from api.src.inference.voice_manager import VoiceManager
from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.structures.websocket_schemas import (
WSResponseMessage,
WSStartMessage,
)
router = APIRouter(tags=["WebSocket"])
@router.websocket("/v1/audio/speech/stream")
async def websocket_tts(ws: WebSocket) -> None:
"""WebSocket endpoint for real-time TTS streaming.
Protocol:
1. Client sends: {"type": "start", "model": "...", "voice": "...", "response_format": "pcm"}
2. Client sends: {"type": "text", "text": "Hello world"}
3. Server streams: {"type": "audio", "data": "<base64>", "format": "pcm"}
4. Server sends: {"type": "done"}
5. Client sends: {"type": "stop"} or more text messages
"""
await ws.accept()
logger.info("WebSocket connection accepted")
model_manager = ModelManager.get_instance()
voice_manager = VoiceManager.get_instance()
session_config: WSStartMessage | None = None
try:
while True:
raw = await ws.receive_text()
msg = json.loads(raw)
msg_type = msg.get("type")
if msg_type == "ping":
await ws.send_text(json.dumps({"type": "pong"}))
continue
if msg_type == "start":
session_config = WSStartMessage(**msg)
# Validate model
if not model_manager.is_loaded(session_config.model):
await _send_error(ws, f"Model '{session_config.model}' is not loaded")
continue
if not voice_manager.voice_exists(session_config.voice):
await _send_error(ws, f"Voice '{session_config.voice}' not found")
continue
logger.info(
f"WS session started: model={session_config.model} "
f"voice={session_config.voice} format={session_config.response_format}"
)
continue
if msg_type == "text":
if session_config is None:
await _send_error(ws, "Send a 'start' message first")
continue
text = msg.get("text", "").strip()
if not text:
await _send_error(ws, "Empty text")
continue
await _handle_text(ws, session_config, text, model_manager, voice_manager)
elif msg_type == "stop":
logger.info("WS session stopped by client")
break
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
try:
await _send_error(ws, str(e))
except Exception:
pass
async def _handle_text(
ws: WebSocket,
config: WSStartMessage,
text: str,
model_manager: ModelManager,
voice_manager: VoiceManager,
) -> None:
loaded = model_manager.loaded_models[config.model]
ref_codes = await voice_manager.get_or_encode_ref_codes(
config.voice, loaded.codec_id, model_manager, config.model
)
ref_text = voice_manager.get_ref_text(config.voice)
from api.src.core.model_config import get_backbone_info
info = get_backbone_info(config.model)
writer = StreamingAudioWriter(config.response_format, settings.sample_rate)
try:
if info and info.supports_streaming:
async for chunk in model_manager.infer_stream(
config.model, text, ref_codes, ref_text
):
encoded = writer.write_chunk(chunk)
if encoded:
await ws.send_text(json.dumps({
"type": "audio",
"data": base64.b64encode(encoded).decode(),
"format": config.response_format,
}))
else:
wav = await model_manager.infer(config.model, text, ref_codes, ref_text)
encoded = writer.write_chunk(wav)
if encoded:
await ws.send_text(json.dumps({
"type": "audio",
"data": base64.b64encode(encoded).decode(),
"format": config.response_format,
}))
final = writer.finalize()
if final:
await ws.send_text(json.dumps({
"type": "audio",
"data": base64.b64encode(final).decode(),
"format": config.response_format,
}))
await ws.send_text(json.dumps({"type": "done"}))
finally:
writer.close()
async def _send_error(ws: WebSocket, message: str) -> None:
await ws.send_text(json.dumps({"type": "error", "message": message}))
|