File size: 7,261 Bytes
8929f9a
 
c36e2bd
6de4ff2
8929f9a
 
 
 
 
 
c36e2bd
 
 
 
8929f9a
 
 
 
 
 
 
c36e2bd
9191ff8
 
c36e2bd
8929f9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9191ff8
 
 
 
 
 
 
 
8929f9a
 
 
 
 
 
 
 
c36e2bd
8929f9a
 
6de4ff2
 
 
 
 
 
 
 
 
 
 
 
c36e2bd
6de4ff2
8929f9a
 
6de4ff2
8929f9a
 
 
6de4ff2
 
 
 
 
 
 
 
8929f9a
 
 
 
 
 
 
 
 
c36e2bd
8929f9a
 
 
 
 
 
c36e2bd
8929f9a
c36e2bd
8929f9a
 
c36e2bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8929f9a
 
 
 
 
 
c36e2bd
8929f9a
 
 
 
 
 
c36e2bd
8929f9a
 
 
c36e2bd
 
 
6de4ff2
c36e2bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6de4ff2
c36e2bd
 
 
 
 
 
 
6de4ff2
c36e2bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8929f9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
MOSS-TTS-Nano ONNX - Docker 部署版本 (FastAPI)
提供 REST API 接口,支持流式和非流式输出,兼容 Qwen3TTS 接口格式
自动从 HuggingFace Hub 下载模型,启动时预加载并预热
"""

import os
import sys
import logging
from pathlib import Path
from typing import Generator, Optional
from queue import Queue
import threading
import time

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import numpy as np

from starlette.concurrency import iterate_in_threadpool
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True
)
logger = logging.getLogger(__name__)

# 设置模型目录
MODEL_DIR = Path("/data/models") if os.path.exists("/data") else Path("./models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# 添加 moss_tts_service 到路径
service_path = Path(__file__).resolve().parent / "moss_tts_service"
if str(service_path) not in sys.path:
    sys.path.insert(0, str(service_path))

from moss_tts_service import MossTtsService

# 全局 TTS 服务实例
_tts_service: Optional[MossTtsService] = None

app = FastAPI(title="MOSS-TTS-Nano ONNX API", version="1.0.0")

# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/")
async def root():
    """返回测试页面"""
    return FileResponse("static/index.html")


class TTSRequest(BaseModel):
    """TTS 请求参数"""
    input: str
    voice: str = "Junhao"
    response_format: str = "wav"
    temperature: float = 1.0
    audio_temperature: float = 0.8
    stream: bool = False


def _init_tts_service() -> MossTtsService:
    """初始化并预热 TTS 服务"""
    global _tts_service
    logger.info("初始化 MossTtsService...")
    _tts_service = MossTtsService(
        model_dir=str(MODEL_DIR),
        cpu_threads=2,
    )
    logger.info("MossTtsService 初始化完成!")
    return _tts_service


def _get_tts_service() -> MossTtsService:
    """获取 TTS 服务实例"""
    global _tts_service
    if _tts_service is None:
        _init_tts_service()
    return _tts_service


@app.on_event("startup")
async def startup_event():
    """启动时初始化并预热模型"""
    logger.info("系统启动,开始加载模型...")
    service = _init_tts_service()
    logger.info("模型预热完成,服务就绪!")


@app.get("/health")
async def health():
    """健康检查"""
    return {"status": "ok", "service": "MOSS-TTS-Nano-ONNX"}


@app.get("/voices")
async def list_voices():
    """列出所有可用的声音"""
    service = _get_tts_service()
    return service.list_voices()


@app.post("/v1/audio/speech")
async def create_speech(req: TTSRequest):
    """
    兼容 Qwen3TTS 的接口格式
    """
    if not req.input:
        raise HTTPException(status_code=400, detail="Missing 'input' parameter")
    
    logger.info(f"MossTTS request: text='{req.input[:50]}...', voice={req.voice}, stream={req.stream}")
    
    if not req.stream:
        return await _handle_non_stream(req)
    else:
        logger.info("Starting streaming response...")
        sync_gen = _generate_audio_chunks_sync(req.input, req.voice)
        return StreamingResponse(
            iterate_in_threadpool(sync_gen),
            media_type="audio/pcm" if req.response_format == "pcm" else "audio/wav"
        )


async def _handle_non_stream(req: TTSRequest):
    """非流式处理:生成完整音频后返回"""
    service = _get_tts_service()
    
    temp_file = f"/tmp/tts_output_{int(time.time())}.wav"
    try:
        output_path = service.synthesize_to_file(
            text=req.input,
            output_path=temp_file,
            voice=req.voice,
        )
        logger.info(f"Non-stream synthesis completed: {output_path}")
        return FileResponse(
            path=output_path,
            media_type="audio/wav",
            filename="output.wav"
        )
    except Exception as e:
        logger.exception("Non-stream synthesis failed")
        raise HTTPException(status_code=500, detail=str(e))


def _generate_audio_chunks_sync(text: str, voice: str) -> Generator[bytes, None, None]:
    """
    同步生成器:流式生成音频块(PCM 格式,16000Hz,16bit,单声道)
    去掉速率控制,生成即发送
    """
    service = _get_tts_service()
    
    logger.info(f"Starting streaming synthesis: text='{text[:50]}...'")
    start_time = time.perf_counter()
    
    chunk_samples = 320  # 20ms的音频样本数
    chunk_bytes = chunk_samples * 2  # 640 bytes per chunk
    
    # 使用队列传递音频数据
    audio_queue = Queue(maxsize=50)
    _DONE = object()
    
    def producer():
        """后台线程:生成音频并放入队列"""
        buffer = b""
        try:
            for chunk in service.synthesize_stream(
                text=text,
                voice=voice,
                temperature=1.0,
                audio_temperature=0.8,
            ):
                if chunk.waveform is None or len(chunk.waveform) == 0:
                    continue
                
                # 重采样到16000
                audio = chunk.waveform.astype(np.float32)
                if chunk.sample_rate != 16000:
                    import resampy
                    audio = resampy.resample(x=audio, sr_orig=chunk.sample_rate, sr_new=16000)
                
                # 转换为16bit PCM
                audio_int16 = (audio * 32767).astype(np.int16)
                buffer += audio_int16.tobytes()
                
                # 按chunk_bytes分块放入队列
                while len(buffer) >= chunk_bytes:
                    chunk_data = buffer[:chunk_bytes]
                    buffer = buffer[chunk_bytes:]
                    audio_queue.put(chunk_data)
        
        except Exception as e:
            logger.exception("Producer error")
            audio_queue.put(e)
        finally:
            if len(buffer) > 0:
                audio_queue.put(buffer)
            audio_queue.put(_DONE)
    
    # 启动后台生成线程
    producer_thread = threading.Thread(target=producer, daemon=True)
    producer_thread.start()
    
    # 主线程:从队列取数据并发送(无速率控制)
    chunk_count = 0
    first_chunk_time = None
    
    while True:
        try:
            item = audio_queue.get(timeout=1)
        except:
            break
        
        if item is _DONE:
            break
        
        if isinstance(item, Exception):
            raise item
        
        if first_chunk_time is None:
            first_chunk_time = time.perf_counter()
            logger.info(f"MossTTS Time to first chunk: {first_chunk_time - start_time:.4f}s")
        
        yield item
        chunk_count += 1
    
    total_time = time.perf_counter() - start_time
    logger.info(f"MossTTS streaming completed in {total_time:.4f}s, chunks sent: {chunk_count}")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)