| """ |
| ATR TTS API Server |
| |
| ๅฏๅจ: python api_atri.py [-a 127.0.0.1] [-p 9880] |
| |
| ๆฅๅฃๆๆกฃ: http://127.0.0.1:9880/docs |
| """ |
|
|
| import os |
| import sys |
| import signal |
| import argparse |
| import subprocess |
| import threading |
| import wave |
| from io import BytesIO |
| from typing import Generator, Optional, Union |
|
|
| import numpy as np |
| import soundfile as sf |
| from fastapi import FastAPI |
| from fastapi.responses import JSONResponse, Response, StreamingResponse |
| from pydantic import BaseModel, Field |
| import uvicorn |
|
|
| now_dir = os.getcwd() |
| sys.path.append(now_dir) |
| sys.path.append(os.path.join(now_dir, "GPT_SoVITS")) |
|
|
| from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config |
|
|
| |
| GPT_MODEL = "/path/to/GPT_SoVITS/pretrained_models/s1v3.ckpt" |
| SOVITS_MODEL = "/path/to/ATR_e8_s3952.pth" |
| REF_AUDIO = "/path/to/ref_audio.wav" |
| REF_TEXT = "ใใใใฏใในใฟใผใฎๆๆ็ฉใงใใฎใงใ ๅๆใซๅฃฒ่ฒทใใใฎใฏ้ๆณใงใ" |
| REF_LANG = "ja" |
| VERSION = "v2Pro" |
|
|
|
|
| |
| class TTSRequest(BaseModel): |
| text: str = Field(..., description="่ฆๅๆ็ๆๆฌ", examples=["ใใใซใกใฏใใๅ
ๆฐใงใใ๏ผ"]) |
| text_lang: str = Field(..., description="ๆๆฌ่ฏญ่จ: zh, en, ja, ko, yue", examples=["ja"]) |
| ref_audio_path: Optional[str] = Field(None, description="ๅ่้ณ้ข่ทฏๅพ (็็ฉบไฝฟ็จ้ป่ฎค)") |
| prompt_text: Optional[str] = Field(None, description="ๅ่้ณ้ข็ๆๆฌ (็็ฉบไฝฟ็จ้ป่ฎค)") |
| prompt_lang: Optional[str] = Field(None, description="ๅ่้ณ้ข็่ฏญ่จ (็็ฉบไฝฟ็จ้ป่ฎค)") |
| speed_factor: float = Field(1.0, ge=0.5, le=2.0, description="่ฏญ้ๅ็") |
| top_k: int = Field(15, ge=1, description="Top-K ้ๆ ท") |
| top_p: float = Field(1.0, ge=0.0, le=1.0, description="Top-P ้ๆ ท") |
| temperature: float = Field(1.0, ge=0.0, le=2.0, description="้ๆ ทๆธฉๅบฆ") |
| seed: int = Field(-1, description="้ๆบ็งๅญ (-1 ไธบ้ๆบ)") |
| media_type: str = Field("wav", description="่พๅบๆ ผๅผ: wav, ogg, aac, raw") |
| text_split_method: str = Field("cut5", description="ๆๆฌๅๅๆนๅผ: cut0-cut5") |
| batch_size: int = Field(1, ge=1, description="ๆจ็ๆนๅคงๅฐ") |
| sample_steps: int = Field(32, ge=1, description="้ๆ ทๆญฅๆฐ") |
|
|
| model_config = {"json_schema_extra": { |
| "examples": [ |
| {"text": "ใใใซใกใฏใใๅ
ๆฐใงใใ๏ผ", "text_lang": "ja"}, |
| {"text": "ไฝ ๅฅฝ๏ผๅพ้ซๅ
ด่ฎค่ฏไฝ ใ", "text_lang": "zh"}, |
| ] |
| }} |
|
|
|
|
| class TTSStreamRequest(TTSRequest): |
| streaming_mode: int = Field(1, ge=1, le=3, |
| description="ๆตๅผๆจกๅผ: 1=ๅๆฎต(ๆ้ซ่ดจ้), 2=็ๆตๅผ(ไธญ็ญ่ดจ้), 3=ๅฎ้ฟๆตๅผ(ๆๅฟซๅๅบ)") |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str |
| gpt_model: str |
| sovits_model: str |
| ref_audio: str |
| version: str |
|
|
|
|
| |
| def pack_wav(data: np.ndarray, rate: int) -> bytes: |
| buf = BytesIO() |
| sf.write(buf, data, rate, format="wav") |
| buf.seek(0) |
| return buf.read() |
|
|
|
|
| def pack_ogg(data: np.ndarray, rate: int) -> bytes: |
| buf = BytesIO() |
| def _write(): |
| with sf.SoundFile(buf, mode="w", samplerate=rate, channels=1, format="ogg") as f: |
| f.write(data) |
| t = threading.Thread(target=_write) |
| threading.stack_size(4096 * 4096) |
| t.start() |
| t.join() |
| buf.seek(0) |
| return buf.read() |
|
|
|
|
| def pack_aac(data: np.ndarray, rate: int) -> bytes: |
| proc = subprocess.Popen( |
| ["ffmpeg", "-f", "s16le", "-ar", str(rate), "-ac", "1", |
| "-i", "pipe:0", "-c:a", "aac", "-b:a", "192k", "-vn", "-f", "adts", "pipe:1"], |
| stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, |
| ) |
| out, _ = proc.communicate(input=data.tobytes()) |
| return out |
|
|
|
|
| def pack_audio(data: np.ndarray, rate: int, media_type: str) -> bytes: |
| if media_type == "ogg": |
| return pack_ogg(data, rate) |
| elif media_type == "aac": |
| return pack_aac(data, rate) |
| elif media_type == "wav": |
| return pack_wav(data, rate) |
| return data.tobytes() |
|
|
|
|
| def wave_header_chunk(sample_rate: int = 32000) -> bytes: |
| buf = BytesIO() |
| with wave.open(buf, "wb") as w: |
| w.setnchannels(1) |
| w.setsampwidth(2) |
| w.setframerate(sample_rate) |
| w.writeframes(b"") |
| buf.seek(0) |
| return buf.read() |
|
|
|
|
| |
| app = FastAPI( |
| title="ATR TTS API", |
| description="ATR ่ง่ฒ่ฏญ้ณๅๆๆฅๅฃใๅบไบ GPT-SoVITS v2Pro ๆจกๅใ", |
| version="1.0.0", |
| ) |
| tts_pipeline: TTS = None |
|
|
|
|
| @app.on_event("startup") |
| def startup(): |
| global tts_pipeline |
| print("Loading models...") |
| config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") |
| config.device = "cpu" |
| config.is_half = False |
| config.version = VERSION |
| config.t2s_weights_path = GPT_MODEL |
| config.vits_weights_path = SOVITS_MODEL |
| tts_pipeline = TTS(config) |
| print(f"Models loaded. Version: {VERSION}") |
| print(f" GPT: {GPT_MODEL}") |
| print(f" SoVITS: {SOVITS_MODEL}") |
| print(f" Ref: {REF_AUDIO}") |
|
|
|
|
| def _build_req(req: TTSRequest, streaming_mode: int = 0) -> dict: |
| return { |
| "text": req.text, |
| "text_lang": req.text_lang.lower(), |
| "ref_audio_path": req.ref_audio_path or REF_AUDIO, |
| "prompt_text": req.prompt_text if req.prompt_text is not None else REF_TEXT, |
| "prompt_lang": (req.prompt_lang or REF_LANG).lower(), |
| "top_k": req.top_k, |
| "top_p": req.top_p, |
| "temperature": req.temperature, |
| "text_split_method": req.text_split_method, |
| "batch_size": req.batch_size, |
| "batch_threshold": 0.75, |
| "split_bucket": True, |
| "speed_factor": req.speed_factor, |
| "fragment_interval": 0.3, |
| "seed": req.seed, |
| "media_type": req.media_type, |
| "streaming_mode": streaming_mode in (2, 3), |
| "return_fragment": streaming_mode == 1, |
| "fixed_length_chunk": streaming_mode == 3, |
| "parallel_infer": True, |
| "repetition_penalty": 1.35, |
| "sample_steps": req.sample_steps, |
| "super_sampling": False, |
| } |
|
|
|
|
| |
| @app.get("/health", response_model=HealthResponse, summary="ๅฅๅบทๆฃๆฅ") |
| async def health(): |
| return HealthResponse( |
| status="ok", |
| gpt_model=GPT_MODEL, |
| sovits_model=SOVITS_MODEL, |
| ref_audio=REF_AUDIO, |
| version=VERSION, |
| ) |
|
|
|
|
| @app.post("/tts", summary="่ฏญ้ณๅๆ", description="่พๅ
ฅๆๆฌ๏ผ่ฟๅๅฎๆด็้ณ้ขๆไปถใ", |
| responses={200: {"content": {"audio/wav": {}}}, 400: {"description": "ๅๆๅคฑ่ดฅ"}}) |
| async def tts_endpoint(request: TTSRequest): |
| req = _build_req(request, streaming_mode=0) |
| try: |
| gen = tts_pipeline.run(req) |
| sr, audio = next(gen) |
| audio_bytes = pack_audio(audio, sr, request.media_type) |
| return Response(audio_bytes, media_type=f"audio/{request.media_type}") |
| except Exception as e: |
| return JSONResponse(status_code=400, content={"message": "tts failed", "error": str(e)}) |
|
|
|
|
| @app.post("/tts/stream", summary="ๆตๅผ่ฏญ้ณๅๆ", |
| description="่พๅ
ฅๆๆฌ๏ผๆตๅผ่ฟๅ้ณ้ขๆฐๆฎใ้็จไบ้ฟๆๆฌๅฎๆถๆญๆพใ", |
| responses={200: {"content": {"audio/wav": {}}}, 400: {"description": "ๅๆๅคฑ่ดฅ"}}) |
| async def tts_stream_endpoint(request: TTSStreamRequest): |
| req = _build_req(request, streaming_mode=request.streaming_mode) |
| try: |
| gen = tts_pipeline.run(req) |
| media = request.media_type |
|
|
| def stream(gen: Generator): |
| first = True |
| for sr, chunk in gen: |
| if first and media == "wav": |
| yield wave_header_chunk(sample_rate=sr) |
| first = False |
| yield pack_audio(chunk, sr, "raw") |
| else: |
| yield pack_audio(chunk, sr, media) |
|
|
| return StreamingResponse(stream(gen), media_type=f"audio/{media}") |
| except Exception as e: |
| return JSONResponse(status_code=400, content={"message": "tts failed", "error": str(e)}) |
|
|
|
|
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="ATR TTS API Server") |
| parser.add_argument("-a", "--host", default="127.0.0.1", help="็ปๅฎๅฐๅ (้ป่ฎค 127.0.0.1)") |
| parser.add_argument("-p", "--port", type=int, default=9880, help="็ซฏๅฃๅท (้ป่ฎค 9880)") |
| args = parser.parse_args() |
|
|
| print(f"\n API docs: http://{args.host}:{args.port}/docs\n") |
| uvicorn.run(app=app, host=args.host, port=args.port, workers=1) |
|
|