vallex-prototyping / apps /audio_cloning /cheched_vallex.py
nukopy
fix: logging format
3355fb8
import logging
import os
import re
import shutil
import time
from typing import List, Optional, Tuple
import gradio as gr
import numpy as np
import spaces
import torch
from .vallex import main as vallex
from .vallex.descriptions import infer_from_audio_ja_md, top_ja_md
from .vallex.examples import infer_from_audio_examples
from .vallex.macros import code2lang, lang2token, langdropdown2token, token2lang
logger = logging.getLogger(__name__)
PROMPTS_DIR = "./models/prompts"
PROMPT_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$")
def _ensure_prompt_dir() -> str:
os.makedirs(PROMPTS_DIR, exist_ok=True)
return PROMPTS_DIR
def _list_saved_prompts() -> List[str]:
directory = _ensure_prompt_dir()
files = [f for f in os.listdir(directory) if f.endswith(".npz")]
return sorted(files)
def _format_prompt_list() -> str:
prompts = _list_saved_prompts()
return "\n".join(prompts) if prompts else "保存済みプロンプトはありません。"
def save_prompt_to_cache(
prompt_id: str,
upload_audio_prompt: Optional[Tuple[int, np.ndarray]],
record_audio_prompt: Optional[Tuple[int, np.ndarray]],
transcript_content: str,
):
prompt_id = prompt_id.strip()
if prompt_id.lower().endswith(".npz"):
prompt_id = prompt_id[:-4]
if not prompt_id:
return (
"プロンプト ID を入力してください。",
None,
gr.update(choices=_list_saved_prompts(), value=None),
gr.update(value=_format_prompt_list()),
)
if not PROMPT_ID_PATTERN.match(prompt_id):
return (
"プロンプト ID には英数字・ハイフン・アンダースコアのみ使用できます。",
None,
gr.update(choices=_list_saved_prompts(), value=None),
gr.update(value=_format_prompt_list()),
)
audio_prompt = (
upload_audio_prompt if upload_audio_prompt is not None else record_audio_prompt
)
if audio_prompt is None:
return (
"音声をアップロードするか録音してください。",
None,
gr.update(choices=_list_saved_prompts(), value=None),
gr.update(value=_format_prompt_list()),
)
try:
message, temp_path = vallex.make_npz_prompt(
prompt_id,
upload_audio_prompt,
record_audio_prompt,
transcript_content,
)
except Exception as err: # pylint: disable=broad-except
logger.exception("Failed to create prompt", exc_info=err)
return (
f"プロンプト作成に失敗しました: {err}",
None,
gr.update(choices=_list_saved_prompts(), value=None),
gr.update(value=_format_prompt_list()),
)
_ensure_prompt_dir()
cached_filename = f"{prompt_id}.npz"
cached_path = os.path.join(PROMPTS_DIR, cached_filename)
try:
shutil.copy(temp_path, cached_path)
except OSError as err:
logger.exception("Failed to copy prompt to cache", exc_info=err)
return (
f"プロンプトの保存に失敗しました: {err}",
None,
gr.update(choices=_list_saved_prompts(), value=None),
gr.update(value=_format_prompt_list()),
)
finally:
try:
os.remove(temp_path)
except OSError:
pass
choices = _list_saved_prompts()
message = (
f"{message}\nSaved cached prompt to {cached_path}"
if message
else f"Saved cached prompt to {cached_path}"
)
return (
message,
cached_path,
gr.update(choices=choices, value=cached_filename),
gr.update(value=_format_prompt_list()),
)
def refresh_prompt_choices():
choices = _list_saved_prompts()
value = choices[0] if choices else None
return (
gr.update(choices=choices, value=value),
gr.update(value=_format_prompt_list()),
)
@spaces.GPU(duration=120)
@torch.no_grad()
def infer_from_cached_prompt(
text: str,
language: str,
accent: str,
prompt_filename: Optional[str],
):
if not text:
return "テキストを入力してください。", None
if not prompt_filename:
return "プロンプトを選択してください。", None
prompt_path = os.path.join(_ensure_prompt_dir(), prompt_filename)
if not os.path.exists(prompt_path):
return f"プロンプトが見つかりません: {prompt_path}", None
timings: List[Tuple[str, float]] = []
start_time = time.perf_counter()
try:
print(f"Loading cached prompt from: {prompt_path}")
prompt_data = np.load(prompt_path)
audio_tokens = torch.from_numpy(prompt_data["audio_tokens"]).to(
dtype=torch.long
)
text_prompts = torch.from_numpy(prompt_data["text_tokens"]).to(dtype=torch.long)
lang_code = (
int(prompt_data["lang_code"])
if prompt_data["lang_code"].shape == ()
else int(prompt_data["lang_code"][0])
)
except Exception as err: # pylint: disable=broad-except
logger.exception("Failed to load cached prompt", exc_info=err)
return (f"プロンプトの読み込みに失敗しました: {err}", None)
timings.append(("[cached] 話者特徴抽出", time.perf_counter() - start_time))
lang_pr = code2lang.get(lang_code, "en")
start_time = time.perf_counter()
if language == "auto-detect":
detected_lang = vallex.langid.classify(text)[0]
lang_token = lang2token.get(detected_lang, "[EN]")
else:
lang_token = langdropdown2token[language]
conditioned_text = f"{lang_token}{text}{lang_token}"
timings.append(("テキスト準備", time.perf_counter() - start_time))
start_time = time.perf_counter()
phone_tokens, langs = vallex.text_tokenizer.tokenize(
text=f"_{conditioned_text}".strip()
)
text_tokens, text_tokens_lens = vallex.text_collater([phone_tokens])
enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]])
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
vallex.model.to(vallex.device)
audio_prompts = audio_tokens.to(vallex.device)
if audio_prompts.dim() == 2:
audio_prompts = audio_prompts.unsqueeze(0)
start_time = time.perf_counter()
print(f"Start inferring from cached prompt: {prompt_path}")
encoded_frames = vallex.model.inference(
text_tokens.to(vallex.device),
text_tokens_lens.to(vallex.device),
audio_prompts,
enroll_x_lens=enroll_x_lens.to(vallex.device),
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs
if accent == "no-accent"
else token2lang[langdropdown2token[accent]],
best_of=5,
)
timings.append(("音響モデル推論", time.perf_counter() - start_time))
print("Inference completed")
start_time = time.perf_counter()
print("Decoding with Vocos...")
frames = encoded_frames.permute(2, 0, 1)
features = vallex.vocos.codes_to_features(frames)
samples = vallex.vocos.decode(
features, bandwidth_id=torch.tensor([2], device=vallex.device)
)
timings.append(("ボコーダ復号", time.perf_counter() - start_time))
print("Decoding completed")
message = (
f"Loaded cached prompt: {prompt_filename}\n"
f"Prompt language: {lang_pr}\n"
f"Synthesized text: {conditioned_text}"
)
for step, duration in timings:
print(f"{step}{duration:.4f} sec")
timing_report = "\n↓\n".join(
f"{step}{duration:.4f} sec" for step, duration in timings
)
print(f"推論ステップ計測結果\n{timing_report}")
return message, (24000, samples.squeeze(0).cpu().numpy())
def main():
prompt_choices = _list_saved_prompts()
gr.Markdown(top_ja_md)
gr.Markdown(infer_from_audio_ja_md)
gr.Markdown("[Cached] Zero-shot 音声クローニング")
with gr.Row():
with gr.Column():
textbox = gr.TextArea(
label="音声合成で喋らせたいテキスト",
placeholder="ここに音声合成で喋らせたいテキストを入力してください。",
value="Welcome back, Master. What can I do for you today?",
elem_id="tts-input-cached",
)
language_dropdown = gr.Dropdown(
choices=["auto-detect", "English", "中文", "日本語"],
value="auto-detect",
label="language",
)
accent_dropdown = gr.Dropdown(
choices=["no-accent", "English", "中文", "日本語"],
value="no-accent",
label="accent",
)
textbox_transcript = gr.TextArea(
label="Transcript",
placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)",
value="",
)
upload_audio_prompt = gr.Audio(
label="音声アップロード",
sources=["upload"],
interactive=True,
)
record_audio_prompt = gr.Audio(
label="音声を録音する",
sources=["microphone"],
interactive=True,
)
prompt_id_box = gr.Textbox(
label="Prompt ID",
placeholder="例: my_speaker01",
value="",
)
cached_prompt_dropdown = gr.Dropdown(
label="Cached prompts",
choices=prompt_choices,
value=prompt_choices[0] if prompt_choices else None,
interactive=True,
)
prompt_list_box = gr.Textbox(
label="保存済みプロンプト一覧",
value=_format_prompt_list(),
interactive=False,
lines=6,
)
refresh_btn = gr.Button("キャッシュ一覧を更新")
with gr.Column():
text_output = gr.Textbox(label="Message")
audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio-cached")
btn_infer = gr.Button("音声合成を開始する")
btn_infer.click(
vallex.infer_from_audio,
inputs=[
textbox,
language_dropdown,
accent_dropdown,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[text_output, audio_output],
)
prompt_output = gr.File(label="Generated prompt", interactive=False)
btn_save = gr.Button("./models/prompts に保存")
btn_save.click(
save_prompt_to_cache,
inputs=[
prompt_id_box,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[
text_output,
prompt_output,
cached_prompt_dropdown,
prompt_list_box,
],
)
btn_cached_infer = gr.Button("キャッシュしたプロンプトで合成")
btn_cached_infer.click(
infer_from_cached_prompt,
inputs=[
textbox,
language_dropdown,
accent_dropdown,
cached_prompt_dropdown,
],
outputs=[text_output, audio_output],
)
refresh_btn.click(
refresh_prompt_choices,
inputs=None,
outputs=[cached_prompt_dropdown, prompt_list_box],
)
gr.Examples(
examples=infer_from_audio_examples,
inputs=[
textbox,
language_dropdown,
accent_dropdown,
upload_audio_prompt,
record_audio_prompt,
textbox_transcript,
],
outputs=[text_output, audio_output],
fn=vallex.infer_from_audio,
cache_examples=False,
)