| import io |
| import os |
| import re |
| import shutil |
| import tempfile |
| import time |
| from http import HTTPStatus |
| from pathlib import Path |
|
|
| import numpy as np |
| import ormsgpack |
| import soundfile as sf |
| import torch |
| from kui.asgi import ( |
| Body, |
| HTTPException, |
| HttpView, |
| JSONResponse, |
| Routes, |
| StreamResponse, |
| UploadFile, |
| request, |
| ) |
| from loguru import logger |
| from typing_extensions import Annotated |
|
|
| from fish_speech.utils.schema import ( |
| AddReferenceRequest, |
| AddReferenceResponse, |
| DeleteReferenceResponse, |
| ListReferencesResponse, |
| ServeTTSRequest, |
| ServeVQGANDecodeRequest, |
| ServeVQGANDecodeResponse, |
| ServeVQGANEncodeRequest, |
| ServeVQGANEncodeResponse, |
| UpdateReferenceResponse, |
| ) |
| from tools.server.api_utils import ( |
| buffer_to_async_generator, |
| format_response, |
| get_content_type, |
| inference_async, |
| ) |
| from tools.server.inference import inference_wrapper as inference |
| from tools.server.model_manager import ModelManager |
| from tools.server.model_utils import ( |
| batch_vqgan_decode, |
| cached_vqgan_batch_encode, |
| ) |
|
|
| MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1)) |
|
|
| _WEBUI_HTML = ( |
| Path(__file__).parent.parent.parent / "awesome_webui" / "dist" / "index.html" |
| ) |
|
|
| routes = Routes() |
|
|
|
|
| @routes.http("/ui") |
| class WebUI(HttpView): |
| @classmethod |
| async def get(cls): |
| from kui.asgi import HTMLResponse |
|
|
| if _WEBUI_HTML.exists(): |
| return HTMLResponse(_WEBUI_HTML.read_text(encoding="utf-8")) |
| return JSONResponse( |
| {"error": "WebUI not built. Run: cd awesome_webui && npm run build"}, |
| status_code=404, |
| ) |
|
|
|
|
| @routes.http("/v1/health") |
| class Health(HttpView): |
| @classmethod |
| async def get(cls): |
| return JSONResponse({"status": "ok"}) |
|
|
| @classmethod |
| async def post(cls): |
| return JSONResponse({"status": "ok"}) |
|
|
|
|
| @routes.http.post("/v1/vqgan/encode") |
| async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]): |
| """ |
| Encode audio using VQGAN model. |
| """ |
| try: |
| |
| model_manager: ModelManager = request.app.state.model_manager |
| decoder_model = model_manager.decoder_model |
|
|
| |
| start_time = time.time() |
| tokens = cached_vqgan_batch_encode(decoder_model, req.audios) |
| logger.info( |
| f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms" |
| ) |
|
|
| |
| return ormsgpack.packb( |
| ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]), |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
| ) |
| except Exception as e: |
| logger.error(f"Error in VQGAN encode: {e}", exc_info=True) |
| raise HTTPException( |
| HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio" |
| ) |
|
|
|
|
| @routes.http.post("/v1/vqgan/decode") |
| async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]): |
| """ |
| Decode tokens to audio using VQGAN model. |
| """ |
| try: |
| |
| model_manager: ModelManager = request.app.state.model_manager |
| decoder_model = model_manager.decoder_model |
|
|
| |
| tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens] |
| start_time = time.time() |
| audios = batch_vqgan_decode(decoder_model, tokens) |
| logger.info( |
| f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms" |
| ) |
| audios = [audio.astype(np.float16).tobytes() for audio in audios] |
|
|
| |
| return ormsgpack.packb( |
| ServeVQGANDecodeResponse(audios=audios), |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
| ) |
| except Exception as e: |
| logger.error(f"Error in VQGAN decode: {e}", exc_info=True) |
| raise HTTPException( |
| HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio" |
| ) |
|
|
|
|
| @routes.http.post("/v1/tts") |
| async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]): |
| """ |
| Generate speech from text using TTS model. |
| """ |
| try: |
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
| sample_rate = engine.decoder_model.sample_rate |
|
|
| |
| if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length: |
| raise HTTPException( |
| HTTPStatus.BAD_REQUEST, |
| content=f"Text is too long, max length is {app_state.max_text_length}", |
| ) |
|
|
| |
| if req.streaming and req.format != "wav": |
| raise HTTPException( |
| HTTPStatus.BAD_REQUEST, |
| content="Streaming only supports WAV format", |
| ) |
|
|
| |
| if req.streaming: |
| return StreamResponse( |
| iterable=inference_async(req, engine), |
| headers={ |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", |
| }, |
| content_type=get_content_type(req.format), |
| ) |
| else: |
| fake_audios = next(inference(req, engine)) |
| buffer = io.BytesIO() |
| sf.write( |
| buffer, |
| fake_audios, |
| sample_rate, |
| format=req.format, |
| ) |
|
|
| return StreamResponse( |
| iterable=buffer_to_async_generator(buffer.getvalue()), |
| headers={ |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", |
| }, |
| content_type=get_content_type(req.format), |
| ) |
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| logger.error(f"Error in TTS generation: {e}", exc_info=True) |
| raise HTTPException( |
| HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech" |
| ) |
|
|
|
|
| @routes.http.post("/v1/references/add") |
| async def add_reference( |
| id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...) |
| ): |
| """ |
| Add a new reference voice with audio file and text. |
| """ |
| temp_file_path = None |
|
|
| try: |
| |
| if not id or not id.strip(): |
| raise ValueError("Reference ID cannot be empty") |
|
|
| if not text or not text.strip(): |
| raise ValueError("Reference text cannot be empty") |
|
|
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
|
|
| |
| audio_content = audio.read() |
| if not audio_content: |
| raise ValueError("Audio file is empty or could not be read") |
|
|
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: |
| temp_file.write(audio_content) |
| temp_file_path = temp_file.name |
|
|
| |
| engine.add_reference(id, temp_file_path, text) |
|
|
| response = AddReferenceResponse( |
| success=True, |
| message=f"Reference voice '{id}' added successfully", |
| reference_id=id, |
| ) |
| return format_response(response) |
|
|
| except FileExistsError as e: |
| logger.warning(f"Reference ID '{id}' already exists: {e}") |
| response = AddReferenceResponse( |
| success=False, |
| message=f"Reference ID '{id}' already exists", |
| reference_id=id, |
| ) |
| return format_response(response, status_code=409) |
|
|
| except ValueError as e: |
| logger.warning(f"Invalid input for reference '{id}': {e}") |
| response = AddReferenceResponse(success=False, message=str(e), reference_id=id) |
| return format_response(response, status_code=400) |
|
|
| except (FileNotFoundError, OSError) as e: |
| logger.error(f"File system error for reference '{id}': {e}") |
| response = AddReferenceResponse( |
| success=False, message="File system error occurred", reference_id=id |
| ) |
| return format_response(response, status_code=500) |
|
|
| except Exception as e: |
| logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True) |
| response = AddReferenceResponse( |
| success=False, message="Internal server error occurred", reference_id=id |
| ) |
| return format_response(response, status_code=500) |
|
|
| finally: |
| |
| if temp_file_path and os.path.exists(temp_file_path): |
| try: |
| os.unlink(temp_file_path) |
| except OSError as e: |
| logger.warning( |
| f"Failed to clean up temporary file {temp_file_path}: {e}" |
| ) |
|
|
|
|
| @routes.http.get("/v1/references/list") |
| async def list_references(): |
| """ |
| Get a list of all available reference voice IDs. |
| """ |
| try: |
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
|
|
| |
| reference_ids = engine.list_reference_ids() |
|
|
| response = ListReferencesResponse( |
| success=True, |
| reference_ids=reference_ids, |
| message=f"Found {len(reference_ids)} reference voices", |
| ) |
| return format_response(response) |
|
|
| except Exception as e: |
| logger.error(f"Unexpected error listing references: {e}", exc_info=True) |
| response = ListReferencesResponse( |
| success=False, reference_ids=[], message="Internal server error occurred" |
| ) |
| return format_response(response, status_code=500) |
|
|
|
|
| @routes.http.delete("/v1/references/delete") |
| async def delete_reference(reference_id: str = Body(...)): |
| """ |
| Delete a reference voice by ID. |
| """ |
| try: |
| |
| if not reference_id or not reference_id.strip(): |
| raise ValueError("Reference ID cannot be empty") |
|
|
| id_pattern = r"^[a-zA-Z0-9\-_ ]+$" |
| if not re.match(id_pattern, reference_id) or len(reference_id) > 255: |
| raise ValueError("Reference ID contains invalid characters or is too long") |
|
|
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
|
|
| |
| engine.delete_reference(reference_id) |
|
|
| response = DeleteReferenceResponse( |
| success=True, |
| message=f"Reference voice '{reference_id}' deleted successfully", |
| reference_id=reference_id, |
| ) |
| return format_response(response) |
|
|
| except FileNotFoundError as e: |
| logger.warning(f"Reference ID '{reference_id}' not found: {e}") |
| response = DeleteReferenceResponse( |
| success=False, |
| message=f"Reference ID '{reference_id}' not found", |
| reference_id=reference_id, |
| ) |
| return format_response(response, status_code=404) |
|
|
| except ValueError as e: |
| logger.warning(f"Invalid input for reference '{reference_id}': {e}") |
| response = DeleteReferenceResponse( |
| success=False, message=str(e), reference_id=reference_id |
| ) |
| return format_response(response, status_code=400) |
|
|
| except OSError as e: |
| logger.error(f"File system error deleting reference '{reference_id}': {e}") |
| response = DeleteReferenceResponse( |
| success=False, |
| message="File system error occurred", |
| reference_id=reference_id, |
| ) |
| return format_response(response, status_code=500) |
|
|
| except Exception as e: |
| logger.error( |
| f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True |
| ) |
| response = DeleteReferenceResponse( |
| success=False, |
| message="Internal server error occurred", |
| reference_id=reference_id, |
| ) |
| return format_response(response, status_code=500) |
|
|
|
|
| @routes.http.post("/v1/references/update") |
| async def update_reference( |
| old_reference_id: str = Body(...), new_reference_id: str = Body(...) |
| ): |
| """ |
| Rename a reference voice directory from old_reference_id to new_reference_id. |
| """ |
| try: |
| |
| if not old_reference_id or not old_reference_id.strip(): |
| raise ValueError("Old reference ID cannot be empty") |
| if not new_reference_id or not new_reference_id.strip(): |
| raise ValueError("New reference ID cannot be empty") |
| if old_reference_id == new_reference_id: |
| raise ValueError("New reference ID must be different from old reference ID") |
|
|
| |
| id_pattern = r"^[a-zA-Z0-9\-_ ]+$" |
| if not re.match(id_pattern, old_reference_id) or len(old_reference_id) > 255: |
| raise ValueError( |
| "Old reference ID contains invalid characters or is too long" |
| ) |
| if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255: |
| raise ValueError( |
| "New reference ID contains invalid characters or is too long" |
| ) |
|
|
| |
| app_state = request.app.state |
| model_manager: ModelManager = app_state.model_manager |
| engine = model_manager.tts_inference_engine |
|
|
| refs_base = Path("references") |
| old_dir = refs_base / old_reference_id |
| new_dir = refs_base / new_reference_id |
|
|
| |
| if not old_dir.exists() or not old_dir.is_dir(): |
| raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found") |
| if new_dir.exists(): |
| |
| response = UpdateReferenceResponse( |
| success=False, |
| message=f"Reference ID '{new_reference_id}' already exists", |
| old_reference_id=old_reference_id, |
| new_reference_id=new_reference_id, |
| ) |
| return format_response(response, status_code=409) |
|
|
| |
| old_dir.rename(new_dir) |
|
|
| |
| if old_reference_id in engine.ref_by_id: |
| engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id) |
|
|
| response = UpdateReferenceResponse( |
| success=True, |
| message=( |
| f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully" |
| ), |
| old_reference_id=old_reference_id, |
| new_reference_id=new_reference_id, |
| ) |
| return format_response(response) |
|
|
| except FileNotFoundError as e: |
| logger.warning(str(e)) |
| response = UpdateReferenceResponse( |
| success=False, |
| message=str(e), |
| old_reference_id=old_reference_id, |
| new_reference_id=new_reference_id, |
| ) |
| return format_response(response, status_code=404) |
|
|
| except ValueError as e: |
| logger.warning(f"Invalid input for update reference: {e}") |
| response = UpdateReferenceResponse( |
| success=False, |
| message=str(e), |
| old_reference_id=old_reference_id if "old_reference_id" in locals() else "", |
| new_reference_id=new_reference_id if "new_reference_id" in locals() else "", |
| ) |
| return format_response(response, status_code=400) |
|
|
| except OSError as e: |
| logger.error(f"File system error renaming reference: {e}") |
| response = UpdateReferenceResponse( |
| success=False, |
| message="File system error occurred", |
| old_reference_id=old_reference_id, |
| new_reference_id=new_reference_id, |
| ) |
| return format_response(response, status_code=500) |
|
|
| except Exception as e: |
| logger.error(f"Unexpected error updating reference: {e}", exc_info=True) |
| response = UpdateReferenceResponse( |
| success=False, |
| message="Internal server error occurred", |
| old_reference_id=old_reference_id if "old_reference_id" in locals() else "", |
| new_reference_id=new_reference_id if "new_reference_id" in locals() else "", |
| ) |
| return format_response(response, status_code=500) |
|
|