| """ |
| MOSS-TTS-Nano ONNX - Docker 部署版本 (FastAPI) |
| 提供 REST API 接口,支持流式和非流式输出,兼容 Qwen3TTS 接口格式 |
| 自动从 HuggingFace Hub 下载模型,启动时预加载并预热 |
| """ |
|
|
| import os |
| import sys |
| import logging |
| from pathlib import Path |
| from typing import Generator, Optional |
| from queue import Queue |
| import threading |
| import time |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import FileResponse, StreamingResponse |
| from pydantic import BaseModel |
| import numpy as np |
|
|
| from starlette.concurrency import iterate_in_threadpool |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[logging.StreamHandler(sys.stdout)], |
| force=True |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MODEL_DIR = Path("/data/models") if os.path.exists("/data") else Path("./models") |
| MODEL_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| service_path = Path(__file__).resolve().parent / "moss_tts_service" |
| if str(service_path) not in sys.path: |
| sys.path.insert(0, str(service_path)) |
|
|
| from moss_tts_service import MossTtsService |
|
|
| |
| _tts_service: Optional[MossTtsService] = None |
|
|
| app = FastAPI(title="MOSS-TTS-Nano ONNX API", version="1.0.0") |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| @app.get("/") |
| async def root(): |
| """返回测试页面""" |
| return FileResponse("static/index.html") |
|
|
|
|
| class TTSRequest(BaseModel): |
| """TTS 请求参数""" |
| input: str |
| voice: str = "Junhao" |
| response_format: str = "wav" |
| temperature: float = 1.0 |
| audio_temperature: float = 0.8 |
| stream: bool = False |
|
|
|
|
| def _init_tts_service() -> MossTtsService: |
| """初始化并预热 TTS 服务""" |
| global _tts_service |
| logger.info("初始化 MossTtsService...") |
| _tts_service = MossTtsService( |
| model_dir=str(MODEL_DIR), |
| cpu_threads=2, |
| ) |
| logger.info("MossTtsService 初始化完成!") |
| return _tts_service |
|
|
|
|
| def _get_tts_service() -> MossTtsService: |
| """获取 TTS 服务实例""" |
| global _tts_service |
| if _tts_service is None: |
| _init_tts_service() |
| return _tts_service |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """启动时初始化并预热模型""" |
| logger.info("系统启动,开始加载模型...") |
| service = _init_tts_service() |
| logger.info("模型预热完成,服务就绪!") |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| """健康检查""" |
| return {"status": "ok", "service": "MOSS-TTS-Nano-ONNX"} |
|
|
|
|
| @app.get("/voices") |
| async def list_voices(): |
| """列出所有可用的声音""" |
| service = _get_tts_service() |
| return service.list_voices() |
|
|
|
|
| @app.post("/v1/audio/speech") |
| async def create_speech(req: TTSRequest): |
| """ |
| 兼容 Qwen3TTS 的接口格式 |
| """ |
| if not req.input: |
| raise HTTPException(status_code=400, detail="Missing 'input' parameter") |
| |
| logger.info(f"MossTTS request: text='{req.input[:50]}...', voice={req.voice}, stream={req.stream}") |
| |
| if not req.stream: |
| return await _handle_non_stream(req) |
| else: |
| logger.info("Starting streaming response...") |
| sync_gen = _generate_audio_chunks_sync(req.input, req.voice) |
| return StreamingResponse( |
| iterate_in_threadpool(sync_gen), |
| media_type="audio/pcm" if req.response_format == "pcm" else "audio/wav" |
| ) |
|
|
|
|
| async def _handle_non_stream(req: TTSRequest): |
| """非流式处理:生成完整音频后返回""" |
| service = _get_tts_service() |
| |
| temp_file = f"/tmp/tts_output_{int(time.time())}.wav" |
| try: |
| output_path = service.synthesize_to_file( |
| text=req.input, |
| output_path=temp_file, |
| voice=req.voice, |
| ) |
| logger.info(f"Non-stream synthesis completed: {output_path}") |
| return FileResponse( |
| path=output_path, |
| media_type="audio/wav", |
| filename="output.wav" |
| ) |
| except Exception as e: |
| logger.exception("Non-stream synthesis failed") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| def _generate_audio_chunks_sync(text: str, voice: str) -> Generator[bytes, None, None]: |
| """ |
| 同步生成器:流式生成音频块(PCM 格式,16000Hz,16bit,单声道) |
| 去掉速率控制,生成即发送 |
| """ |
| service = _get_tts_service() |
| |
| logger.info(f"Starting streaming synthesis: text='{text[:50]}...'") |
| start_time = time.perf_counter() |
| |
| chunk_samples = 320 |
| chunk_bytes = chunk_samples * 2 |
| |
| |
| audio_queue = Queue(maxsize=50) |
| _DONE = object() |
| |
| def producer(): |
| """后台线程:生成音频并放入队列""" |
| buffer = b"" |
| try: |
| for chunk in service.synthesize_stream( |
| text=text, |
| voice=voice, |
| temperature=1.0, |
| audio_temperature=0.8, |
| ): |
| if chunk.waveform is None or len(chunk.waveform) == 0: |
| continue |
| |
| |
| audio = chunk.waveform.astype(np.float32) |
| if chunk.sample_rate != 16000: |
| import resampy |
| audio = resampy.resample(x=audio, sr_orig=chunk.sample_rate, sr_new=16000) |
| |
| |
| audio_int16 = (audio * 32767).astype(np.int16) |
| buffer += audio_int16.tobytes() |
| |
| |
| while len(buffer) >= chunk_bytes: |
| chunk_data = buffer[:chunk_bytes] |
| buffer = buffer[chunk_bytes:] |
| audio_queue.put(chunk_data) |
| |
| except Exception as e: |
| logger.exception("Producer error") |
| audio_queue.put(e) |
| finally: |
| if len(buffer) > 0: |
| audio_queue.put(buffer) |
| audio_queue.put(_DONE) |
| |
| |
| producer_thread = threading.Thread(target=producer, daemon=True) |
| producer_thread.start() |
| |
| |
| chunk_count = 0 |
| first_chunk_time = None |
| |
| while True: |
| try: |
| item = audio_queue.get(timeout=1) |
| except: |
| break |
| |
| if item is _DONE: |
| break |
| |
| if isinstance(item, Exception): |
| raise item |
| |
| if first_chunk_time is None: |
| first_chunk_time = time.perf_counter() |
| logger.info(f"MossTTS Time to first chunk: {first_chunk_time - start_time:.4f}s") |
| |
| yield item |
| chunk_count += 1 |
| |
| total_time = time.perf_counter() - start_time |
| logger.info(f"MossTTS streaming completed in {total_time:.4f}s, chunks sent: {chunk_count}") |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|