Spaces:
Sleeping
Sleeping
| import logging | |
| import multiprocessing | |
| import os | |
| import pathlib | |
| import platform | |
| import sys | |
| import tempfile | |
| import time | |
| import gradio as gr | |
| import langid | |
| import nltk | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torchaudio | |
| import whisper | |
| from vocos import Vocos | |
| from .data.collation import get_text_token_collater | |
| from .data.tokenizer import ( | |
| AudioTokenizer, | |
| tokenize_audio, | |
| ) | |
| from .descriptions import infer_from_audio_ja_md, top_ja_md | |
| from .examples import infer_from_audio_examples | |
| from .g2p import PhonemeBpeTokenizer | |
| from .macros import ( | |
| N_DIM, | |
| NUM_HEAD, | |
| NUM_LAYERS, | |
| NUM_QUANTIZERS, | |
| PREFIX_MODE, | |
| lang2code, | |
| lang2token, | |
| langdropdown2token, | |
| token2lang, | |
| ) | |
| from .models.vallex import VALLE | |
| logger = logging.getLogger(__name__) | |
| # set base directory | |
| OUTPUT_BASE_DIR = os.getenv("HF_HOME", ".") | |
| PREPARED_BASE_DIR = "." | |
| print(f"Base directory: {OUTPUT_BASE_DIR}") | |
| print(f"Prepared base directory: {PREPARED_BASE_DIR}") | |
| # set languages | |
| langid.set_languages(["en", "zh", "ja"]) | |
| # set nltk data path | |
| nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")] | |
| print(f"nltk_data path: {nltk.data.path}") | |
| # get encoding | |
| print( | |
| "default encoding is " | |
| f"{sys.getdefaultencoding()}," | |
| f"file system encoding is {sys.getfilesystemencoding()}" | |
| ) | |
| # check python version | |
| print(f"You are using Python version {platform.python_version()}") | |
| if sys.version_info[0] < 3 or sys.version_info[1] < 7: | |
| logger.warning("The Python version is too low and may cause problems") | |
| if platform.system().lower() == "windows": | |
| temp = pathlib.PosixPath | |
| pathlib.PosixPath = pathlib.WindowsPath | |
| else: | |
| temp = pathlib.WindowsPath | |
| pathlib.WindowsPath = pathlib.PosixPath | |
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
| # set torch threads (guarded for hot-reload) | |
| thread_count = multiprocessing.cpu_count() | |
| print(f"Use {thread_count} cpu cores for computing") | |
| if not getattr(torch, "_vallex_threads_configured", False): | |
| torch.set_num_threads(thread_count) | |
| try: | |
| torch.set_num_interop_threads(thread_count) | |
| except RuntimeError as err: | |
| logger.warning("Skipping set_num_interop_threads: %s", err) | |
| torch._C._jit_set_profiling_executor(False) | |
| torch._C._jit_set_profiling_mode(False) | |
| torch._C._set_graph_executor_optimize(False) | |
| # gradio のリロード時に torch.set_num_iterop_threads を実行するとエラーになるので、設定済みのフラグをセット | |
| setattr(torch, "_vallex_threads_configured", True) | |
| else: | |
| print("Torch threads already configured; skipping reconfiguration") | |
| # set text tokenizer and collater | |
| print("Setting text tokenizer and collater...") | |
| tokenizer_path = os.path.join( | |
| PREPARED_BASE_DIR, "apps/audio_cloning/vallex/g2p/bpe_69.json" | |
| ) | |
| text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path) | |
| text_collater = get_text_token_collater() | |
| # set device | |
| print("Setting device...") | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda", 0) | |
| # if torch.backends.mps.is_available(): | |
| # device = torch.device("mps") | |
| print(f"Device set to {device}") | |
| # Download VALL-E-X model weights if not exists | |
| OUTPUT_DIR_CHECKPOINTS = os.path.join(OUTPUT_BASE_DIR, "models/checkpoints") | |
| OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt" | |
| OUTPUT_PATH_CHECKPOINTS = os.path.join( | |
| OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS | |
| ) | |
| if not os.path.exists(OUTPUT_DIR_CHECKPOINTS): | |
| os.makedirs(OUTPUT_DIR_CHECKPOINTS, exist_ok=True) | |
| if not os.path.exists(OUTPUT_PATH_CHECKPOINTS): | |
| import wget | |
| logging.info( | |
| "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ..." | |
| ) | |
| try: | |
| wget.download( | |
| "https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt", | |
| out=OUTPUT_PATH_CHECKPOINTS, | |
| bar=wget.bar_adaptive, | |
| ) | |
| print("Model weights downloaded successfully") | |
| except Exception as e: | |
| logger.error("Error downloading model weights: %s", e) | |
| raise Exception( | |
| "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'" | |
| f"\n manually download model weights and put it to {OUTPUT_DIR_CHECKPOINTS}: {str(e)}" | |
| ) | |
| # initialize VALL-E-X model | |
| model = VALLE( | |
| N_DIM, | |
| NUM_HEAD, | |
| NUM_LAYERS, | |
| norm_first=True, | |
| add_prenet=False, | |
| prefix_mode=PREFIX_MODE, | |
| share_embedding=True, | |
| nar_scale_factor=1.0, | |
| prepend_bos=True, | |
| num_quantizers=NUM_QUANTIZERS, | |
| ) | |
| checkpoint = torch.load( | |
| OUTPUT_PATH_CHECKPOINTS, map_location=device, weights_only=False | |
| ) | |
| missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True) | |
| assert not missing_keys | |
| model.eval() | |
| # Encodec-based tokenizer: converts reference audio into discrete conditioning tokens for VALLE | |
| print("Initializing Encodec-based tokenizer...") | |
| audio_tokenizer = AudioTokenizer(device) | |
| # Vocos vocoder: decodes VALLE's discrete acoustic codes back into a 24 kHz waveform | |
| vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device) | |
| # initialize ASR model | |
| OUTPUT_DIR_WHISPER = os.path.join(PREPARED_BASE_DIR, "models/whisper") | |
| if not os.path.exists(OUTPUT_DIR_WHISPER): | |
| os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True) | |
| try: | |
| print("Loading Whisper model...") | |
| model_name = "tiny" | |
| whisper_model = whisper.load_model( | |
| model_name, device="cpu", download_root=OUTPUT_DIR_WHISPER | |
| ) | |
| print("Whisper model loaded successfully") | |
| except NotImplementedError as e: | |
| logger.error("Error on loading Whisper model: %s", e) | |
| raise Exception( | |
| f"Whisper model {model_name} is not supported on this platform." | |
| ) from e | |
| except Exception as e: | |
| logger.error("Error on loading Whisper model: %s", e) | |
| raise Exception( | |
| "\n Whisper download failed or damaged, please go to " | |
| f"'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/{model_name}.pt'" | |
| f"\n manually download model and put it to {OUTPUT_DIR_WHISPER}." | |
| ) from e | |
| # Initialize Voice Presets | |
| print("Initializing Voice Presets...") | |
| PRESETS_DIR = os.path.join(PREPARED_BASE_DIR, "apps/audio_cloning/vallex/presets") | |
| preset_list = os.walk(PRESETS_DIR).__next__()[2] | |
| preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")] | |
| def clear_prompts(): | |
| try: | |
| path = tempfile.gettempdir() | |
| for eachfile in os.listdir(path): | |
| filename = os.path.join(path, eachfile) | |
| if os.path.isfile(filename) and filename.endswith(".npz"): | |
| lastmodifytime = os.stat(filename).st_mtime | |
| endfiletime = time.time() - 60 | |
| if endfiletime > lastmodifytime: | |
| os.remove(filename) | |
| except Exception as e: | |
| logger.error("Error clearing prompts: %s", e) | |
| return | |
| def transcribe_one(model, audio_path): | |
| # load audio and pad/trim it to fit 30 seconds | |
| audio = whisper.load_audio(audio_path) | |
| audio = whisper.pad_or_trim(audio) | |
| # make log-Mel spectrogram and move to the same device as the model | |
| mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
| # detect the spoken language | |
| _, probs = model.detect_language(mel) | |
| print(f"Detected language: {max(probs, key=probs.get)}") | |
| lang = max(probs, key=probs.get) | |
| # decode the audio | |
| options = whisper.DecodingOptions( | |
| temperature=1.0, | |
| best_of=5, | |
| fp16=False if device == torch.device("cpu") else True, | |
| sample_len=150, | |
| ) | |
| result = whisper.decode(model, mel, options) | |
| # print the recognized text | |
| print(result.text) | |
| text_pr = result.text | |
| if text_pr.strip(" ")[-1] not in "?!.,。,?!。、": | |
| text_pr += "." | |
| return lang, text_pr | |
| def transcribe_one_with_gpu(model, audio_path): | |
| model.eval() | |
| # ZeroGPU では GPU 初期化/移動は関数内で | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda", non_blocking=True) | |
| use_fp16 = True | |
| dev = torch.device("cuda") | |
| else: | |
| use_fp16 = False | |
| dev = torch.device("cpu") | |
| # 推論は grad 無効に(速くて軽い) | |
| with torch.inference_mode(): | |
| # 30 秒にパディング/トリム | |
| audio = whisper.load_audio(audio_path) | |
| audio = whisper.pad_or_trim(audio) | |
| # ログメルを作成(最初は CPU の密テンソル想定) | |
| mel = whisper.log_mel_spectrogram(audio) | |
| mel = mel.to(dev, non_blocking=True) | |
| # 言語推定 | |
| _, probs = model.detect_language(mel) | |
| lang = max(probs, key=probs.get) | |
| print(f"Detected language: {lang}") | |
| # デコード | |
| options = whisper.DecodingOptions( | |
| temperature=1.0, | |
| best_of=5, | |
| fp16=use_fp16, | |
| sample_len=150, | |
| ) | |
| result = whisper.decode(model, mel, options) | |
| text_pr = result.text | |
| if text_pr.strip(" ")[-1] not in "?!.,。,?!。、": | |
| text_pr += "." | |
| return lang, text_pr | |
| def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content): | |
| global model, text_collater, text_tokenizer, audio_tokenizer | |
| clear_prompts() | |
| audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio | |
| sr, wav_pr = audio_prompt | |
| if not isinstance(wav_pr, torch.FloatTensor): | |
| wav_pr = torch.FloatTensor(wav_pr) | |
| if wav_pr.abs().max() > 1: | |
| wav_pr /= wav_pr.abs().max() | |
| if wav_pr.size(-1) == 2: | |
| wav_pr = wav_pr[:, 0] | |
| if wav_pr.ndim == 1: | |
| wav_pr = wav_pr.unsqueeze(0) | |
| assert wav_pr.ndim and wav_pr.size(0) == 1 | |
| if transcript_content == "": | |
| text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False) | |
| else: | |
| lang_pr = langid.classify(str(transcript_content))[0] | |
| lang_token = lang2token[lang_pr] | |
| text_pr = f"{lang_token}{str(transcript_content)}{lang_token}" | |
| # tokenize audio | |
| encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr)) | |
| audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy() | |
| # tokenize text | |
| phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip()) | |
| text_tokens, enroll_x_lens = text_collater([phonemes]) | |
| message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n" | |
| # save as npz file | |
| np.savez( | |
| os.path.join(tempfile.gettempdir(), f"{name}.npz"), | |
| audio_tokens=audio_tokens, | |
| text_tokens=text_tokens, | |
| lang_code=lang2code[lang_pr], | |
| ) | |
| return message, os.path.join(tempfile.gettempdir(), f"{name}.npz") | |
| def make_prompt(name, wav, sr, save=True): | |
| global whisper_model | |
| whisper_model.to(device) | |
| if not isinstance(wav, torch.FloatTensor): | |
| wav = torch.tensor(wav) | |
| if wav.abs().max() > 1: | |
| wav /= wav.abs().max() | |
| if wav.size(-1) == 2: | |
| wav = wav.mean(-1, keepdim=False) | |
| if wav.ndim == 1: | |
| wav = wav.unsqueeze(0) | |
| assert wav.ndim and wav.size(0) == 1 | |
| torchaudio.save(f"./prompts/{name}.wav", wav, sr) | |
| lang, text = transcribe_one_with_gpu(whisper_model, f"./prompts/{name}.wav") | |
| lang_token = lang2token[lang] | |
| text = lang_token + text + lang_token | |
| with open(f"./prompts/{name}.txt", "w", encoding="utf-8") as f: | |
| f.write(text) | |
| if not save: | |
| os.remove(f"./prompts/{name}.wav") | |
| os.remove(f"./prompts/{name}.txt") | |
| whisper_model.cpu() | |
| torch.cuda.empty_cache() | |
| return text, lang | |
| def infer_from_audio( | |
| text, language, accent, audio_prompt, record_audio_prompt, transcript_content | |
| ): | |
| global model, text_collater, text_tokenizer, audio_tokenizer | |
| timings = [] | |
| start_time = time.perf_counter() | |
| audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt | |
| sr, wav_pr = audio_prompt | |
| if not isinstance(wav_pr, torch.FloatTensor): | |
| wav_pr = torch.FloatTensor(wav_pr) | |
| if wav_pr.abs().max() > 1: | |
| wav_pr /= wav_pr.abs().max() | |
| if wav_pr.size(-1) == 2: | |
| wav_pr = wav_pr[:, 0] | |
| if wav_pr.ndim == 1: | |
| wav_pr = wav_pr.unsqueeze(0) | |
| assert wav_pr.ndim and wav_pr.size(0) == 1 | |
| timings.append(("前処理", time.perf_counter() - start_time)) | |
| start_time = time.perf_counter() | |
| if transcript_content == "": | |
| text_pr, lang_pr = make_prompt("dummy", wav_pr, sr, save=False) | |
| else: | |
| lang_pr = langid.classify(str(transcript_content))[0] | |
| lang_token = lang2token[lang_pr] | |
| text_pr = f"{lang_token}{str(transcript_content)}{lang_token}" | |
| if language == "auto-detect": | |
| lang_token = lang2token[langid.classify(text)[0]] | |
| else: | |
| lang_token = langdropdown2token[language] | |
| lang = token2lang[lang_token] | |
| text = lang_token + text + lang_token | |
| timings.append(("テキスト準備", time.perf_counter() - start_time)) | |
| # onload model | |
| model.to(device) | |
| start_time = time.perf_counter() | |
| # tokenize audio | |
| encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr)) | |
| audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device) | |
| timings.append(("話者特徴抽出", time.perf_counter() - start_time)) | |
| start_time = time.perf_counter() | |
| # tokenize text | |
| logging.info(f"synthesize text: {text}") | |
| phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) | |
| text_tokens, text_tokens_lens = text_collater([phone_tokens]) | |
| enroll_x_lens = None | |
| if text_pr: | |
| text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip()) | |
| text_prompts, enroll_x_lens = text_collater([text_prompts]) | |
| text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) | |
| text_tokens_lens += enroll_x_lens | |
| lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] | |
| timings.append(("音素化/トークナイズ", time.perf_counter() - start_time)) | |
| start_time = time.perf_counter() | |
| encoded_frames = model.inference( | |
| text_tokens.to(device), | |
| text_tokens_lens.to(device), | |
| audio_prompts, | |
| enroll_x_lens=enroll_x_lens, | |
| top_k=-100, | |
| temperature=1, | |
| prompt_language=lang_pr, | |
| text_language=langs if accent == "no-accent" else lang, | |
| best_of=5, | |
| ) | |
| timings.append(("音響モデル推論", time.perf_counter() - start_time)) | |
| # Decode with Vocos | |
| start_time = time.perf_counter() | |
| frames = encoded_frames.permute(2, 0, 1) | |
| features = vocos.codes_to_features(frames) | |
| samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) | |
| timings.append(("ボコーダ復号", time.perf_counter() - start_time)) | |
| for step, duration in timings: | |
| print(f"{step}:{duration:.4f} sec") | |
| timing_report = "\n↓\n".join( | |
| f"{step}:{duration:.4f} sec" for step, duration in timings | |
| ) | |
| print(f"推論ステップ計測結果\n{timing_report}") | |
| message = f"text prompt: {text_pr}\nsythesized text: {text}" | |
| return message, (24000, samples.squeeze(0).cpu().numpy()) | |
| def main(): | |
| app = gr.Blocks(title="VALL-E X") | |
| with app: | |
| gr.Markdown(top_ja_md) | |
| with gr.Tab("Infer from audio"): | |
| gr.Markdown(infer_from_audio_ja_md) | |
| with gr.Row(): | |
| with gr.Column(): | |
| textbox = gr.TextArea( | |
| label="音声合成で喋らせたいテキスト", | |
| # placeholder="Type your sentence here", | |
| placeholder="ここに音声合成で喋らせたいテキストを入力してください。", | |
| value="Welcome back, Master. What can I do for you today?", | |
| elem_id="tts-input", | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=["auto-detect", "English", "中文", "日本語"], | |
| value="auto-detect", | |
| label="language", | |
| ) | |
| accent_dropdown = gr.Dropdown( | |
| choices=["no-accent", "English", "中文", "日本語"], | |
| value="no-accent", | |
| label="accent", | |
| ) | |
| textbox_transcript = gr.TextArea( | |
| label="Transcript", | |
| # placeholder="Write transcript here. (leave empty to use whisper)", | |
| placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)", | |
| value="", | |
| elem_id="prompt-name", | |
| ) | |
| upload_audio_prompt = gr.Audio( | |
| label="音声アップロード", | |
| sources=["upload"], | |
| interactive=True, | |
| ) | |
| record_audio_prompt = gr.Audio( | |
| label="音声を録音する", | |
| sources=["microphone"], | |
| interactive=True, | |
| ) | |
| with gr.Column(): | |
| text_output = gr.Textbox(label="Message") | |
| audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio") | |
| btn = gr.Button("音声合成を開始する") | |
| btn.click( | |
| infer_from_audio, | |
| inputs=[ | |
| textbox, | |
| language_dropdown, | |
| accent_dropdown, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[text_output, audio_output], | |
| ) | |
| textbox_mp = gr.TextArea( | |
| label="Prompt name", | |
| placeholder="Name your prompt here", | |
| value="prompt_1", | |
| elem_id="prompt-name", | |
| ) | |
| btn_mp = gr.Button("Make prompt!") | |
| prompt_output = gr.File(interactive=False) | |
| btn_mp.click( | |
| make_npz_prompt, | |
| inputs=[ | |
| textbox_mp, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[text_output, prompt_output], | |
| ) | |
| gr.Examples( | |
| examples=infer_from_audio_examples, | |
| inputs=[ | |
| textbox, | |
| language_dropdown, | |
| accent_dropdown, | |
| upload_audio_prompt, | |
| record_audio_prompt, | |
| textbox_transcript, | |
| ], | |
| outputs=[text_output, audio_output], | |
| fn=infer_from_audio, | |
| cache_examples=False, | |
| ) | |