nukopy commited on
Commit
df70e48
·
1 Parent(s): ae18720

feat: 計測用のログの追加

Browse files
apps/audio_cloning/cheched_vallex.py CHANGED
@@ -7,6 +7,7 @@ from typing import List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import numpy as np
 
10
  import torch
11
 
12
  from .vallex import main as vallex
@@ -132,6 +133,8 @@ def refresh_prompt_choices():
132
  )
133
 
134
 
 
 
135
  def infer_from_cached_prompt(
136
  text: str,
137
  language: str,
@@ -165,7 +168,7 @@ def infer_from_cached_prompt(
165
  except Exception as err: # pylint: disable=broad-except
166
  logger.exception("Failed to load cached prompt", exc_info=err)
167
  return (f"プロンプトの読み込みに失敗しました: {err}", None)
168
- timings.append(("プロンプト読込", time.perf_counter() - start_time))
169
 
170
  lang_pr = code2lang.get(lang_code, "en")
171
 
@@ -178,6 +181,9 @@ def infer_from_cached_prompt(
178
 
179
  conditioned_text = f"{lang_token}{text}{lang_token}"
180
 
 
 
 
181
  phone_tokens, langs = vallex.text_tokenizer.tokenize(
182
  text=f"_{conditioned_text}".strip()
183
  )
@@ -186,7 +192,7 @@ def infer_from_cached_prompt(
186
  enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]])
187
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
188
  text_tokens_lens += enroll_x_lens
189
- timings.append(("テキスト準備", time.perf_counter() - start_time))
190
 
191
  vallex.model.to(vallex.device)
192
 
@@ -209,7 +215,7 @@ def infer_from_cached_prompt(
209
  else token2lang[langdropdown2token[accent]],
210
  best_of=5,
211
  )
212
- timings.append(("モデル推論", time.perf_counter() - start_time))
213
  logger.info("Inference completed")
214
 
215
  start_time = time.perf_counter()
@@ -228,6 +234,9 @@ def infer_from_cached_prompt(
228
  f"Synthesized text: {conditioned_text}"
229
  )
230
 
 
 
 
231
  timing_report = "\n↓\n".join(
232
  f"{step}:{duration:.4f} sec" for step, duration in timings
233
  )
 
7
 
8
  import gradio as gr
9
  import numpy as np
10
+ import spaces
11
  import torch
12
 
13
  from .vallex import main as vallex
 
133
  )
134
 
135
 
136
+ @spaces.GPU(duration=120)
137
+ @torch.no_grad()
138
  def infer_from_cached_prompt(
139
  text: str,
140
  language: str,
 
168
  except Exception as err: # pylint: disable=broad-except
169
  logger.exception("Failed to load cached prompt", exc_info=err)
170
  return (f"プロンプトの読み込みに失敗しました: {err}", None)
171
+ timings.append(("[cached] 話者特徴抽出", time.perf_counter() - start_time))
172
 
173
  lang_pr = code2lang.get(lang_code, "en")
174
 
 
181
 
182
  conditioned_text = f"{lang_token}{text}{lang_token}"
183
 
184
+ timings.append(("テキスト準備", time.perf_counter() - start_time))
185
+
186
+ start_time = time.perf_counter()
187
  phone_tokens, langs = vallex.text_tokenizer.tokenize(
188
  text=f"_{conditioned_text}".strip()
189
  )
 
192
  enroll_x_lens = torch.IntTensor([text_prompts.shape[-1]])
193
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
194
  text_tokens_lens += enroll_x_lens
195
+ timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
196
 
197
  vallex.model.to(vallex.device)
198
 
 
215
  else token2lang[langdropdown2token[accent]],
216
  best_of=5,
217
  )
218
+ timings.append(("音響モデル推論", time.perf_counter() - start_time))
219
  logger.info("Inference completed")
220
 
221
  start_time = time.perf_counter()
 
234
  f"Synthesized text: {conditioned_text}"
235
  )
236
 
237
+ for step, duration in timings:
238
+ logger.info("%s:%.4f sec", step, duration)
239
+
240
  timing_report = "\n↓\n".join(
241
  f"{step}:{duration:.4f} sec" for step, duration in timings
242
  )
apps/audio_cloning/vallex/main.py CHANGED
@@ -373,7 +373,7 @@ def infer_from_audio(
373
  if wav_pr.ndim == 1:
374
  wav_pr = wav_pr.unsqueeze(0)
375
  assert wav_pr.ndim and wav_pr.size(0) == 1
376
- timings.append(("音声前処理", time.perf_counter() - start_time))
377
 
378
  start_time = time.perf_counter()
379
  if transcript_content == "":
@@ -382,16 +382,14 @@ def infer_from_audio(
382
  lang_pr = langid.classify(str(transcript_content))[0]
383
  lang_token = lang2token[lang_pr]
384
  text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
385
- timings.append(("プロンプト生成", time.perf_counter() - start_time))
386
 
387
- start_time = time.perf_counter()
388
  if language == "auto-detect":
389
  lang_token = lang2token[langid.classify(text)[0]]
390
  else:
391
  lang_token = langdropdown2token[language]
392
  lang = token2lang[lang_token]
393
  text = lang_token + text + lang_token
394
- timings.append(("言語設定", time.perf_counter() - start_time))
395
 
396
  # onload model
397
  model.to(device)
@@ -400,7 +398,7 @@ def infer_from_audio(
400
  # tokenize audio
401
  encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
402
  audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
403
- timings.append(("音声トークナイズ", time.perf_counter() - start_time))
404
 
405
  start_time = time.perf_counter()
406
  # tokenize text
@@ -415,7 +413,7 @@ def infer_from_audio(
415
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
416
  text_tokens_lens += enroll_x_lens
417
  lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
418
- timings.append(("テキストトークナイズ", time.perf_counter() - start_time))
419
 
420
  start_time = time.perf_counter()
421
  encoded_frames = model.inference(
@@ -429,7 +427,7 @@ def infer_from_audio(
429
  text_language=langs if accent == "no-accent" else lang,
430
  best_of=5,
431
  )
432
- timings.append(("モデル推論", time.perf_counter() - start_time))
433
  # Decode with Vocos
434
  start_time = time.perf_counter()
435
  frames = encoded_frames.permute(2, 0, 1)
@@ -437,6 +435,9 @@ def infer_from_audio(
437
  samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
438
  timings.append(("ボコーダ復号", time.perf_counter() - start_time))
439
 
 
 
 
440
  timing_report = "\n↓\n".join(
441
  f"{step}:{duration:.4f} sec" for step, duration in timings
442
  )
 
373
  if wav_pr.ndim == 1:
374
  wav_pr = wav_pr.unsqueeze(0)
375
  assert wav_pr.ndim and wav_pr.size(0) == 1
376
+ timings.append(("前処理", time.perf_counter() - start_time))
377
 
378
  start_time = time.perf_counter()
379
  if transcript_content == "":
 
382
  lang_pr = langid.classify(str(transcript_content))[0]
383
  lang_token = lang2token[lang_pr]
384
  text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
 
385
 
 
386
  if language == "auto-detect":
387
  lang_token = lang2token[langid.classify(text)[0]]
388
  else:
389
  lang_token = langdropdown2token[language]
390
  lang = token2lang[lang_token]
391
  text = lang_token + text + lang_token
392
+ timings.append(("テキスト準備", time.perf_counter() - start_time))
393
 
394
  # onload model
395
  model.to(device)
 
398
  # tokenize audio
399
  encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
400
  audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
401
+ timings.append(("話者特徴抽出", time.perf_counter() - start_time))
402
 
403
  start_time = time.perf_counter()
404
  # tokenize text
 
413
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
414
  text_tokens_lens += enroll_x_lens
415
  lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
416
+ timings.append(("音素化/トークナイズ", time.perf_counter() - start_time))
417
 
418
  start_time = time.perf_counter()
419
  encoded_frames = model.inference(
 
427
  text_language=langs if accent == "no-accent" else lang,
428
  best_of=5,
429
  )
430
+ timings.append(("音響モデル推論", time.perf_counter() - start_time))
431
  # Decode with Vocos
432
  start_time = time.perf_counter()
433
  frames = encoded_frames.permute(2, 0, 1)
 
435
  samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
436
  timings.append(("ボコーダ復号", time.perf_counter() - start_time))
437
 
438
+ for step, duration in timings:
439
+ logger.info("%s:%.4f sec", step, duration)
440
+
441
  timing_report = "\n↓\n".join(
442
  f"{step}:{duration:.4f} sec" for step, duration in timings
443
  )