tts / api.py
张明
Add application file
9191ff8
"""
MOSS-TTS-Nano ONNX - Docker 部署版本 (FastAPI)
提供 REST API 接口,支持流式和非流式输出,兼容 Qwen3TTS 接口格式
自动从 HuggingFace Hub 下载模型,启动时预加载并预热
"""
import os
import sys
import logging
from pathlib import Path
from typing import Generator, Optional
from queue import Queue
import threading
import time
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import numpy as np
from starlette.concurrency import iterate_in_threadpool
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)],
force=True
)
logger = logging.getLogger(__name__)
# 设置模型目录
MODEL_DIR = Path("/data/models") if os.path.exists("/data") else Path("./models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# 添加 moss_tts_service 到路径
service_path = Path(__file__).resolve().parent / "moss_tts_service"
if str(service_path) not in sys.path:
sys.path.insert(0, str(service_path))
from moss_tts_service import MossTtsService
# 全局 TTS 服务实例
_tts_service: Optional[MossTtsService] = None
app = FastAPI(title="MOSS-TTS-Nano ONNX API", version="1.0.0")
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
async def root():
"""返回测试页面"""
return FileResponse("static/index.html")
class TTSRequest(BaseModel):
"""TTS 请求参数"""
input: str
voice: str = "Junhao"
response_format: str = "wav"
temperature: float = 1.0
audio_temperature: float = 0.8
stream: bool = False
def _init_tts_service() -> MossTtsService:
"""初始化并预热 TTS 服务"""
global _tts_service
logger.info("初始化 MossTtsService...")
_tts_service = MossTtsService(
model_dir=str(MODEL_DIR),
cpu_threads=2,
)
logger.info("MossTtsService 初始化完成!")
return _tts_service
def _get_tts_service() -> MossTtsService:
"""获取 TTS 服务实例"""
global _tts_service
if _tts_service is None:
_init_tts_service()
return _tts_service
@app.on_event("startup")
async def startup_event():
"""启动时初始化并预热模型"""
logger.info("系统启动,开始加载模型...")
service = _init_tts_service()
logger.info("模型预热完成,服务就绪!")
@app.get("/health")
async def health():
"""健康检查"""
return {"status": "ok", "service": "MOSS-TTS-Nano-ONNX"}
@app.get("/voices")
async def list_voices():
"""列出所有可用的声音"""
service = _get_tts_service()
return service.list_voices()
@app.post("/v1/audio/speech")
async def create_speech(req: TTSRequest):
"""
兼容 Qwen3TTS 的接口格式
"""
if not req.input:
raise HTTPException(status_code=400, detail="Missing 'input' parameter")
logger.info(f"MossTTS request: text='{req.input[:50]}...', voice={req.voice}, stream={req.stream}")
if not req.stream:
return await _handle_non_stream(req)
else:
logger.info("Starting streaming response...")
sync_gen = _generate_audio_chunks_sync(req.input, req.voice)
return StreamingResponse(
iterate_in_threadpool(sync_gen),
media_type="audio/pcm" if req.response_format == "pcm" else "audio/wav"
)
async def _handle_non_stream(req: TTSRequest):
"""非流式处理:生成完整音频后返回"""
service = _get_tts_service()
temp_file = f"/tmp/tts_output_{int(time.time())}.wav"
try:
output_path = service.synthesize_to_file(
text=req.input,
output_path=temp_file,
voice=req.voice,
)
logger.info(f"Non-stream synthesis completed: {output_path}")
return FileResponse(
path=output_path,
media_type="audio/wav",
filename="output.wav"
)
except Exception as e:
logger.exception("Non-stream synthesis failed")
raise HTTPException(status_code=500, detail=str(e))
def _generate_audio_chunks_sync(text: str, voice: str) -> Generator[bytes, None, None]:
"""
同步生成器:流式生成音频块(PCM 格式,16000Hz,16bit,单声道)
去掉速率控制,生成即发送
"""
service = _get_tts_service()
logger.info(f"Starting streaming synthesis: text='{text[:50]}...'")
start_time = time.perf_counter()
chunk_samples = 320 # 20ms的音频样本数
chunk_bytes = chunk_samples * 2 # 640 bytes per chunk
# 使用队列传递音频数据
audio_queue = Queue(maxsize=50)
_DONE = object()
def producer():
"""后台线程:生成音频并放入队列"""
buffer = b""
try:
for chunk in service.synthesize_stream(
text=text,
voice=voice,
temperature=1.0,
audio_temperature=0.8,
):
if chunk.waveform is None or len(chunk.waveform) == 0:
continue
# 重采样到16000
audio = chunk.waveform.astype(np.float32)
if chunk.sample_rate != 16000:
import resampy
audio = resampy.resample(x=audio, sr_orig=chunk.sample_rate, sr_new=16000)
# 转换为16bit PCM
audio_int16 = (audio * 32767).astype(np.int16)
buffer += audio_int16.tobytes()
# 按chunk_bytes分块放入队列
while len(buffer) >= chunk_bytes:
chunk_data = buffer[:chunk_bytes]
buffer = buffer[chunk_bytes:]
audio_queue.put(chunk_data)
except Exception as e:
logger.exception("Producer error")
audio_queue.put(e)
finally:
if len(buffer) > 0:
audio_queue.put(buffer)
audio_queue.put(_DONE)
# 启动后台生成线程
producer_thread = threading.Thread(target=producer, daemon=True)
producer_thread.start()
# 主线程:从队列取数据并发送(无速率控制)
chunk_count = 0
first_chunk_time = None
while True:
try:
item = audio_queue.get(timeout=1)
except:
break
if item is _DONE:
break
if isinstance(item, Exception):
raise item
if first_chunk_time is None:
first_chunk_time = time.perf_counter()
logger.info(f"MossTTS Time to first chunk: {first_chunk_time - start_time:.4f}s")
yield item
chunk_count += 1
total_time = time.perf_counter() - start_time
logger.info(f"MossTTS streaming completed in {total_time:.4f}s, chunks sent: {chunk_count}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)