nukopy
fix: logging format
3355fb8
import logging
import multiprocessing
import os
import pathlib
import platform
import sys
import tempfile
import time
import gradio as gr
import langid
import nltk
import numpy as np
import spaces
import torch
import torchaudio
import whisper
from vocos import Vocos
from .data.collation import get_text_token_collater
from .data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from .descriptions import infer_from_audio_ja_md, top_ja_md
from .examples import infer_from_audio_examples
from .g2p import PhonemeBpeTokenizer
from .macros import (
N_DIM,
NUM_HEAD,
NUM_LAYERS,
NUM_QUANTIZERS,
PREFIX_MODE,
lang2code,
lang2token,
langdropdown2token,
token2lang,
)
from .models.vallex import VALLE
logger = logging.getLogger(__name__)
# set base directory
OUTPUT_BASE_DIR = os.getenv("HF_HOME", ".")
PREPARED_BASE_DIR = "."
print(f"Base directory: {OUTPUT_BASE_DIR}")
print(f"Prepared base directory: {PREPARED_BASE_DIR}")
# set languages
langid.set_languages(["en", "zh", "ja"])
# set nltk data path
nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")]
print(f"nltk_data path: {nltk.data.path}")
# get encoding
print(
"default encoding is "
f"{sys.getdefaultencoding()},"
f"file system encoding is {sys.getfilesystemencoding()}"
)
# check python version
print(f"You are using Python version {platform.python_version()}")
if sys.version_info[0] < 3 or sys.version_info[1] < 7:
logger.warning("The Python version is too low and may cause problems")
if platform.system().lower() == "windows":
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
else:
temp = pathlib.WindowsPath
pathlib.WindowsPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# set torch threads (guarded for hot-reload)
thread_count = multiprocessing.cpu_count()
print(f"Use {thread_count} cpu cores for computing")
if not getattr(torch, "_vallex_threads_configured", False):
torch.set_num_threads(thread_count)
try:
torch.set_num_interop_threads(thread_count)
except RuntimeError as err:
logger.warning("Skipping set_num_interop_threads: %s", err)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
# gradio のリロード時に torch.set_num_iterop_threads を実行するとエラーになるので、設定済みのフラグをセット
setattr(torch, "_vallex_threads_configured", True)
else:
print("Torch threads already configured; skipping reconfiguration")
# set text tokenizer and collater
print("Setting text tokenizer and collater...")
tokenizer_path = os.path.join(
PREPARED_BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_69.json"
)
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path)
text_collater = get_text_token_collater()
# set device
print("Setting device...")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
# if torch.backends.mps.is_available():
# device = torch.device("mps")
print(f"Device set to {device}")
# Download VALL-E-X model weights if not exists
OUTPUT_DIR_CHECKPOINTS = os.path.join(OUTPUT_BASE_DIR, "models/checkpoints")
OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt"
OUTPUT_PATH_CHECKPOINTS = os.path.join(
OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS
)
if not os.path.exists(OUTPUT_DIR_CHECKPOINTS):
os.makedirs(OUTPUT_DIR_CHECKPOINTS, exist_ok=True)
if not os.path.exists(OUTPUT_PATH_CHECKPOINTS):
import wget
logging.info(
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ..."
)
try:
wget.download(
"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
out=OUTPUT_PATH_CHECKPOINTS,
bar=wget.bar_adaptive,
)
print("Model weights downloaded successfully")
except Exception as e:
logger.error("Error downloading model weights: %s", e)
raise Exception(
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
f"\n manually download model weights and put it to {OUTPUT_DIR_CHECKPOINTS}: {str(e)}"
)
# initialize VALL-E-X model
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
)
checkpoint = torch.load(
OUTPUT_PATH_CHECKPOINTS, map_location=device, weights_only=False
)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True)
assert not missing_keys
model.eval()
# Encodec-based tokenizer: converts reference audio into discrete conditioning tokens for VALLE
print("Initializing Encodec-based tokenizer...")
audio_tokenizer = AudioTokenizer(device)
# Vocos vocoder: decodes VALLE's discrete acoustic codes back into a 24 kHz waveform
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
# initialize ASR model
OUTPUT_DIR_WHISPER = os.path.join(PREPARED_BASE_DIR, "models/whisper")
if not os.path.exists(OUTPUT_DIR_WHISPER):
os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True)
try:
print("Loading Whisper model...")
model_name = "tiny"
whisper_model = whisper.load_model(
model_name, device="cpu", download_root=OUTPUT_DIR_WHISPER
)
print("Whisper model loaded successfully")
except NotImplementedError as e:
logger.error("Error on loading Whisper model: %s", e)
raise Exception(
f"Whisper model {model_name} is not supported on this platform."
) from e
except Exception as e:
logger.error("Error on loading Whisper model: %s", e)
raise Exception(
"\n Whisper download failed or damaged, please go to "
f"'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/{model_name}.pt'"
f"\n manually download model and put it to {OUTPUT_DIR_WHISPER}."
) from e
# Initialize Voice Presets
print("Initializing Voice Presets...")
PRESETS_DIR = os.path.join(PREPARED_BASE_DIR, "apps/audio_cloning/vallex/presets")
preset_list = os.walk(PRESETS_DIR).__next__()[2]
preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
def clear_prompts():
try:
path = tempfile.gettempdir()
for eachfile in os.listdir(path):
filename = os.path.join(path, eachfile)
if os.path.isfile(filename) and filename.endswith(".npz"):
lastmodifytime = os.stat(filename).st_mtime
endfiletime = time.time() - 60
if endfiletime > lastmodifytime:
os.remove(filename)
except Exception as e:
logger.error("Error clearing prompts: %s", e)
return
@spaces.GPU(duration=120)
def transcribe_one(model, audio_path):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
lang = max(probs, key=probs.get)
# decode the audio
options = whisper.DecodingOptions(
temperature=1.0,
best_of=5,
fp16=False if device == torch.device("cpu") else True,
sample_len=150,
)
result = whisper.decode(model, mel, options)
# print the recognized text
print(result.text)
text_pr = result.text
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
text_pr += "."
return lang, text_pr
@spaces.GPU(duration=120)
def transcribe_one_with_gpu(model, audio_path):
model.eval()
# ZeroGPU では GPU 初期化/移動は関数内で
if torch.cuda.is_available():
model = model.to("cuda", non_blocking=True)
use_fp16 = True
dev = torch.device("cuda")
else:
use_fp16 = False
dev = torch.device("cpu")
# 推論は grad 無効に(速くて軽い)
with torch.inference_mode():
# 30 秒にパディング/トリム
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
# ログメルを作成(最初は CPU の密テンソル想定)
mel = whisper.log_mel_spectrogram(audio)
mel = mel.to(dev, non_blocking=True)
# 言語推定
_, probs = model.detect_language(mel)
lang = max(probs, key=probs.get)
print(f"Detected language: {lang}")
# デコード
options = whisper.DecodingOptions(
temperature=1.0,
best_of=5,
fp16=use_fp16,
sample_len=150,
)
result = whisper.decode(model, mel, options)
text_pr = result.text
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
text_pr += "."
return lang, text_pr
@spaces.GPU(duration=120)
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
global model, text_collater, text_tokenizer, audio_tokenizer
clear_prompts()
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
sr, wav_pr = audio_prompt
if not isinstance(wav_pr, torch.FloatTensor):
wav_pr = torch.FloatTensor(wav_pr)
if wav_pr.abs().max() > 1:
wav_pr /= wav_pr.abs().max()
if wav_pr.size(-1) == 2:
wav_pr = wav_pr[:, 0]
if wav_pr.ndim == 1:
wav_pr = wav_pr.unsqueeze(0)
assert wav_pr.ndim and wav_pr.size(0) == 1
if transcript_content == "":
text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
else:
lang_pr = langid.classify(str(transcript_content))[0]
lang_token = lang2token[lang_pr]
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
# tokenize audio
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
# tokenize text
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
text_tokens, enroll_x_lens = text_collater([phonemes])
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
# save as npz file
np.savez(
os.path.join(tempfile.gettempdir(), f"{name}.npz"),
audio_tokens=audio_tokens,
text_tokens=text_tokens,
lang_code=lang2code[lang_pr],
)
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
@spaces.GPU(duration=120)
def make_prompt(name, wav, sr, save=True):
global whisper_model
whisper_model.to(device)
if not isinstance(wav, torch.FloatTensor):
wav = torch.tensor(wav)
if wav.abs().max() > 1:
wav /= wav.abs().max()
if wav.size(-1) == 2:
wav = wav.mean(-1, keepdim=False)
if wav.ndim == 1:
wav = wav.unsqueeze(0)
assert wav.ndim and wav.size(0) == 1
torchaudio.save(f"./prompts/{name}.wav", wav, sr)
lang, text = transcribe_one_with_gpu(whisper_model, f"./prompts/{name}.wav")
lang_token = lang2token[lang]
text = lang_token + text + lang_token
with open(f"./prompts/{name}.txt", "w", encoding="utf-8") as f:
f.write(text)
if not save:
os.remove(f"./prompts/{name}.wav")
os.remove(f"./prompts/{name}.txt")
whisper_model.cpu()
torch.cuda.empty_cache()
return text, lang
@spaces.GPU(duration=120)
@torch.no_grad()
def infer_from_audio(
text, language, accent, audio_prompt, record_audio_prompt, transcript_content
):
global model, text_collater, text_tokenizer, audio_tokenizer
timings = []
start_time = time.perf_counter()
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
sr, wav_pr = audio_prompt
if not isinstance(wav_pr, torch.FloatTensor):
wav_pr = torch.FloatTensor(wav_pr)
if wav_pr.abs().max() > 1:
wav_pr /= wav_pr.abs().max()
if wav_pr.size(-1) == 2:
wav_pr = wav_pr[:, 0]
if wav_pr.ndim == 1:
wav_pr = wav_pr.unsqueeze(0)
assert wav_pr.ndim and wav_pr.size(0) == 1
timings.append(("前処理", time.perf_counter() - start_time))
start_time = time.perf_counter()
if transcript_content == "":
text_pr, lang_pr = make_prompt("dummy", wav_pr, sr, save=False)
else:
lang_pr = langid.classify(str(transcript_content))[0]
lang_token = lang2token[lang_pr]
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
if language == "auto-detect":
lang_token = lang2token[langid.classify(text)[0]]
else:
lang_token = langdropdown2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
timings.append(("テキスト準備", time.perf_counter() - start_time))
# onload model
model.to(device)
start_time = time.perf_counter()
# tokenize audio
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
timings.append(("話者特徴抽出", time.perf_counter() - start_time))
start_time = time.perf_counter()
# tokenize text
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater([phone_tokens])
enroll_x_lens = None
if text_pr:
text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
text_prompts, enroll_x_lens = text_collater([text_prompts])
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
start_time = time.perf_counter()
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
best_of=5,
)
timings.append(("音響モデル推論", time.perf_counter() - start_time))
# Decode with Vocos
start_time = time.perf_counter()
frames = encoded_frames.permute(2, 0, 1)
features = vocos.codes_to_features(frames)
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
timings.append(("ボコーダ復号", time.perf_counter() - start_time))
for step, duration in timings:
print(f"{step}{duration:.4f} sec")
timing_report = "\n↓\n".join(
f"{step}{duration:.4f} sec" for step, duration in timings
)
print(f"推論ステップ計測結果\n{timing_report}")
message = f"text prompt: {text_pr}\nsythesized text: {text}"
return message, (24000, samples.squeeze(0).cpu().numpy())
def main():
app = gr.Blocks(title="VALL-E X")
with app:
gr.Markdown(top_ja_md)
with gr.Tab("Infer from audio"):
gr.Markdown(infer_from_audio_ja_md)
with gr.Row():
with gr.Column():
textbox = gr.TextArea(
label="音声合成で喋らせたいテキスト",
# placeholder="Type your sentence here",
placeholder="ここに音声合成で喋らせたいテキストを入力してください。",
value="Welcome back, Master. What can I do for you today?",
elem_id="tts-input",
)
language_dropdown = gr.Dropdown(
choices=["auto-detect", "English", "中文", "日本語"],
value="auto-detect",
label="language",
)
accent_dropdown = gr.Dropdown(
choices=["no-accent", "English", "中文", "日本語"],
value="no-accent",
label="accent",
)
textbox_transcript = gr.TextArea(
label="Transcript",
# placeholder="Write transcript here. (leave empty to use whisper)",
placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)",
value="",
elem_id="prompt-name",
)
upload_audio_prompt = gr.Audio(
label="音声アップロード",
sources=["upload"],
interactive=True,
)
record_audio_prompt = gr.Audio(
label="音声を録音する",
sources=["microphone"],
interactive=True,
)
with gr.Column():
text_output = gr.Textbox(label="Message")
audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
btn = gr.Button("音声合成を開始する")
btn.click(
infer_from_audio,
inputs=[
textbox,
language_dropdown,
accent_dropdown,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[text_output, audio_output],
)
textbox_mp = gr.TextArea(
label="Prompt name",
placeholder="Name your prompt here",
value="prompt_1",
elem_id="prompt-name",
)
btn_mp = gr.Button("Make prompt!")
prompt_output = gr.File(interactive=False)
btn_mp.click(
make_npz_prompt,
inputs=[
textbox_mp,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[text_output, prompt_output],
)
gr.Examples(
examples=infer_from_audio_examples,
inputs=[
textbox,
language_dropdown,
accent_dropdown,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[text_output, audio_output],
fn=infer_from_audio,
cache_examples=False,
)