Approximetal commited on
Commit
b90bc68
·
verified ·
1 Parent(s): bec636c

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +14 -15
inference_gradio.py CHANGED
@@ -35,12 +35,11 @@ device = (
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
37
  # Local pretrained root (used when running from a repo / Space that bundles weights)
38
- PRETRAINED_ROOT = str(REPO_ROOT / "pretrained_models")
39
- CKPTS_ROOT = os.path.join(PRETRAINED_ROOT, "ckpts")
40
 
41
  # HF location for pretrained assets (used as a fallback when local files are missing)
42
- # HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/blob/main/pretrained_models"
43
-
44
 
45
  # 1) 指向你仓库里的 libespeak-ng.so
46
  ESPEAK_LIB = os.path.join(PRETRAINED_ROOT, "espeak-ng-lib", "libespeak-ng.so")
@@ -91,8 +90,8 @@ class UVR5:
91
  return output_audio.squeeze().T.numpy(), 44100
92
 
93
  denoise_model = UVR5(
94
- model_dir=os.path.join(PRETRAINED_ROOT, "uvr5"),
95
- code_dir=str(REPO_ROOT / "uvr5"),
96
  )
97
 
98
  def load_wav(audio_info, sr=16000, channel=1):
@@ -181,7 +180,7 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
181
  def get_available_projects():
182
  """Get available project names from data directory"""
183
  data_paths = [
184
- str(PRETRAINED_ROOT / "data"),
185
  ]
186
 
187
  project_list = []
@@ -238,13 +237,13 @@ def infer(
238
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
239
 
240
  # Resolve vocab file (local or HF)
241
- local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
242
  if local_vocab.is_file():
243
  vocab_file = str(local_vocab)
244
  else:
245
  remote_vocab_map = {
246
- "multilingual_grl": f"{PRETRAINED_ROOT}/data/multilingual_grl/vocab.txt",
247
- "multilingual_prosody": f"{PRETRAINED_ROOT}/data/multilingual_prosody/vocab.txt",
248
  }
249
  remote_vocab = remote_vocab_map.get(project)
250
  if remote_vocab is None:
@@ -482,13 +481,13 @@ with gr.Blocks(title="LEMAS-TTS Inference") as app:
482
 
483
  # Examples
484
  def _resolve_example(name: str) -> str:
485
- local = os.path.join(PRETRAINED_ROOT, "data", "test_examples", name)
486
  if os.path.isfile(local):
487
  return local
488
  remote_map = {
489
- "en.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "en.wav"),
490
- "es.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "es.wav"),
491
- "pt.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "pt.wav"),
492
  }
493
  url = remote_map.get(name)
494
  return str(cached_path(url)) if url is not None else ""
@@ -599,7 +598,7 @@ def main(port, host, share, api):
599
  server_port=port,
600
  share=share,
601
  show_api=api,
602
- allowed_paths=[str(os.path.join(PRETRAINED_ROOT, "data"))],
603
  )
604
 
605
 
 
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
37
  # Local pretrained root (used when running from a repo / Space that bundles weights)
38
+ # PRETRAINED_ROOT = str(REPO_ROOT / "pretrained_models")
 
39
 
40
  # HF location for pretrained assets (used as a fallback when local files are missing)
41
+ PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
42
+ CKPTS_ROOT = os.path.join(PRETRAINED_ROOT, "ckpts")
43
 
44
  # 1) 指向你仓库里的 libespeak-ng.so
45
  ESPEAK_LIB = os.path.join(PRETRAINED_ROOT, "espeak-ng-lib", "libespeak-ng.so")
 
90
  return output_audio.squeeze().T.numpy(), 44100
91
 
92
  denoise_model = UVR5(
93
+ model_dir=cached_path(os.path.join(PRETRAINED_ROOT, "uvr5")),
94
+ code_dir=cached_path(str(REPO_ROOT / "uvr5")),
95
  )
96
 
97
  def load_wav(audio_info, sr=16000, channel=1):
 
180
  def get_available_projects():
181
  """Get available project names from data directory"""
182
  data_paths = [
183
+ cached_path(str(PRETRAINED_ROOT / "data")),
184
  ]
185
 
186
  project_list = []
 
237
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
238
 
239
  # Resolve vocab file (local or HF)
240
+ local_vocab = cached_path(str(PRETRAINED_ROOT / "data" / project / "vocab.txt"))
241
  if local_vocab.is_file():
242
  vocab_file = str(local_vocab)
243
  else:
244
  remote_vocab_map = {
245
+ "multilingual_grl": cached_path(f"{PRETRAINED_ROOT}/data/multilingual_grl/vocab.txt"),
246
+ "multilingual_prosody": cached_path(f"{PRETRAINED_ROOT}/data/multilingual_prosody/vocab.txt"),
247
  }
248
  remote_vocab = remote_vocab_map.get(project)
249
  if remote_vocab is None:
 
481
 
482
  # Examples
483
  def _resolve_example(name: str) -> str:
484
+ local = cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", name))
485
  if os.path.isfile(local):
486
  return local
487
  remote_map = {
488
+ "en.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "en.wav")),
489
+ "es.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "es.wav")),
490
+ "pt.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "pt.wav")),
491
  }
492
  url = remote_map.get(name)
493
  return str(cached_path(url)) if url is not None else ""
 
598
  server_port=port,
599
  share=share,
600
  show_api=api,
601
+ allowed_paths=[str(cached_path(os.path.join(PRETRAINED_ROOT, "data")))],
602
  )
603
 
604