| from fastapi import Depends, HTTPException, Query |
| from fastapi.responses import StreamingResponse |
|
|
| import io |
| from pydantic import BaseModel |
| import soundfile as sf |
| from fastapi.responses import FileResponse |
|
|
|
|
| from modules.normalization import text_normalize |
|
|
| from modules import generate_audio as generate |
|
|
| from modules.api import utils as api_utils |
| from modules.api.Api import APIManager |
| from modules.synthesize_audio import synthesize_audio |
|
|
|
|
| class TTSParams(BaseModel): |
| text: str = Query(..., description="Text to synthesize") |
| spk: str = Query( |
| "female2", description="Specific speaker by speaker name or speaker seed" |
| ) |
| style: str = Query("chat", description="Specific style by style name") |
| temperature: float = Query( |
| 0.3, description="Temperature for sampling (may be overridden by style or spk)" |
| ) |
| top_P: float = Query( |
| 0.5, description="Top P for sampling (may be overridden by style or spk)" |
| ) |
| top_K: int = Query( |
| 20, description="Top K for sampling (may be overridden by style or spk)" |
| ) |
| seed: int = Query( |
| 42, description="Seed for generate (may be overridden by style or spk)" |
| ) |
| format: str = Query("mp3", description="Response audio format: [mp3,wav]") |
| prompt1: str = Query("", description="Text prompt for inference") |
| prompt2: str = Query("", description="Text prompt for inference") |
| prefix: str = Query("", description="Text prefix for inference") |
| bs: str = Query("8", description="Batch size for inference") |
| thr: str = Query("100", description="Threshold for sentence spliter") |
|
|
|
|
| async def synthesize_tts(params: TTSParams = Depends()): |
| try: |
| text = text_normalize(params.text, is_end=False) |
|
|
| calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style) |
|
|
| spk = calc_params.get("spk", params.spk) |
| seed = params.seed or calc_params.get("seed", params.seed) |
| temperature = params.temperature or calc_params.get( |
| "temperature", params.temperature |
| ) |
| prefix = params.prefix or calc_params.get("prefix", params.prefix) |
| prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) |
| prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) |
|
|
| batch_size = int(params.bs) |
| threshold = int(params.thr) |
|
|
| sample_rate, audio_data = synthesize_audio( |
| text, |
| temperature=temperature, |
| top_P=params.top_P, |
| top_K=params.top_K, |
| spk=spk, |
| infer_seed=seed, |
| prompt1=prompt1, |
| prompt2=prompt2, |
| prefix=prefix, |
| batch_size=batch_size, |
| spliter_threshold=threshold, |
| ) |
|
|
| buffer = io.BytesIO() |
| sf.write(buffer, audio_data, sample_rate, format="wav") |
| buffer.seek(0) |
|
|
| if format == "mp3": |
| buffer = api_utils.wav_to_mp3(buffer) |
|
|
| return StreamingResponse(buffer, media_type="audio/wav") |
|
|
| except Exception as e: |
| import logging |
|
|
| logging.exception(e) |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| def setup(api_manager: APIManager): |
| api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts) |
|
|