buchi-stdesign commited on
Commit
44571ac
·
verified ·
1 Parent(s): 1ee91f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +10 -6
  2. inference.py +81 -81
app.py CHANGED
@@ -1,16 +1,20 @@
1
- from fastapi import FastAPI
2
  from fastapi.responses import StreamingResponse
3
- from inference import synthesize_voice, load_model
4
  import io
 
 
 
5
 
6
  app = FastAPI()
7
 
8
- # 🛠 サーバ起動時にモデルをロードする
9
  @app.on_event("startup")
10
  async def startup_event():
11
  load_model()
12
 
13
  @app.get("/voice")
14
- async def voice_endpoint(text: str):
15
- wav_bytes = synthesize_voice(text)
16
- return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
  from fastapi.responses import StreamingResponse
 
3
  import io
4
+ import numpy as np
5
+ import soundfile as sf
6
+ from inference import load_model, synthesize
7
 
8
  app = FastAPI()
9
 
 
10
  @app.on_event("startup")
11
  async def startup_event():
12
  load_model()
13
 
14
  @app.get("/voice")
15
+ async def voice(text: str = Query(..., description="Text to synthesize")):
16
+ audio = synthesize(text)
17
+ buf = io.BytesIO()
18
+ sf.write(buf, audio, 24000, format="WAV")
19
+ buf.seek(0)
20
+ return StreamingResponse(buf, media_type="audio/wav")
inference.py CHANGED
@@ -1,81 +1,81 @@
1
- import torch
2
- import os
3
- from huggingface_hub import hf_hub_download
4
- from src.sbv2.synthesizer_trn import SynthesizerTrn
5
- from src.sbv2.text import text_to_sequence
6
- from src.sbv2.commons import get_hparams_from_file
7
-
8
- # 環境変数から取得
9
- MODEL_REPO = os.getenv("MODEL_REPO")
10
- HF_TOKEN = os.getenv("HF_TOKEN")
11
- CACHE_DIR = "/tmp/hf_cache"
12
-
13
- # モデルとデバイスをグローバル変数として用意
14
- model = None
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- def load_model():
18
- global model
19
- # Hugging Faceからモデルファイルをダウンロード
20
- config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR)
21
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR)
22
- style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR)
23
-
24
- # configをロード
25
- hps = get_hparams_from_file(config_path)
26
-
27
- # モデルを初期化
28
- model = SynthesizerTrn(
29
- n_vocab=70, # 仮設定(※symbolsが無いため一般的な日本語TTS想定)
30
- spec_channels=hps["model"].get("spec_channels", 80),
31
- segment_size=None,
32
- inter_channels=hps["model"]["hidden_channels"],
33
- hidden_channels=hps["model"]["hidden_channels"],
34
- filter_channels=hps["model"]["filter_channels"],
35
- n_heads=hps["model"]["n_heads"],
36
- n_layers=int(hps["model"]["encoder_n_layers"]),
37
- kernel_size=hps["model"]["encoder_kernel_size"],
38
- p_dropout=hps["model"]["dropout"],
39
- resblock=str(hps["model"].get("resblock", 2)),
40
- resblock_kernel_sizes=hps["model"]["resblock_kernel_sizes"],
41
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
42
- upsample_rates=hps["model"]["upsample_rates"],
43
- upsample_initial_channel=512, # 通常512固定
44
- upsample_kernel_sizes=hps["model"]["upsample_kernel_sizes"],
45
- gin_channels=hps["model"]["gin_channels"],
46
- out_channels=hps["model"].get("spec_channels", 80),
47
- dec_kernel_size=hps["model"]["encoder_kernel_size"],
48
- enc_channels=hps["model"]["encoder_hidden"],
49
- enc_out_channels=hps["model"]["encoder_hidden"] * 2,
50
- enc_kernel_size=hps["model"]["encoder_kernel_size"],
51
- enc_dilation_rate=hps["model"].get("enc_dilation_rate", 1),
52
- enc_n_layers=int(hps["model"]["encoder_n_layers"]),
53
- flow_hidden_channels=hps["model"]["hidden_channels"],
54
- flow_kernel_size=hps["model"]["flow_kernel_size"],
55
- flow_n_layers=int(hps["model"]["flow_n_layers"]),
56
- flow_n_flows=int(hps["model"]["flow_n_flows"]),
57
- sdp_hidden_channels=hps["model"]["sdp_filter_channels"],
58
- sdp_kernel_size=hps["model"]["sdp_kernel_size"],
59
- sdp_n_layers=int(hps["model"]["sdp_n_layers"]),
60
- sdp_dropout=hps["model"]["sdp_dropout"],
61
- sampling_rate=hps["data"]["sampling_rate"],
62
- filter_length=1024,
63
- hop_length=256,
64
- win_length=1024,
65
- ).to(device)
66
-
67
- # safetensorsで重み読み込み
68
- from safetensors.torch import load_file
69
- model_sd = load_file(model_path)
70
- model.load_state_dict(model_sd, strict=True)
71
- model.eval()
72
-
73
- def synthesize_voice(text):
74
- # 推論を実行
75
- x = torch.LongTensor(text_to_sequence(text, ['basic_cleaners'])).unsqueeze(0).to(device)
76
- x_lengths = torch.LongTensor([x.size(1)]).to(device)
77
- sid = torch.LongTensor([0]).to(device)
78
-
79
- with torch.no_grad():
80
- audio = model.infer(x, x_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0)[0][0, 0].cpu().numpy()
81
- return audio
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ from scipy.io import wavfile
6
+ from huggingface_hub import hf_hub_download
7
+ from src.sbv2.synthesizer_trn import SynthesizerTrn
8
+ from src.sbv2 import commons
9
+ from src.sbv2 import utils
10
+ from src.sbv2.text import text_to_sequence
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model = None
14
+ hps = None
15
+
16
+ MODEL_REPO = os.getenv("MODEL_REPO")
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ CACHE_DIR = "./models"
19
+
20
+ def load_model():
21
+ global model, hps
22
+
23
+ config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR)
24
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR)
25
+ style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR)
26
+
27
+ with open(config_path, "r", encoding="utf-8") as f:
28
+ hps = json.load(f)
29
+
30
+ symbols = hps["symbols"]
31
+
32
+ n_speakers = hps["data"].get("n_speakers", 0)
33
+
34
+ model = SynthesizerTrn(
35
+ len(symbols),
36
+ hps["model"]["inter_channels"],
37
+ hps["model"]["hidden_channels"],
38
+ hps["model"]["filter_channels"],
39
+ hps["model"]["n_heads"],
40
+ hps["model"]["n_layers"],
41
+ hps["model"]["kernel_size"],
42
+ hps["model"]["p_dropout"],
43
+ resblock=hps["model"]["resblock"],
44
+ resblock_kernel_sizes=hps["model"]["resblock_kernel_sizes"],
45
+ resblock_dilation_sizes=hps["model"]["resblock_dilation_sizes"],
46
+ upsample_rates=hps["model"]["upsample_rates"],
47
+ upsample_initial_channel=hps["model"]["upsample_initial_channel"],
48
+ upsample_kernel_sizes=hps["model"]["upsample_kernel_sizes"],
49
+ gin_channels=hps["model"].get("gin_channels", 0),
50
+ n_speakers=n_speakers,
51
+ use_spectral_norm=hps["model"].get("use_spectral_norm", False)
52
+ ).to(device)
53
+
54
+ _ = utils.load_checkpoint(model_path, model, None, strict=True)
55
+ model.eval()
56
+
57
+ print("✅ Model loaded successfully (strict=True).")
58
+
59
+
60
+ def synthesize(text):
61
+ global model, hps
62
+
63
+ if model is None or hps is None:
64
+ raise RuntimeError("Model not loaded!")
65
+
66
+ stn_tst = torch.LongTensor(text_to_sequence(text, hps["data"]["text_cleaners"], hps["data"].get("cleaned_text", True))).unsqueeze(0).to(device)
67
+
68
+ with torch.no_grad():
69
+ x_tst_lengths = torch.LongTensor([stn_tst.size(1)]).to(device)
70
+ sid = torch.LongTensor([0]).to(device) if hps["data"].get("n_speakers", 0) > 0 else None
71
+
72
+ audio = model.infer(
73
+ stn_tst,
74
+ x_tst_lengths,
75
+ sid=sid,
76
+ noise_scale=0.667,
77
+ noise_scale_w=0.8,
78
+ length_scale=1.0
79
+ )[0][0, 0].data.cpu().float().numpy()
80
+
81
+ return audio