Replaced Encodec with Vocos
Browse files- app.py +56 -62
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import logging
|
| 3 |
import os
|
| 4 |
import pathlib
|
|
@@ -19,7 +18,6 @@ langid.set_languages(['en', 'zh', 'ja'])
|
|
| 19 |
|
| 20 |
import torch
|
| 21 |
import torchaudio
|
| 22 |
-
import random
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
|
|
@@ -35,7 +33,8 @@ from macros import *
|
|
| 35 |
from examples import *
|
| 36 |
|
| 37 |
import gradio as gr
|
| 38 |
-
import
|
|
|
|
| 39 |
|
| 40 |
torch._C._jit_set_profiling_executor(False)
|
| 41 |
torch._C._jit_set_profiling_mode(False)
|
|
@@ -72,8 +71,13 @@ model.eval()
|
|
| 72 |
# Encodec model
|
| 73 |
audio_tokenizer = AudioTokenizer(device)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
# ASR
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Voice Presets
|
| 79 |
preset_list = os.walk("./presets/").__next__()[2]
|
|
@@ -89,34 +93,33 @@ def clear_prompts():
|
|
| 89 |
endfiletime = time.time() - 60
|
| 90 |
if endfiletime > lastmodifytime:
|
| 91 |
os.remove(filename)
|
|
|
|
|
|
|
| 92 |
except:
|
| 93 |
return
|
| 94 |
|
| 95 |
-
def transcribe_one(
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
| 102 |
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
|
| 109 |
-
result = whisper.decode(model, mel, options)
|
| 110 |
|
| 111 |
# print the recognized text
|
| 112 |
-
print(
|
| 113 |
|
| 114 |
-
text_pr = result.text
|
| 115 |
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
|
| 116 |
text_pr += "."
|
| 117 |
|
| 118 |
# delete all variables
|
| 119 |
-
del
|
| 120 |
gc.collect()
|
| 121 |
return lang, text_pr
|
| 122 |
|
|
@@ -137,7 +140,7 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
| 137 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 138 |
|
| 139 |
if transcript_content == "":
|
| 140 |
-
|
| 141 |
else:
|
| 142 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 143 |
lang_token = lang2token[lang_pr]
|
|
@@ -147,6 +150,8 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
| 147 |
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
| 148 |
|
| 149 |
# tokenize text
|
|
|
|
|
|
|
| 150 |
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 151 |
text_tokens, enroll_x_lens = text_collater(
|
| 152 |
[
|
|
@@ -155,6 +160,8 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
| 155 |
)
|
| 156 |
|
| 157 |
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
|
|
|
|
|
|
|
| 158 |
|
| 159 |
# save as npz file
|
| 160 |
np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
|
|
@@ -166,30 +173,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
|
| 166 |
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
|
| 167 |
|
| 168 |
|
| 169 |
-
def make_prompt(name, wav, sr, save=True):
|
| 170 |
-
if not isinstance(wav, torch.FloatTensor):
|
| 171 |
-
wav = torch.tensor(wav)
|
| 172 |
-
if wav.abs().max() > 1:
|
| 173 |
-
wav /= wav.abs().max()
|
| 174 |
-
if wav.size(-1) == 2:
|
| 175 |
-
wav = wav.mean(-1, keepdim=False)
|
| 176 |
-
if wav.ndim == 1:
|
| 177 |
-
wav = wav.unsqueeze(0)
|
| 178 |
-
assert wav.ndim and wav.size(0) == 1
|
| 179 |
-
torchaudio.save(f"./prompts/{name}.wav", wav, sr)
|
| 180 |
-
lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
|
| 181 |
-
lang_token = lang2token[lang]
|
| 182 |
-
text = lang_token + text + lang_token
|
| 183 |
-
with open(f"./prompts/{name}.txt", 'w') as f:
|
| 184 |
-
f.write(text)
|
| 185 |
-
if not save:
|
| 186 |
-
os.remove(f"./prompts/{name}.wav")
|
| 187 |
-
os.remove(f"./prompts/{name}.txt")
|
| 188 |
-
# delete all variables
|
| 189 |
-
del lang_token, wav, sr
|
| 190 |
-
gc.collect()
|
| 191 |
-
return text, lang
|
| 192 |
-
|
| 193 |
@torch.no_grad()
|
| 194 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
| 195 |
if len(text) > 150:
|
|
@@ -209,7 +192,7 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
|
|
| 209 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 210 |
|
| 211 |
if transcript_content == "":
|
| 212 |
-
|
| 213 |
else:
|
| 214 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 215 |
lang_token = lang2token[lang_pr]
|
|
@@ -222,6 +205,9 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
|
|
| 222 |
lang = token2lang[lang_token]
|
| 223 |
text = lang_token + text + lang_token
|
| 224 |
|
|
|
|
|
|
|
|
|
|
| 225 |
# tokenize audio
|
| 226 |
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 227 |
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
|
|
@@ -237,6 +223,8 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
|
|
| 237 |
|
| 238 |
enroll_x_lens = None
|
| 239 |
if text_pr:
|
|
|
|
|
|
|
| 240 |
text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 241 |
text_prompts, enroll_x_lens = text_collater(
|
| 242 |
[
|
|
@@ -256,15 +244,16 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
|
|
| 256 |
prompt_language=lang_pr,
|
| 257 |
text_language=langs if accent == "no-accent" else lang,
|
| 258 |
)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
)
|
|
|
|
| 262 |
|
| 263 |
message = f"text prompt: {text_pr}\nsythesized text: {text}"
|
| 264 |
# delete all variables
|
| 265 |
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
|
| 266 |
gc.collect()
|
| 267 |
-
return message, (24000, samples
|
| 268 |
|
| 269 |
@torch.no_grad()
|
| 270 |
def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
|
|
@@ -315,16 +304,17 @@ def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
|
|
| 315 |
prompt_language=lang_pr,
|
| 316 |
text_language=langs if accent == "no-accent" else lang,
|
| 317 |
)
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
)
|
|
|
|
| 321 |
|
| 322 |
message = f"sythesized text: {text}"
|
| 323 |
|
| 324 |
# delete all variables
|
| 325 |
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
|
| 326 |
gc.collect()
|
| 327 |
-
return message, (24000, samples
|
| 328 |
|
| 329 |
|
| 330 |
from utils.sentence_cutter import split_text_into_sentences
|
|
@@ -407,11 +397,13 @@ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='n
|
|
| 407 |
text_language=langs if accent == "no-accent" else lang,
|
| 408 |
)
|
| 409 |
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
)
|
|
|
|
|
|
|
| 413 |
message = f"Cut into {len(sentences)} sentences"
|
| 414 |
-
return message, (24000, samples
|
| 415 |
elif mode == "sliding-window":
|
| 416 |
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
|
| 417 |
original_audio_prompts = audio_prompts
|
|
@@ -453,12 +445,14 @@ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='n
|
|
| 453 |
else:
|
| 454 |
audio_prompts = original_audio_prompts
|
| 455 |
text_prompts = original_text_prompts
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
)
|
|
|
|
|
|
|
| 459 |
message = f"Cut into {len(sentences)} sentences"
|
| 460 |
|
| 461 |
-
return message, (24000, samples
|
| 462 |
else:
|
| 463 |
raise ValueError(f"No such mode {mode}")
|
| 464 |
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import pathlib
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torchaudio
|
|
|
|
| 21 |
|
| 22 |
import numpy as np
|
| 23 |
|
|
|
|
| 33 |
from examples import *
|
| 34 |
|
| 35 |
import gradio as gr
|
| 36 |
+
from vocos import Vocos
|
| 37 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 38 |
|
| 39 |
torch._C._jit_set_profiling_executor(False)
|
| 40 |
torch._C._jit_set_profiling_mode(False)
|
|
|
|
| 71 |
# Encodec model
|
| 72 |
audio_tokenizer = AudioTokenizer(device)
|
| 73 |
|
| 74 |
+
# Vocos decoder
|
| 75 |
+
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
|
| 76 |
+
|
| 77 |
# ASR
|
| 78 |
+
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
| 79 |
+
whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
|
| 80 |
+
whisper.config.forced_decoder_ids = None
|
| 81 |
|
| 82 |
# Voice Presets
|
| 83 |
preset_list = os.walk("./presets/").__next__()[2]
|
|
|
|
| 93 |
endfiletime = time.time() - 60
|
| 94 |
if endfiletime > lastmodifytime:
|
| 95 |
os.remove(filename)
|
| 96 |
+
del path, filename, lastmodifytime, endfiletime
|
| 97 |
+
gc.collect()
|
| 98 |
except:
|
| 99 |
return
|
| 100 |
|
| 101 |
+
def transcribe_one(wav, sr):
|
| 102 |
+
if sr != 16000:
|
| 103 |
+
wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
|
| 104 |
+
else:
|
| 105 |
+
wav4trans = wav
|
| 106 |
|
| 107 |
+
input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
|
|
|
|
| 108 |
|
| 109 |
+
# generate token ids
|
| 110 |
+
predicted_ids = whisper.generate(input_features.to(device))
|
| 111 |
+
lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
|
| 112 |
+
# decode token ids to text
|
| 113 |
+
text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# print the recognized text
|
| 116 |
+
print(text_pr)
|
| 117 |
|
|
|
|
| 118 |
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
|
| 119 |
text_pr += "."
|
| 120 |
|
| 121 |
# delete all variables
|
| 122 |
+
del wav4trans, input_features, predicted_ids
|
| 123 |
gc.collect()
|
| 124 |
return lang, text_pr
|
| 125 |
|
|
|
|
| 140 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 141 |
|
| 142 |
if transcript_content == "":
|
| 143 |
+
lang_pr, text_pr = transcribe_one(wav_pr, sr)
|
| 144 |
else:
|
| 145 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 146 |
lang_token = lang2token[lang_pr]
|
|
|
|
| 150 |
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
| 151 |
|
| 152 |
# tokenize text
|
| 153 |
+
lang_token = lang2token[lang_pr]
|
| 154 |
+
text_pr = lang_token + text_pr + lang_token
|
| 155 |
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 156 |
text_tokens, enroll_x_lens = text_collater(
|
| 157 |
[
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
|
| 163 |
+
if lang_pr not in ['ja', 'zh', 'en']:
|
| 164 |
+
return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
|
| 165 |
|
| 166 |
# save as npz file
|
| 167 |
np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
|
|
|
|
| 173 |
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
|
| 174 |
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
@torch.no_grad()
|
| 177 |
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
|
| 178 |
if len(text) > 150:
|
|
|
|
| 192 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 193 |
|
| 194 |
if transcript_content == "":
|
| 195 |
+
lang_pr, text_pr = transcribe_one(wav_pr, sr)
|
| 196 |
else:
|
| 197 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 198 |
lang_token = lang2token[lang_pr]
|
|
|
|
| 205 |
lang = token2lang[lang_token]
|
| 206 |
text = lang_token + text + lang_token
|
| 207 |
|
| 208 |
+
if lang_pr not in ['ja', 'zh', 'en']:
|
| 209 |
+
return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
|
| 210 |
+
|
| 211 |
# tokenize audio
|
| 212 |
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 213 |
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
|
|
|
|
| 223 |
|
| 224 |
enroll_x_lens = None
|
| 225 |
if text_pr:
|
| 226 |
+
lang_token = lang2token[lang_pr]
|
| 227 |
+
text_pr = lang_token + text_pr + lang_token
|
| 228 |
text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 229 |
text_prompts, enroll_x_lens = text_collater(
|
| 230 |
[
|
|
|
|
| 244 |
prompt_language=lang_pr,
|
| 245 |
text_language=langs if accent == "no-accent" else lang,
|
| 246 |
)
|
| 247 |
+
# Decode with Vocos
|
| 248 |
+
frames = encoded_frames.permute(2,0,1)
|
| 249 |
+
features = vocos.codes_to_features(frames)
|
| 250 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 251 |
|
| 252 |
message = f"text prompt: {text_pr}\nsythesized text: {text}"
|
| 253 |
# delete all variables
|
| 254 |
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
|
| 255 |
gc.collect()
|
| 256 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
| 257 |
|
| 258 |
@torch.no_grad()
|
| 259 |
def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
|
|
|
|
| 304 |
prompt_language=lang_pr,
|
| 305 |
text_language=langs if accent == "no-accent" else lang,
|
| 306 |
)
|
| 307 |
+
# Decode with Vocos
|
| 308 |
+
frames = encoded_frames.permute(2,0,1)
|
| 309 |
+
features = vocos.codes_to_features(frames)
|
| 310 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 311 |
|
| 312 |
message = f"sythesized text: {text}"
|
| 313 |
|
| 314 |
# delete all variables
|
| 315 |
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
|
| 316 |
gc.collect()
|
| 317 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
| 318 |
|
| 319 |
|
| 320 |
from utils.sentence_cutter import split_text_into_sentences
|
|
|
|
| 397 |
text_language=langs if accent == "no-accent" else lang,
|
| 398 |
)
|
| 399 |
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
|
| 400 |
+
# Decode with Vocos
|
| 401 |
+
frames = encoded_frames.permute(2, 0, 1)
|
| 402 |
+
features = vocos.codes_to_features(frames)
|
| 403 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 404 |
+
|
| 405 |
message = f"Cut into {len(sentences)} sentences"
|
| 406 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
| 407 |
elif mode == "sliding-window":
|
| 408 |
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
|
| 409 |
original_audio_prompts = audio_prompts
|
|
|
|
| 445 |
else:
|
| 446 |
audio_prompts = original_audio_prompts
|
| 447 |
text_prompts = original_text_prompts
|
| 448 |
+
# Decode with Vocos
|
| 449 |
+
frames = encoded_frames.permute(2, 0, 1)
|
| 450 |
+
features = vocos.codes_to_features(frames)
|
| 451 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 452 |
+
|
| 453 |
message = f"Cut into {len(sentences)} sentences"
|
| 454 |
|
| 455 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
| 456 |
else:
|
| 457 |
raise ValueError(f"No such mode {mode}")
|
| 458 |
|
requirements.txt
CHANGED
|
@@ -5,6 +5,7 @@ torchvision==0.15.2
|
|
| 5 |
torchaudio
|
| 6 |
tokenizers
|
| 7 |
encodec
|
|
|
|
| 8 |
langid
|
| 9 |
unidecode
|
| 10 |
pyopenjtalk
|
|
|
|
| 5 |
torchaudio
|
| 6 |
tokenizers
|
| 7 |
encodec
|
| 8 |
+
vocos
|
| 9 |
langid
|
| 10 |
unidecode
|
| 11 |
pyopenjtalk
|