Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| from fastapi import FastAPI | |
| from huggingface_hub import hf_hub_download | |
| from src.sbv2 import utils | |
| from src.sbv2.synthesizer_trn import SynthesizerTrn | |
| from src.sbv2.text import text_to_sequence | |
| MODEL_REPO = os.getenv("MODEL_REPO") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| CACHE_DIR = "/tmp/models" | |
| app = FastAPI() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| global model, hps | |
| # config.json と model.safetensors と style_vectors.npy をダウンロード | |
| config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
| style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR) | |
| # configロード | |
| import json | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| hps = json.load(f) | |
| n_vocab = 77 # 小春音アミ用 symbol数 | |
| segment_size = 8192 # 通常固定値、Style-BERT-VITS2推奨 | |
| model = SynthesizerTrn( | |
| n_vocab, | |
| hps["model"]["p_dropout"], | |
| segment_size // 2, | |
| hps["model"]["inter_channels"], | |
| hps["model"]["out_channels"], | |
| hps["model"]["hidden_channels"], | |
| hps["model"]["filter_channels"], | |
| hps["model"]["dec_kernel_size"], | |
| hps["model"]["enc_channels"], | |
| hps["model"]["enc_out_channels"], | |
| hps["model"]["enc_kernel_size"], | |
| hps["model"]["enc_dilation_rate"], | |
| hps["model"]["enc_n_layers"], | |
| hps["model"]["flow_hidden_channels"], | |
| hps["model"]["flow_kernel_size"], | |
| hps["model"]["flow_n_layers"], | |
| hps["model"]["flow_n_flows"], | |
| hps["model"]["sdp_hidden_channels"], | |
| hps["model"]["sdp_kernel_size"], | |
| hps["model"]["sdp_n_layers"], | |
| hps["model"]["sdp_dropout"], | |
| hps["audio"]["sampling_rate"], | |
| hps["audio"]["filter_length"], | |
| hps["audio"]["hop_length"], | |
| hps["audio"]["win_length"], | |
| hps["model"]["resblock"], | |
| hps["model"]["resblock_kernel_sizes"], | |
| hps["model"]["resblock_dilation_sizes"], | |
| hps["model"]["upsample_rates"], | |
| hps["model"]["upsample_initial_channel"], | |
| hps["model"]["upsample_kernel_sizes"], | |
| hps["model"].get("gin_channels", 0) | |
| ).to(device) | |
| # safetensorsロード | |
| utils.load_checkpoint(model_path, model, strict=True) | |
| model.eval() | |
| def synthesize(text: str): | |
| # テキストを音素に変換 | |
| sequence = np.array(text_to_sequence(text, hps["data"]["text_cleaners"]), dtype=np.int64) | |
| sequence = torch.LongTensor(sequence).unsqueeze(0).to(device) | |
| # 推論 | |
| with torch.no_grad(): | |
| audio = model.infer(sequence, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0)[0][0, 0].data.cpu().numpy() | |
| # 一時WAVファイル保存 | |
| output_path = "/tmp/output.wav" | |
| sf.write(output_path, audio, hps["audio"]["sampling_rate"]) | |
| return {"audio_path": output_path} | |