Spaces:
Runtime error
Runtime error
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import base64 | |
| import io | |
| from scipy.io.wavfile import write | |
| import sounddevice as sd | |
| # 自定义模块 | |
| import commons | |
| import utils | |
| from models import SynthesizerTrn | |
| from text.symbols import symbols | |
| from text import text_to_sequence | |
| # 检查 PyTorch 版本 | |
| print(torch.__version__) | |
| # 检查 CUDA 是否可用 | |
| print(torch.cuda.is_available()) | |
| # 检查当前 CUDA 版本 | |
| print(torch.version.cuda) | |
| # FastAPI 应用 | |
| app = FastAPI() | |
| # 请求体模型 | |
| class TextRequest(BaseModel): | |
| text: str | |
| # 加载配置和模型 | |
| config_path = "configs/steins_gate_base.json" | |
| checkpoint_path = "G_265000.pth" | |
| hps = utils.get_hparams_from_file(config_path) | |
| net_g = SynthesizerTrn( | |
| len(symbols), | |
| hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps.model, | |
| ).eval() | |
| utils.load_checkpoint(checkpoint_path, net_g, None) | |
| # 文本到语音合成 | |
| def text_to_speech(content): | |
| stn_tst = text_to_sequence(content, hps.data.text_cleaners) | |
| if hps.data.add_blank: | |
| stn_tst = commons.intersperse(stn_tst, 0) | |
| stn_tst = torch.LongTensor(stn_tst) | |
| with torch.no_grad(): | |
| x_tst = stn_tst.unsqueeze(0) | |
| x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) | |
| audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1)[0][0, 0].data.float().numpy() | |
| return hps.data.sampling_rate, audio | |
| # API 路由:文本转语音 | |
| def synthesize(request: TextRequest): | |
| # 假设 text_to_speech 是生成音频的函数 | |
| sampling_rate, audio = text_to_speech(request.text) | |
| # 将音频数据保存到 BytesIO 对象 | |
| wav_bytes = io.BytesIO() | |
| write(wav_bytes, sampling_rate, (audio * 32767).astype(np.int16)) | |
| wav_bytes.seek(0) # 将指针移动到文件开头 | |
| # 将 WAV 文件编码为 Base64 | |
| audio_base64 = base64.b64encode(wav_bytes.read()).decode("utf-8") | |
| return {"audio": audio_base64} | |
| # 主函数 | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |