Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,8 @@ from collections import OrderedDict
|
|
| 16 |
from onmt_modules.misc import sequence_mask
|
| 17 |
from model_autopst import Generator_2 as Predictor
|
| 18 |
from hparams_autopst import hparams
|
|
|
|
|
|
|
| 19 |
|
| 20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
|
|
@@ -71,6 +73,10 @@ model = build_model().to(device)
|
|
| 71 |
checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
|
| 72 |
model.load_state_dict(checkpoint["state_dict"])
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# for name, sp in spect_vc.items():
|
| 75 |
|
| 76 |
# print(name)
|
|
@@ -81,57 +87,164 @@ model.load_state_dict(checkpoint["state_dict"])
|
|
| 81 |
|
| 82 |
|
| 83 |
|
| 84 |
-
def respond(
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
):
|
| 92 |
-
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
|
| 100 |
-
|
| 101 |
|
| 102 |
-
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
|
| 116 |
"""
|
| 117 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 118 |
"""
|
| 119 |
-
demo = gr.ChatInterface(
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
],
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
|
|
|
|
| 136 |
if __name__ == "__main__":
|
| 137 |
demo.launch()
|
|
|
|
| 16 |
from onmt_modules.misc import sequence_mask
|
| 17 |
from model_autopst import Generator_2 as Predictor
|
| 18 |
from hparams_autopst import hparams
|
| 19 |
+
from model_sea import Generator
|
| 20 |
+
from hparams_sea import hparams as sea_hparams
|
| 21 |
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
|
|
|
|
| 73 |
checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename="checkpoint_step001000000_ema.pth"), map_location=torch.device('cpu'))
|
| 74 |
model.load_state_dict(checkpoint["state_dict"])
|
| 75 |
|
| 76 |
+
# sea_checkpoint = torch.load(hf_hub_download(repo_id="jonathanjordan21/AutoPST", filename='sea.ckpt'), map_location=lambda storage, loc: storage)
|
| 77 |
+
# gen =Generator(sea_hparams)
|
| 78 |
+
# gen.load_state_dict(sea_checkpoint['model'], strict=True)
|
| 79 |
+
|
| 80 |
# for name, sp in spect_vc.items():
|
| 81 |
|
| 82 |
# print(name)
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
|
| 90 |
+
# def respond(
|
| 91 |
+
# message,
|
| 92 |
+
# history: list[tuple[str, str]],
|
| 93 |
+
# system_message,
|
| 94 |
+
# max_tokens,
|
| 95 |
+
# temperature,
|
| 96 |
+
# top_p,
|
| 97 |
+
# ):
|
| 98 |
+
# messages = [{"role": "system", "content": system_message}]
|
| 99 |
|
| 100 |
+
# for val in history:
|
| 101 |
+
# if val[0]:
|
| 102 |
+
# messages.append({"role": "user", "content": val[0]})
|
| 103 |
+
# if val[1]:
|
| 104 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
| 105 |
|
| 106 |
+
# messages.append({"role": "user", "content": message})
|
| 107 |
|
| 108 |
+
# response = ""
|
| 109 |
|
| 110 |
+
# for message in client.chat_completion(
|
| 111 |
+
# messages,
|
| 112 |
+
# max_tokens=max_tokens,
|
| 113 |
+
# stream=True,
|
| 114 |
+
# temperature=temperature,
|
| 115 |
+
# top_p=top_p,
|
| 116 |
+
# ):
|
| 117 |
+
# token = message.choices[0].delta.content
|
| 118 |
|
| 119 |
+
# response += token
|
| 120 |
+
# yield response
|
| 121 |
|
| 122 |
"""
|
| 123 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 124 |
"""
|
| 125 |
+
# demo = gr.ChatInterface(
|
| 126 |
+
# respond,
|
| 127 |
+
# additional_inputs=[
|
| 128 |
+
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 129 |
+
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 130 |
+
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 131 |
+
# gr.Slider(
|
| 132 |
+
# minimum=0.1,
|
| 133 |
+
# maximum=1.0,
|
| 134 |
+
# value=0.95,
|
| 135 |
+
# step=0.05,
|
| 136 |
+
# label="Top-p (nucleus sampling)",
|
| 137 |
+
# ),
|
| 138 |
+
# ],
|
| 139 |
+
# )
|
| 140 |
+
|
| 141 |
+
import os
|
| 142 |
+
import pickle
|
| 143 |
+
import numpy as np
|
| 144 |
+
import soundfile as sf
|
| 145 |
+
from scipy import signal
|
| 146 |
+
from scipy.signal import get_window
|
| 147 |
+
from librosa.filters import mel
|
| 148 |
+
from numpy.random import RandomState
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def butter_highpass(cutoff, fs, order=5):
|
| 152 |
+
nyq = 0.5 * fs
|
| 153 |
+
normal_cutoff = cutoff / nyq
|
| 154 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
| 155 |
+
return b, a
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def pySTFT(x, fft_length=1024, hop_length=256):
|
| 159 |
+
|
| 160 |
+
x = np.pad(x, int(fft_length//2), mode='reflect')
|
| 161 |
+
|
| 162 |
+
noverlap = fft_length - hop_length
|
| 163 |
+
shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
|
| 164 |
+
strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
|
| 165 |
+
result = np.lib.stride_tricks.as_strided(x, shape=shape,
|
| 166 |
+
strides=strides)
|
| 167 |
+
|
| 168 |
+
fft_window = get_window('hann', fft_length, fftbins=True)
|
| 169 |
+
result = np.fft.rfft(fft_window * result, n=fft_length).T
|
| 170 |
+
|
| 171 |
+
return np.abs(result)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def create_sp(cep_real, spk_emb):
|
| 175 |
+
# cep_real, spk_emb = dict_test[uttr[0]][uttr[2]]
|
| 176 |
+
cep_real_A = torch.from_numpy(cep_real).unsqueeze(0).to(device)
|
| 177 |
+
len_real_A = torch.tensor(cep_real_A.size(1)).unsqueeze(0).to(device)
|
| 178 |
+
real_mask_A = sequence_mask(len_real_A, cep_real_A.size(1)).float()
|
| 179 |
+
|
| 180 |
+
# _, spk_emb = dict_test[uttr[1]][uttr[2]]
|
| 181 |
+
spk_emb_B = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
| 182 |
+
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
spect_output, len_spect = P.infer_onmt(cep_real_A.transpose(2,1)[:,:14,:],
|
| 185 |
+
real_mask_A,
|
| 186 |
+
len_real_A,
|
| 187 |
+
spk_emb_B)
|
| 188 |
+
|
| 189 |
+
uttr_tgt = spect_output[:len_spect[0],0,:].cpu().numpy()
|
| 190 |
+
return uttr_tgt
|
| 191 |
+
|
| 192 |
+
def create_mel(x):
|
| 193 |
+
mel_basis = mel(sr=16000, n_fft=1024, fmin=90, fmax=7600, n_mels=80).T
|
| 194 |
+
min_level = np.exp(-100 / 20 * np.log(10))
|
| 195 |
+
b, a = butter_highpass(30, 16000, order=5)
|
| 196 |
+
|
| 197 |
+
mfcc_mean, mfcc_std, dctmx = pickle.load(open('assets/mfcc_stats.pkl', 'rb'))
|
| 198 |
+
spk2emb = pickle.load(open('assets/spk2emb_82.pkl', 'rb'))
|
| 199 |
+
|
| 200 |
+
if x.shape[0] % 256 == 0:
|
| 201 |
+
x = np.concatenate((x, np.array([1e-06])), axis=0)
|
| 202 |
+
y = signal.filtfilt(b, a, x)
|
| 203 |
+
D = pySTFT(y * 0.96).T
|
| 204 |
+
D_mel = np.dot(D, mel_basis)
|
| 205 |
+
D_db = 20 * np.log10(np.maximum(min_level, D_mel))
|
| 206 |
+
|
| 207 |
+
# mel sp
|
| 208 |
+
S = (D_db + 80) / 100
|
| 209 |
+
|
| 210 |
+
# mel cep
|
| 211 |
+
cc_tmp = S.dot(dctmx)
|
| 212 |
+
cc_norm = (cc_tmp - mfcc_mean) / mfcc_std
|
| 213 |
+
S = np.clip(S, 0, 1)
|
| 214 |
+
|
| 215 |
+
# teacher code
|
| 216 |
+
# cc_torch = torch.from_numpy(cc_norm[:,0:20].astype(np.float32)).unsqueeze(0).to(device)
|
| 217 |
+
# with torch.no_grad():
|
| 218 |
+
# codes = gen.encode(cc_torch, torch.ones_like(cc_torch[:,:,0])).squeeze(0)
|
| 219 |
+
return S, cc_norm
|
| 220 |
+
|
| 221 |
+
def transcribe(audio, spk):
|
| 222 |
+
sr, y = audio
|
| 223 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
|
| 224 |
+
y = y.astype(np.float32)
|
| 225 |
+
y /= np.max(np.abs(y))
|
| 226 |
+
|
| 227 |
+
spk_emb = np.zeros((82,))
|
| 228 |
+
spk_emb[spk-1] = 1
|
| 229 |
+
|
| 230 |
+
mel_sp, mel_cep = create_mel(y)
|
| 231 |
+
sp = create_sp(mel_cep, spk_emb)
|
| 232 |
+
waveform = wavegen(model, c=sp)
|
| 233 |
+
return 16000, waveform.numpy()
|
| 234 |
+
|
| 235 |
+
# return transcriber({"sampling_rate": sr, "raw": y})["text"]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
demo = gr.Interface(
|
| 239 |
+
transcribe,
|
| 240 |
+
[
|
| 241 |
+
gr.Audio(),
|
| 242 |
+
gr.Slider(1, 82, value=21, label="Count", info="Choose between 1 and 82")
|
| 243 |
],
|
| 244 |
+
"audio",
|
| 245 |
)
|
| 246 |
|
| 247 |
|
| 248 |
+
|
| 249 |
if __name__ == "__main__":
|
| 250 |
demo.launch()
|