File size: 7,261 Bytes
8929f9a c36e2bd 6de4ff2 8929f9a c36e2bd 8929f9a c36e2bd 9191ff8 c36e2bd 8929f9a 9191ff8 8929f9a c36e2bd 8929f9a 6de4ff2 c36e2bd 6de4ff2 8929f9a 6de4ff2 8929f9a 6de4ff2 8929f9a c36e2bd 8929f9a c36e2bd 8929f9a c36e2bd 8929f9a c36e2bd 8929f9a c36e2bd 8929f9a c36e2bd 8929f9a c36e2bd 6de4ff2 c36e2bd 6de4ff2 c36e2bd 6de4ff2 c36e2bd 8929f9a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
MOSS-TTS-Nano ONNX - Docker 部署版本 (FastAPI)
提供 REST API 接口,支持流式和非流式输出,兼容 Qwen3TTS 接口格式
自动从 HuggingFace Hub 下载模型,启动时预加载并预热
"""
import os
import sys
import logging
from pathlib import Path
from typing import Generator, Optional
from queue import Queue
import threading
import time
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import numpy as np
from starlette.concurrency import iterate_in_threadpool
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)],
force=True
)
logger = logging.getLogger(__name__)
# 设置模型目录
MODEL_DIR = Path("/data/models") if os.path.exists("/data") else Path("./models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# 添加 moss_tts_service 到路径
service_path = Path(__file__).resolve().parent / "moss_tts_service"
if str(service_path) not in sys.path:
sys.path.insert(0, str(service_path))
from moss_tts_service import MossTtsService
# 全局 TTS 服务实例
_tts_service: Optional[MossTtsService] = None
app = FastAPI(title="MOSS-TTS-Nano ONNX API", version="1.0.0")
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
async def root():
"""返回测试页面"""
return FileResponse("static/index.html")
class TTSRequest(BaseModel):
"""TTS 请求参数"""
input: str
voice: str = "Junhao"
response_format: str = "wav"
temperature: float = 1.0
audio_temperature: float = 0.8
stream: bool = False
def _init_tts_service() -> MossTtsService:
"""初始化并预热 TTS 服务"""
global _tts_service
logger.info("初始化 MossTtsService...")
_tts_service = MossTtsService(
model_dir=str(MODEL_DIR),
cpu_threads=2,
)
logger.info("MossTtsService 初始化完成!")
return _tts_service
def _get_tts_service() -> MossTtsService:
"""获取 TTS 服务实例"""
global _tts_service
if _tts_service is None:
_init_tts_service()
return _tts_service
@app.on_event("startup")
async def startup_event():
"""启动时初始化并预热模型"""
logger.info("系统启动,开始加载模型...")
service = _init_tts_service()
logger.info("模型预热完成,服务就绪!")
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "ok", "service": "MOSS-TTS-Nano-ONNX"}
@app.get("/voices")
async def list_voices():
"""列出所有可用的声音"""
service = _get_tts_service()
return service.list_voices()
@app.post("/v1/audio/speech")
async def create_speech(req: TTSRequest):
"""
兼容 Qwen3TTS 的接口格式
"""
if not req.input:
raise HTTPException(status_code=400, detail="Missing 'input' parameter")
logger.info(f"MossTTS request: text='{req.input[:50]}...', voice={req.voice}, stream={req.stream}")
if not req.stream:
return await _handle_non_stream(req)
else:
logger.info("Starting streaming response...")
sync_gen = _generate_audio_chunks_sync(req.input, req.voice)
return StreamingResponse(
iterate_in_threadpool(sync_gen),
media_type="audio/pcm" if req.response_format == "pcm" else "audio/wav"
)
async def _handle_non_stream(req: TTSRequest):
"""非流式处理:生成完整音频后返回"""
service = _get_tts_service()
temp_file = f"/tmp/tts_output_{int(time.time())}.wav"
try:
output_path = service.synthesize_to_file(
text=req.input,
output_path=temp_file,
voice=req.voice,
)
logger.info(f"Non-stream synthesis completed: {output_path}")
return FileResponse(
path=output_path,
media_type="audio/wav",
filename="output.wav"
)
except Exception as e:
logger.exception("Non-stream synthesis failed")
raise HTTPException(status_code=500, detail=str(e))
def _generate_audio_chunks_sync(text: str, voice: str) -> Generator[bytes, None, None]:
"""
同步生成器:流式生成音频块(PCM 格式,16000Hz,16bit,单声道)
去掉速率控制,生成即发送
"""
service = _get_tts_service()
logger.info(f"Starting streaming synthesis: text='{text[:50]}...'")
start_time = time.perf_counter()
chunk_samples = 320 # 20ms的音频样本数
chunk_bytes = chunk_samples * 2 # 640 bytes per chunk
# 使用队列传递音频数据
audio_queue = Queue(maxsize=50)
_DONE = object()
def producer():
"""后台线程:生成音频并放入队列"""
buffer = b""
try:
for chunk in service.synthesize_stream(
text=text,
voice=voice,
temperature=1.0,
audio_temperature=0.8,
):
if chunk.waveform is None or len(chunk.waveform) == 0:
continue
# 重采样到16000
audio = chunk.waveform.astype(np.float32)
if chunk.sample_rate != 16000:
import resampy
audio = resampy.resample(x=audio, sr_orig=chunk.sample_rate, sr_new=16000)
# 转换为16bit PCM
audio_int16 = (audio * 32767).astype(np.int16)
buffer += audio_int16.tobytes()
# 按chunk_bytes分块放入队列
while len(buffer) >= chunk_bytes:
chunk_data = buffer[:chunk_bytes]
buffer = buffer[chunk_bytes:]
audio_queue.put(chunk_data)
except Exception as e:
logger.exception("Producer error")
audio_queue.put(e)
finally:
if len(buffer) > 0:
audio_queue.put(buffer)
audio_queue.put(_DONE)
# 启动后台生成线程
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
# 主线程:从队列取数据并发送(无速率控制)
chunk_count = 0
first_chunk_time = None
while True:
try:
item = audio_queue.get(timeout=1)
except:
break
if item is _DONE:
break
if isinstance(item, Exception):
raise item
if first_chunk_time is None:
first_chunk_time = time.perf_counter()
logger.info(f"MossTTS Time to first chunk: {first_chunk_time - start_time:.4f}s")
yield item
chunk_count += 1
total_time = time.perf_counter() - start_time
logger.info(f"MossTTS streaming completed in {total_time:.4f}s, chunks sent: {chunk_count}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|