Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import re | |
| import shutil | |
| import time | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from .vallex import main as vallex | |
| from .vallex.descriptions import infer_from_audio_ja_md, top_ja_md | |
| from .vallex.examples import infer_from_audio_examples | |
| from .vallex.macros import code2lang, lang2token, langdropdown2token, token2lang | |
| logger = logging.getLogger(__name__) | |
| PROMPTS_DIR = "./models/prompts" | |
| PROMPT_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$") | |
| def _ensure_prompt_dir() -> str: | |
| os.makedirs(PROMPTS_DIR, exist_ok=True) | |
| return PROMPTS_DIR | |
| def _list_saved_prompts() -> List[str]: | |
| directory = _ensure_prompt_dir() | |
| files = [f for f in os.listdir(directory) if f.endswith(".npz")] | |
| return sorted(files) | |
| def _format_prompt_list() -> str: | |
| prompts = _list_saved_prompts() | |
| return "\n".join(prompts) if prompts else "保存済みプロンプトはありません。" | |
| def save_prompt_to_cache( | |
| prompt_id: str, | |
| upload_audio_prompt: Optional[Tuple[int, np.ndarray]], | |
| record_audio_prompt: Optional[Tuple[int, np.ndarray]], | |
| transcript_content: str, | |
| ): | |
| prompt_id = prompt_id.strip() | |
| if prompt_id.lower().endswith(".npz"): | |
| prompt_id = prompt_id[:-4] | |
| if not prompt_id: | |
| return ( | |
| "プロンプト ID を入力してください。", | |
| None, | |
| gr.update(choices=_list_saved_prompts(), value=None), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| if not PROMPT_ID_PATTERN.match(prompt_id): | |
| return ( | |
| "プロンプト ID には英数字・ハイフン・アンダースコアのみ使用できます。", | |
| None, | |
| gr.update(choices=_list_saved_prompts(), value=None), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| audio_prompt = ( | |
| upload_audio_prompt if upload_audio_prompt is not None else record_audio_prompt | |
| ) | |
| if audio_prompt is None: | |
| return ( | |
| "音声をアップロードするか録音してください。", | |
| None, | |
| gr.update(choices=_list_saved_prompts(), value=None), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| try: | |
| message, temp_path = vallex.make_npz_prompt( | |
| prompt_id, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| transcript_content, | |
| ) | |
| except Exception as err: # pylint: disable=broad-except | |
| logger.exception("Failed to create prompt", exc_info=err) | |
| return ( | |
| f"プロンプト作成に失敗しました: {err}", | |
| None, | |
| gr.update(choices=_list_saved_prompts(), value=None), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| _ensure_prompt_dir() | |
| cached_filename = f"{prompt_id}.npz" | |
| cached_path = os.path.join(PROMPTS_DIR, cached_filename) | |
| try: | |
| shutil.copy(temp_path, cached_path) | |
| except OSError as err: | |
| logger.exception("Failed to copy prompt to cache", exc_info=err) | |
| return ( | |
| f"プロンプトの保存に失敗しました: {err}", | |
| None, | |
| gr.update(choices=_list_saved_prompts(), value=None), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| finally: | |
| try: | |
| os.remove(temp_path) | |
| except OSError: | |
| pass | |
| choices = _list_saved_prompts() | |
| message = ( | |
| f"{message}\nSaved cached prompt to {cached_path}" | |
| if message | |
| else f"Saved cached prompt to {cached_path}" | |
| ) | |
| return ( | |
| message, | |
| cached_path, | |
| gr.update(choices=choices, value=cached_filename), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| def refresh_prompt_choices(): | |
| choices = _list_saved_prompts() | |
| value = choices[0] if choices else None | |
| return ( | |
| gr.update(choices=choices, value=value), | |
| gr.update(value=_format_prompt_list()), | |
| ) | |
| def infer_from_cached_prompt( | |
| text: str, | |
| language: str, | |
| accent: str, | |
| prompt_filename: Optional[str], | |
| ): | |
| if not text: | |
| return "テキストを入力してください。", None | |
| if not prompt_filename: | |
| return "プロンプトを選択してください。", None | |
| prompt_path = os.path.join(_ensure_prompt_dir(), prompt_filename) | |
| if not os.path.exists(prompt_path): | |
| return f"プロンプトが見つかりません: {prompt_path}", None | |
| timings: List[Tuple[str, float]] = [] | |
| start_time = time.perf_counter() | |
| try: | |
| print(f"Loading cached prompt from: {prompt_path}") | |
| prompt_data = np.load(prompt_path) | |
| audio_tokens = torch.from_numpy(prompt_data["audio_tokens"]).to( | |
| dtype=torch.long | |
| ) | |
| text_prompts = torch.from_numpy(prompt_data["text_tokens"]).to(dtype=torch.long) | |
| lang_code = ( | |
| int(prompt_data["lang_code"]) | |
| if prompt_data["lang_code"].shape == () | |
| else int(prompt_data["lang_code"][0]) | |
| ) | |
| except Exception as err: # pylint: disable=broad-except | |
| logger.exception("Failed to load cached prompt", exc_info=err) | |
| return (f"プロンプトの読み込みに失敗しました: {err}", None) | |
| timings.append(("[cached] 話者特徴抽出", time.perf_counter() - start_time)) | |
| lang_pr = code2lang.get(lang_code, "en") | |
| start_time = time.perf_counter() | |
| if language == "auto-detect": | |
| detected_lang = vallex.langid.classify(text)[0] | |
| lang_token = lang2token.get(detected_lang, "[EN]") | |
| else: | |
| lang_token = langdropdown2token[language] | |
| conditioned_text = f"{lang_token}{text}{lang_token}" | |
| timings.append(("テキスト準備", time.perf_counter() - start_time)) | |
| start_time = time.perf_counter() | |
| phone_tokens, langs = vallex.text_tokenizer.tokenize( | |
| text=f"_{conditioned_text}".strip() | |
| ) | |
| text_tokens, text_tokens_lens = vallex.text_collater([phone_tokens]) | |
| enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]]) | |
| text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) | |
| text_tokens_lens += enroll_x_lens | |
| timings.append(("音素化/トークナイズ", time.perf_counter() - start_time)) | |
| vallex.model.to(vallex.device) | |
| audio_prompts = audio_tokens.to(vallex.device) | |
| if audio_prompts.dim() == 2: | |
| audio_prompts = audio_prompts.unsqueeze(0) | |
| start_time = time.perf_counter() | |
| print(f"Start inferring from cached prompt: {prompt_path}") | |
| encoded_frames = vallex.model.inference( | |
| text_tokens.to(vallex.device), | |
| text_tokens_lens.to(vallex.device), | |
| audio_prompts, | |
| enroll_x_lens=enroll_x_lens.to(vallex.device), | |
| top_k=-100, | |
| temperature=1, | |
| prompt_language=lang_pr, | |
| text_language=langs | |
| if accent == "no-accent" | |
| else token2lang[langdropdown2token[accent]], | |
| best_of=5, | |
| ) | |
| timings.append(("音響モデル推論", time.perf_counter() - start_time)) | |
| print("Inference completed") | |
| start_time = time.perf_counter() | |
| print("Decoding with Vocos...") | |
| frames = encoded_frames.permute(2, 0, 1) | |
| features = vallex.vocos.codes_to_features(frames) | |
| samples = vallex.vocos.decode( | |
| features, bandwidth_id=torch.tensor([2], device=vallex.device) | |
| ) | |
| timings.append(("ボコーダ復号", time.perf_counter() - start_time)) | |
| print("Decoding completed") | |
| message = ( | |
| f"Loaded cached prompt: {prompt_filename}\n" | |
| f"Prompt language: {lang_pr}\n" | |
| f"Synthesized text: {conditioned_text}" | |
| ) | |
| for step, duration in timings: | |
| print(f"{step}:{duration:.4f} sec") | |
| timing_report = "\n↓\n".join( | |
| f"{step}:{duration:.4f} sec" for step, duration in timings | |
| ) | |
| print(f"推論ステップ計測結果\n{timing_report}") | |
| return message, (24000, samples.squeeze(0).cpu().numpy()) | |
| def main(): | |
| prompt_choices = _list_saved_prompts() | |
| gr.Markdown(top_ja_md) | |
| gr.Markdown(infer_from_audio_ja_md) | |
| gr.Markdown("[Cached] Zero-shot 音声クローニング") | |
| with gr.Row(): | |
| with gr.Column(): | |
| textbox = gr.TextArea( | |
| label="音声合成で喋らせたいテキスト", | |
| placeholder="ここに音声合成で喋らせたいテキストを入力してください。", | |
| value="Welcome back, Master. What can I do for you today?", | |
| elem_id="tts-input-cached", | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=["auto-detect", "English", "中文", "日本語"], | |
| value="auto-detect", | |
| label="language", | |
| ) | |
| accent_dropdown = gr.Dropdown( | |
| choices=["no-accent", "English", "中文", "日本語"], | |
| value="no-accent", | |
| label="accent", | |
| ) | |
| textbox_transcript = gr.TextArea( | |
| label="Transcript", | |
| placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)", | |
| value="", | |
| ) | |
| upload_audio_prompt = gr.Audio( | |
| label="音声アップロード", | |
| sources=["upload"], | |
| interactive=True, | |
| ) | |
| record_audio_prompt = gr.Audio( | |
| label="音声を録音する", | |
| sources=["microphone"], | |
| interactive=True, | |
| ) | |
| prompt_id_box = gr.Textbox( | |
| label="Prompt ID", | |
| placeholder="例: my_speaker01", | |
| value="", | |
| ) | |
| cached_prompt_dropdown = gr.Dropdown( | |
| label="Cached prompts", | |
| choices=prompt_choices, | |
| value=prompt_choices[0] if prompt_choices else None, | |
| interactive=True, | |
| ) | |
| prompt_list_box = gr.Textbox( | |
| label="保存済みプロンプト一覧", | |
| value=_format_prompt_list(), | |
| interactive=False, | |
| lines=6, | |
| ) | |
| refresh_btn = gr.Button("キャッシュ一覧を更新") | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Message") | |
| audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio-cached") | |
| btn_infer = gr.Button("音声合成を開始する") | |
| btn_infer.click( | |
| vallex.infer_from_audio, | |
| inputs=[ | |
| textbox, | |
| language_dropdown, | |
| accent_dropdown, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[text_output, audio_output], | |
| ) | |
| prompt_output = gr.File(label="Generated prompt", interactive=False) | |
| btn_save = gr.Button("./models/prompts に保存") | |
| btn_save.click( | |
| save_prompt_to_cache, | |
| inputs=[ | |
| prompt_id_box, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[ | |
| text_output, | |
| prompt_output, | |
| cached_prompt_dropdown, | |
| prompt_list_box, | |
| ], | |
| ) | |
| btn_cached_infer = gr.Button("キャッシュしたプロンプトで合成") | |
| btn_cached_infer.click( | |
| infer_from_cached_prompt, | |
| inputs=[ | |
| textbox, | |
| language_dropdown, | |
| accent_dropdown, | |
| cached_prompt_dropdown, | |
| ], | |
| outputs=[text_output, audio_output], | |
| ) | |
| refresh_btn.click( | |
| refresh_prompt_choices, | |
| inputs=None, | |
| outputs=[cached_prompt_dropdown, prompt_list_box], | |
| ) | |
| gr.Examples( | |
| examples=infer_from_audio_examples, | |
| inputs=[ | |
| textbox, | |
| language_dropdown, | |
| accent_dropdown, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[text_output, audio_output], | |
| fn=vallex.infer_from_audio, | |
| cache_examples=False, | |
| ) | |