buchi-stdesign commited on
Commit
94766d1
·
verified ·
1 Parent(s): f6f2d6e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +55 -43
inference.py CHANGED
@@ -1,77 +1,89 @@
1
  import os
2
  import torch
3
  import numpy as np
4
- import json
5
- from scipy.io import wavfile
 
6
  from huggingface_hub import hf_hub_download
7
- from src.sbv2.synthesizer_trn import SynthesizerTrn
8
- from src.sbv2 import commons
9
  from src.sbv2 import utils
 
10
  from src.sbv2.text import text_to_sequence
11
 
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model = None
14
- hps = None
15
-
16
  MODEL_REPO = os.getenv("MODEL_REPO")
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  CACHE_DIR = "/tmp/models"
19
 
 
 
 
 
20
  def load_model():
21
  global model, hps
22
 
 
23
  config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR)
24
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR)
25
  style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR)
26
 
 
 
27
  with open(config_path, "r", encoding="utf-8") as f:
28
  hps = json.load(f)
29
 
30
- n_vocab = 77
31
 
32
  model = SynthesizerTrn(
33
  n_vocab,
 
 
34
  hps["model"]["inter_channels"],
 
35
  hps["model"]["hidden_channels"],
36
  hps["model"]["filter_channels"],
37
- hps["model"]["n_heads"],
38
- hps["model"]["n_layers"],
39
- hps["model"]["kernel_size"],
40
- hps["model"]["p_dropout"],
41
- resblock=hps["model"]["resblock"],
42
- resblock_kernel_sizes=hps["model"]["resblock_kernel_sizes"],
43
- resblock_dilation_sizes=hps["model"]["resblock_dilation_sizes"],
44
- upsample_rates=hps["model"]["upsample_rates"],
45
- upsample_initial_channel=hps["model"]["upsample_initial_channel"],
46
- upsample_kernel_sizes=hps["model"]["upsample_kernel_sizes"],
47
- gin_channels=hps["model"].get("gin_channels", 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ).to(device)
49
 
50
- _ = utils.load_checkpoint(model_path, model, None, strict=True)
 
51
  model.eval()
52
 
53
- print("✅ Model loaded successfully (strict=True).")
54
-
55
-
56
- def synthesize(text):
57
- global model, hps
58
-
59
- if model is None or hps is None:
60
- raise RuntimeError("Model not loaded!")
61
-
62
- stn_tst = torch.LongTensor(text_to_sequence(text, hps["data"]["text_cleaners"], hps["data"].get("cleaned_text", True))).unsqueeze(0).to(device)
63
 
 
64
  with torch.no_grad():
65
- x_tst_lengths = torch.LongTensor([stn_tst.size(1)]).to(device)
66
- sid = torch.LongTensor([0]).to(device) if hps["data"].get("n_speakers", 0) > 0 else None
67
-
68
- audio = model.infer(
69
- stn_tst,
70
- x_tst_lengths,
71
- sid=sid,
72
- noise_scale=0.667,
73
- noise_scale_w=0.8,
74
- length_scale=1.0
75
- )[0][0, 0].data.cpu().float().numpy()
76
 
77
- return audio
 
 
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
4
+ import soundfile as sf
5
+
6
+ from fastapi import FastAPI
7
  from huggingface_hub import hf_hub_download
8
+
 
9
  from src.sbv2 import utils
10
+ from src.sbv2.synthesizer_trn import SynthesizerTrn
11
  from src.sbv2.text import text_to_sequence
12
 
 
 
 
 
13
  MODEL_REPO = os.getenv("MODEL_REPO")
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  CACHE_DIR = "/tmp/models"
16
 
17
+ app = FastAPI()
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
  def load_model():
22
  global model, hps
23
 
24
+ # config.json と model.safetensors と style_vectors.npy をダウンロード
25
  config_path = hf_hub_download(repo_id=MODEL_REPO, filename="config.json", token=HF_TOKEN, cache_dir=CACHE_DIR)
26
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors", token=HF_TOKEN, cache_dir=CACHE_DIR)
27
  style_path = hf_hub_download(repo_id=MODEL_REPO, filename="style_vectors.npy", token=HF_TOKEN, cache_dir=CACHE_DIR)
28
 
29
+ # configロード
30
+ import json
31
  with open(config_path, "r", encoding="utf-8") as f:
32
  hps = json.load(f)
33
 
34
+ n_vocab = 77 # 小春音アミ用 symbol数
35
 
36
  model = SynthesizerTrn(
37
  n_vocab,
38
+ hps["model"]["p_dropout"],
39
+ hps["data"]["segment_size"] // 2,
40
  hps["model"]["inter_channels"],
41
+ hps["model"]["out_channels"],
42
  hps["model"]["hidden_channels"],
43
  hps["model"]["filter_channels"],
44
+ hps["model"]["dec_kernel_size"],
45
+ hps["model"]["enc_channels"],
46
+ hps["model"]["enc_out_channels"],
47
+ hps["model"]["enc_kernel_size"],
48
+ hps["model"]["enc_dilation_rate"],
49
+ hps["model"]["enc_n_layers"],
50
+ hps["model"]["flow_hidden_channels"],
51
+ hps["model"]["flow_kernel_size"],
52
+ hps["model"]["flow_n_layers"],
53
+ hps["model"]["flow_n_flows"],
54
+ hps["model"]["sdp_hidden_channels"],
55
+ hps["model"]["sdp_kernel_size"],
56
+ hps["model"]["sdp_n_layers"],
57
+ hps["model"]["sdp_dropout"],
58
+ hps["audio"]["sampling_rate"],
59
+ hps["audio"]["filter_length"],
60
+ hps["audio"]["hop_length"],
61
+ hps["audio"]["win_length"],
62
+ hps["model"]["resblock"],
63
+ hps["model"]["resblock_kernel_sizes"],
64
+ hps["model"]["resblock_dilation_sizes"],
65
+ hps["model"]["upsample_rates"],
66
+ hps["model"]["upsample_initial_channel"],
67
+ hps["model"]["upsample_kernel_sizes"],
68
+ hps["model"].get("gin_channels", 0)
69
  ).to(device)
70
 
71
+ # safetensorsロード
72
+ utils.load_checkpoint(model_path, model, strict=True)
73
  model.eval()
74
 
75
+ @app.get("/voice")
76
+ def synthesize(text: str):
77
+ # テキストを音素に変換
78
+ sequence = np.array(text_to_sequence(text, hps["data"]["text_cleaners"]), dtype=np.int64)
79
+ sequence = torch.LongTensor(sequence).unsqueeze(0).to(device)
 
 
 
 
 
80
 
81
+ # 推論
82
  with torch.no_grad():
83
+ audio = model.infer(sequence, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0)[0][0, 0].data.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # 一時WAVファイル保存
86
+ output_path = "/tmp/output.wav"
87
+ sf.write(output_path, audio, hps["audio"]["sampling_rate"])
88
+
89
+ return {"audio_path": output_path}