Spaces:
Runtime error
Runtime error
| # Set inference model | |
| # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct | |
| # For development | |
| # fastapi dev --port 6006 fastapi_server.py | |
| # For production deployment | |
| # fastapi run --port 6006 fastapi_server.py | |
| import os | |
| import sys | |
| import io,time | |
| from fastapi import FastAPI, Response, File, UploadFile, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块 | |
| from contextlib import asynccontextmanager | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append('{}/../../..'.format(ROOT_DIR)) | |
| sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) | |
| from cosyvoice.cli.cosyvoice import CosyVoice | |
| from cosyvoice.utils.file_utils import load_wav | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import logging | |
| logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
| class LaunchFailed(Exception): | |
| pass | |
| async def lifespan(app: FastAPI): | |
| model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT") | |
| if model_dir: | |
| logging.info("MODEL_DIR is {}", model_dir) | |
| app.cosyvoice = CosyVoice(model_dir) | |
| # sft usage | |
| logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks()) | |
| else: | |
| raise LaunchFailed("MODEL_DIR environment must set") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| #设置允许访问的域名 | |
| origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, #设置允许的origins来源 | |
| allow_credentials=True, | |
| allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。 | |
| allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。 | |
| def buildResponse(output): | |
| buffer = io.BytesIO() | |
| torchaudio.save(buffer, output, 22050, format="wav") | |
| buffer.seek(0) | |
| return Response(content=buffer.read(-1), media_type="audio/wav") | |
| async def sft(tts: str = Form(), role: str = Form()): | |
| start = time.process_time() | |
| output = app.cosyvoice.inference_sft(tts, role) | |
| end = time.process_time() | |
| logging.info("infer time is {} seconds", end-start) | |
| return buildResponse(output['tts_speech']) | |
| async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()): | |
| start = time.process_time() | |
| prompt_speech = load_wav(audio.file, 16000) | |
| prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | |
| prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) | |
| prompt_speech_16k = prompt_speech_16k.float() / (2**15) | |
| output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k) | |
| end = time.process_time() | |
| logging.info("infer time is {} seconds", end-start) | |
| return buildResponse(output['tts_speech']) | |
| async def crossLingual(tts: str = Form(), audio: UploadFile = File()): | |
| start = time.process_time() | |
| prompt_speech = load_wav(audio.file, 16000) | |
| prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() | |
| prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0) | |
| prompt_speech_16k = prompt_speech_16k.float() / (2**15) | |
| output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k) | |
| end = time.process_time() | |
| logging.info("infer time is {} seconds", end-start) | |
| return buildResponse(output['tts_speech']) | |
| async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()): | |
| start = time.process_time() | |
| output = app.cosyvoice.inference_instruct(tts, role, instruct) | |
| end = time.process_time() | |
| logging.info("infer time is {} seconds", end-start) | |
| return buildResponse(output['tts_speech']) | |
| async def roles(): | |
| return {"roles": app.cosyvoice.list_avaliable_spks()} | |
| async def root(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html lang=zh-cn> | |
| <head> | |
| <meta charset=utf-8> | |
| <title>Api information</title> | |
| </head> | |
| <body> | |
| Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a> | |
| </body> | |
| </html> | |
| """ | |