Spaces:
Running
Running
File size: 5,211 Bytes
c441d2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import asyncio
import os
from typing import AsyncIterator, Optional, Callable, Union, Dict
import logging
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from .Audio.ReferenceAudio import ReferenceAudio
from .Core.TTSPlayer import tts_player
from .ModelManager import model_manager
from .Utils.Shared import context
from .Utils.Language import normalize_language
logger = logging.getLogger(__name__)
_reference_audios: Dict[str, dict] = {}
SUPPORTED_AUDIO_EXTS = {'.wav', '.flac', '.ogg', '.aiff', '.aif'}
app = FastAPI()
class CharacterPayload(BaseModel):
character_name: str
onnx_model_dir: str
language: str
class UnloadCharacterPayload(BaseModel):
character_name: str
class ReferenceAudioPayload(BaseModel):
character_name: str
audio_path: str
audio_text: str
language: str
class TTSPayload(BaseModel):
character_name: str
text: str
split_sentence: bool = False
save_path: Optional[str] = None
@app.post("/load_character")
def load_character_endpoint(payload: CharacterPayload):
try:
model_manager.load_character(
character_name=payload.character_name,
model_dir=payload.onnx_model_dir,
language=normalize_language(payload.language),
)
return {"status": "success", "message": f"Character '{payload.character_name}' loaded."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/unload_character")
def unload_character_endpoint(payload: UnloadCharacterPayload):
try:
model_manager.remove_character(character_name=payload.character_name)
return {"status": "success", "message": f"Character '{payload.character_name}' unloaded."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/set_reference_audio")
def set_reference_audio_endpoint(payload: ReferenceAudioPayload):
ext = os.path.splitext(payload.audio_path)[1].lower()
if ext not in SUPPORTED_AUDIO_EXTS:
raise HTTPException(
status_code=400,
detail=f"Audio format '{ext}' is not supported. Supported formats: {SUPPORTED_AUDIO_EXTS}",
)
_reference_audios[payload.character_name] = {
'audio_path': payload.audio_path,
'audio_text': payload.audio_text,
'language': normalize_language(payload.language),
}
return {"status": "success", "message": f"Reference audio for '{payload.character_name}' set."}
def run_tts_in_background(
character_name: str,
text: str,
split_sentence: bool,
save_path: Optional[str],
chunk_callback: Callable[[Optional[bytes]], None]
):
try:
context.current_speaker = character_name
context.current_prompt_audio = ReferenceAudio(
prompt_wav=_reference_audios[character_name]['audio_path'],
prompt_text=_reference_audios[character_name]['audio_text'],
language=_reference_audios[character_name]['language'],
)
tts_player.start_session(
play=False,
split=split_sentence,
save_path=save_path,
chunk_callback=chunk_callback,
)
tts_player.feed(text)
tts_player.end_session()
tts_player.wait_for_tts_completion()
except Exception as e:
logger.error(f"Error in TTS background task: {e}", exc_info=True)
async def audio_stream_generator(queue: asyncio.Queue) -> AsyncIterator[bytes]:
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
@app.post("/tts")
async def tts_endpoint(payload: TTSPayload):
if payload.character_name not in _reference_audios:
raise HTTPException(status_code=404, detail="Character not found or reference audio not set.")
loop = asyncio.get_running_loop()
stream_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue()
def tts_chunk_callback(chunk: Optional[bytes]):
loop.call_soon_threadsafe(stream_queue.put_nowait, chunk)
loop.run_in_executor(
None,
run_tts_in_background,
payload.character_name,
payload.text,
payload.split_sentence,
payload.save_path,
tts_chunk_callback
)
return StreamingResponse(audio_stream_generator(stream_queue), media_type="audio/wav")
@app.post("/stop")
def stop_endpoint():
try:
tts_player.stop()
return {"status": "success", "message": "TTS stopped."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/clear_reference_audio_cache")
def clear_reference_audio_cache_endpoint():
try:
ReferenceAudio.clear_cache()
return {"status": "success", "message": "Reference audio cache cleared."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def start_server(host: str = "127.0.0.1", port: int = 8000, workers: int = 1):
uvicorn.run(app, host=host, port=port, workers=workers)
if __name__ == "__main__":
start_server(host="0.0.0.0", port=8000, workers=1)
|