Spaces:
Sleeping
Sleeping
nukopy commited on
Commit ·
df70e48
1
Parent(s): ae18720
feat: 計測用のログの追加
Browse files
apps/audio_cloning/cheched_vallex.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
|
|
|
| 10 |
import torch
|
| 11 |
|
| 12 |
from .vallex import main as vallex
|
|
@@ -132,6 +133,8 @@ def refresh_prompt_choices():
|
|
| 132 |
)
|
| 133 |
|
| 134 |
|
|
|
|
|
|
|
| 135 |
def infer_from_cached_prompt(
|
| 136 |
text: str,
|
| 137 |
language: str,
|
|
@@ -165,7 +168,7 @@ def infer_from_cached_prompt(
|
|
| 165 |
except Exception as err: # pylint: disable=broad-except
|
| 166 |
logger.exception("Failed to load cached prompt", exc_info=err)
|
| 167 |
return (f"プロンプトの読み込みに失敗しました: {err}", None)
|
| 168 |
-
timings.append(("
|
| 169 |
|
| 170 |
lang_pr = code2lang.get(lang_code, "en")
|
| 171 |
|
|
@@ -178,6 +181,9 @@ def infer_from_cached_prompt(
|
|
| 178 |
|
| 179 |
conditioned_text = f"{lang_token}{text}{lang_token}"
|
| 180 |
|
|
|
|
|
|
|
|
|
|
| 181 |
phone_tokens, langs = vallex.text_tokenizer.tokenize(
|
| 182 |
text=f"_{conditioned_text}".strip()
|
| 183 |
)
|
|
@@ -186,7 +192,7 @@ def infer_from_cached_prompt(
|
|
| 186 |
enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]])
|
| 187 |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
| 188 |
text_tokens_lens += enroll_x_lens
|
| 189 |
-
timings.append(("
|
| 190 |
|
| 191 |
vallex.model.to(vallex.device)
|
| 192 |
|
|
@@ -209,7 +215,7 @@ def infer_from_cached_prompt(
|
|
| 209 |
else token2lang[langdropdown2token[accent]],
|
| 210 |
best_of=5,
|
| 211 |
)
|
| 212 |
-
timings.append(("
|
| 213 |
logger.info("Inference completed")
|
| 214 |
|
| 215 |
start_time = time.perf_counter()
|
|
@@ -228,6 +234,9 @@ def infer_from_cached_prompt(
|
|
| 228 |
f"Synthesized text: {conditioned_text}"
|
| 229 |
)
|
| 230 |
|
|
|
|
|
|
|
|
|
|
| 231 |
timing_report = "\n↓\n".join(
|
| 232 |
f"{step}:{duration:.4f} sec" for step, duration in timings
|
| 233 |
)
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
| 10 |
+
import spaces
|
| 11 |
import torch
|
| 12 |
|
| 13 |
from .vallex import main as vallex
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
|
| 136 |
+
@spaces.GPU(duration=120)
|
| 137 |
+
@torch.no_grad()
|
| 138 |
def infer_from_cached_prompt(
|
| 139 |
text: str,
|
| 140 |
language: str,
|
|
|
|
| 168 |
except Exception as err: # pylint: disable=broad-except
|
| 169 |
logger.exception("Failed to load cached prompt", exc_info=err)
|
| 170 |
return (f"プロンプトの読み込みに失敗しました: {err}", None)
|
| 171 |
+
timings.append(("[cached] 話者特徴抽出", time.perf_counter() - start_time))
|
| 172 |
|
| 173 |
lang_pr = code2lang.get(lang_code, "en")
|
| 174 |
|
|
|
|
| 181 |
|
| 182 |
conditioned_text = f"{lang_token}{text}{lang_token}"
|
| 183 |
|
| 184 |
+
timings.append(("テキスト準備", time.perf_counter() - start_time))
|
| 185 |
+
|
| 186 |
+
start_time = time.perf_counter()
|
| 187 |
phone_tokens, langs = vallex.text_tokenizer.tokenize(
|
| 188 |
text=f"_{conditioned_text}".strip()
|
| 189 |
)
|
|
|
|
| 192 |
enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]])
|
| 193 |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
| 194 |
text_tokens_lens += enroll_x_lens
|
| 195 |
+
timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
|
| 196 |
|
| 197 |
vallex.model.to(vallex.device)
|
| 198 |
|
|
|
|
| 215 |
else token2lang[langdropdown2token[accent]],
|
| 216 |
best_of=5,
|
| 217 |
)
|
| 218 |
+
timings.append(("音響モデル推論", time.perf_counter() - start_time))
|
| 219 |
logger.info("Inference completed")
|
| 220 |
|
| 221 |
start_time = time.perf_counter()
|
|
|
|
| 234 |
f"Synthesized text: {conditioned_text}"
|
| 235 |
)
|
| 236 |
|
| 237 |
+
for step, duration in timings:
|
| 238 |
+
logger.info("%s:%.4f sec", step, duration)
|
| 239 |
+
|
| 240 |
timing_report = "\n↓\n".join(
|
| 241 |
f"{step}:{duration:.4f} sec" for step, duration in timings
|
| 242 |
)
|
apps/audio_cloning/vallex/main.py
CHANGED
|
@@ -373,7 +373,7 @@ def infer_from_audio(
|
|
| 373 |
if wav_pr.ndim == 1:
|
| 374 |
wav_pr = wav_pr.unsqueeze(0)
|
| 375 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 376 |
-
timings.append(("
|
| 377 |
|
| 378 |
start_time = time.perf_counter()
|
| 379 |
if transcript_content == "":
|
|
@@ -382,16 +382,14 @@ def infer_from_audio(
|
|
| 382 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 383 |
lang_token = lang2token[lang_pr]
|
| 384 |
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
|
| 385 |
-
timings.append(("プロンプト生成", time.perf_counter() - start_time))
|
| 386 |
|
| 387 |
-
start_time = time.perf_counter()
|
| 388 |
if language == "auto-detect":
|
| 389 |
lang_token = lang2token[langid.classify(text)[0]]
|
| 390 |
else:
|
| 391 |
lang_token = langdropdown2token[language]
|
| 392 |
lang = token2lang[lang_token]
|
| 393 |
text = lang_token + text + lang_token
|
| 394 |
-
timings.append(("
|
| 395 |
|
| 396 |
# onload model
|
| 397 |
model.to(device)
|
|
@@ -400,7 +398,7 @@ def infer_from_audio(
|
|
| 400 |
# tokenize audio
|
| 401 |
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 402 |
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
|
| 403 |
-
timings.append(("
|
| 404 |
|
| 405 |
start_time = time.perf_counter()
|
| 406 |
# tokenize text
|
|
@@ -415,7 +413,7 @@ def infer_from_audio(
|
|
| 415 |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
| 416 |
text_tokens_lens += enroll_x_lens
|
| 417 |
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
|
| 418 |
-
timings.append(("
|
| 419 |
|
| 420 |
start_time = time.perf_counter()
|
| 421 |
encoded_frames = model.inference(
|
|
@@ -429,7 +427,7 @@ def infer_from_audio(
|
|
| 429 |
text_language=langs if accent == "no-accent" else lang,
|
| 430 |
best_of=5,
|
| 431 |
)
|
| 432 |
-
timings.append(("
|
| 433 |
# Decode with Vocos
|
| 434 |
start_time = time.perf_counter()
|
| 435 |
frames = encoded_frames.permute(2, 0, 1)
|
|
@@ -437,6 +435,9 @@ def infer_from_audio(
|
|
| 437 |
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 438 |
timings.append(("ボコーダ復号", time.perf_counter() - start_time))
|
| 439 |
|
|
|
|
|
|
|
|
|
|
| 440 |
timing_report = "\n↓\n".join(
|
| 441 |
f"{step}:{duration:.4f} sec" for step, duration in timings
|
| 442 |
)
|
|
|
|
| 373 |
if wav_pr.ndim == 1:
|
| 374 |
wav_pr = wav_pr.unsqueeze(0)
|
| 375 |
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 376 |
+
timings.append(("前処理", time.perf_counter() - start_time))
|
| 377 |
|
| 378 |
start_time = time.perf_counter()
|
| 379 |
if transcript_content == "":
|
|
|
|
| 382 |
lang_pr = langid.classify(str(transcript_content))[0]
|
| 383 |
lang_token = lang2token[lang_pr]
|
| 384 |
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
|
|
|
|
| 385 |
|
|
|
|
| 386 |
if language == "auto-detect":
|
| 387 |
lang_token = lang2token[langid.classify(text)[0]]
|
| 388 |
else:
|
| 389 |
lang_token = langdropdown2token[language]
|
| 390 |
lang = token2lang[lang_token]
|
| 391 |
text = lang_token + text + lang_token
|
| 392 |
+
timings.append(("テキスト準備", time.perf_counter() - start_time))
|
| 393 |
|
| 394 |
# onload model
|
| 395 |
model.to(device)
|
|
|
|
| 398 |
# tokenize audio
|
| 399 |
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 400 |
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
|
| 401 |
+
timings.append(("話者特徴抽出", time.perf_counter() - start_time))
|
| 402 |
|
| 403 |
start_time = time.perf_counter()
|
| 404 |
# tokenize text
|
|
|
|
| 413 |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
| 414 |
text_tokens_lens += enroll_x_lens
|
| 415 |
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
|
| 416 |
+
timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
|
| 417 |
|
| 418 |
start_time = time.perf_counter()
|
| 419 |
encoded_frames = model.inference(
|
|
|
|
| 427 |
text_language=langs if accent == "no-accent" else lang,
|
| 428 |
best_of=5,
|
| 429 |
)
|
| 430 |
+
timings.append(("音響モデル推論", time.perf_counter() - start_time))
|
| 431 |
# Decode with Vocos
|
| 432 |
start_time = time.perf_counter()
|
| 433 |
frames = encoded_frames.permute(2, 0, 1)
|
|
|
|
| 435 |
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 436 |
timings.append(("ボコーダ復号", time.perf_counter() - start_time))
|
| 437 |
|
| 438 |
+
for step, duration in timings:
|
| 439 |
+
logger.info("%s:%.4f sec", step, duration)
|
| 440 |
+
|
| 441 |
timing_report = "\n↓\n".join(
|
| 442 |
f"{step}:{duration:.4f} sec" for step, duration in timings
|
| 443 |
)
|