Approximetal commited on
Commit
bec636c
Β·
verified Β·
1 Parent(s): 4fb7ead

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +21 -18
inference_gradio.py CHANGED
@@ -34,17 +34,20 @@ device = (
34
  )
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
 
 
 
37
 
38
  # HF location for pretrained assets (used as a fallback when local files are missing)
39
- HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/blob/main/pretrained_models"
40
- CKPTS_ROOT = os.path.join(HF_PRETRAINED_ROOT, "ckpts")
41
 
42
  # 1) ζŒ‡ε‘δ½ δ»“εΊ“ι‡Œηš„ libespeak-ng.so
43
- ESPEAK_LIB = os.path.join(HF_PRETRAINED_ROOT, "espeak-ng-lib", "libespeak-ng.so")
44
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
45
 
46
  # 2) ζŒ‡ε‘δ½ δ»“εΊ“ι‡Œηš„ espeak-ng-data
47
- ESPEAK_DATA_DIR = os.path.join(HF_PRETRAINED_ROOT, "espeak-ng-data")
48
  os.environ["ESPEAK_DATA_PATH"] = ESPEAK_DATA_DIR
49
  os.environ["ESPEAKNG_DATA_PATH"] = ESPEAK_DATA_DIR
50
 
@@ -88,7 +91,7 @@ class UVR5:
88
  return output_audio.squeeze().T.numpy(), 44100
89
 
90
  denoise_model = UVR5(
91
- model_dir=os.path.join(HF_PRETRAINED_ROOT, "uvr5"),
92
  code_dir=str(REPO_ROOT / "uvr5"),
93
  )
94
 
@@ -124,8 +127,8 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
124
  checkpoint_dir = [str(CKPTS_ROOT)]
125
  # Remote ckpt locations on HF (used if local ckpts are not present)
126
  remote_ckpts = {
127
- "multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
128
- "multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
129
  }
130
 
131
  if project_name is None:
@@ -178,7 +181,7 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
178
  def get_available_projects():
179
  """Get available project names from data directory"""
180
  data_paths = [
181
- str(HF_PRETRAINED_ROOT / "data"),
182
  ]
183
 
184
  project_list = []
@@ -235,13 +238,13 @@ def infer(
235
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
236
 
237
  # Resolve vocab file (local or HF)
238
- local_vocab = Path(HF_PRETRAINED_ROOT) / "data" / project / "vocab.txt"
239
  if local_vocab.is_file():
240
  vocab_file = str(local_vocab)
241
  else:
242
  remote_vocab_map = {
243
- "multilingual_grl": f"{HF_PRETRAINED_ROOT}/data/multilingual_grl/vocab.txt",
244
- "multilingual_prosody": f"{HF_PRETRAINED_ROOT}/data/multilingual_prosody/vocab.txt",
245
  }
246
  remote_vocab = remote_vocab_map.get(project)
247
  if remote_vocab is None:
@@ -259,13 +262,13 @@ def infer(
259
  prosody_cfg_path = str(local_prosody_cfg)
260
  else:
261
  prosody_cfg_path = str(
262
- cached_path(f"{HF_PRETRAINED_ROOT}/ckpts/prosody_encoder/pretssel_cfg.json")
263
  )
264
  if local_prosody_ckpt.is_file():
265
  prosody_ckpt_path = str(local_prosody_ckpt)
266
  else:
267
  prosody_ckpt_path = str(
268
- cached_path(f"{HF_PRETRAINED_ROOT}/ckpts/prosody_encoder/prosody_encoder_UnitY2.pt")
269
  )
270
 
271
  try:
@@ -479,13 +482,13 @@ with gr.Blocks(title="LEMAS-TTS Inference") as app:
479
 
480
  # Examples
481
  def _resolve_example(name: str) -> str:
482
- local = os.path.join(HF_PRETRAINED_ROOT, "data", "test_examples", name)
483
  if os.path.isfile(local):
484
  return local
485
  remote_map = {
486
- "en.wav": os.path.join(HF_PRETRAINED_ROOT, "data", "test_examples", "en.wav"),
487
- "es.wav": os.path.join(HF_PRETRAINED_ROOT, "data", "test_examples", "es.wav"),
488
- "pt.wav": os.path.join(HF_PRETRAINED_ROOT, "data", "test_examples", "pt.wav"),
489
  }
490
  url = remote_map.get(name)
491
  return str(cached_path(url)) if url is not None else ""
@@ -596,7 +599,7 @@ def main(port, host, share, api):
596
  server_port=port,
597
  share=share,
598
  show_api=api,
599
- allowed_paths=[str(os.path.join(HF_PRETRAINED_ROOT, "data"))],
600
  )
601
 
602
 
 
34
  )
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
37
+ # Local pretrained root (used when running from a repo / Space that bundles weights)
38
+ PRETRAINED_ROOT = str(REPO_ROOT / "pretrained_models")
39
+ CKPTS_ROOT = os.path.join(PRETRAINED_ROOT, "ckpts")
40
 
41
  # HF location for pretrained assets (used as a fallback when local files are missing)
42
+ # HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/blob/main/pretrained_models"
43
+
44
 
45
  # 1) ζŒ‡ε‘δ½ δ»“εΊ“ι‡Œηš„ libespeak-ng.so
46
+ ESPEAK_LIB = os.path.join(PRETRAINED_ROOT, "espeak-ng-lib", "libespeak-ng.so")
47
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
48
 
49
  # 2) ζŒ‡ε‘δ½ δ»“εΊ“ι‡Œηš„ espeak-ng-data
50
+ ESPEAK_DATA_DIR = os.path.join(PRETRAINED_ROOT, "espeak-ng-data")
51
  os.environ["ESPEAK_DATA_PATH"] = ESPEAK_DATA_DIR
52
  os.environ["ESPEAKNG_DATA_PATH"] = ESPEAK_DATA_DIR
53
 
 
91
  return output_audio.squeeze().T.numpy(), 44100
92
 
93
  denoise_model = UVR5(
94
+ model_dir=os.path.join(PRETRAINED_ROOT, "uvr5"),
95
  code_dir=str(REPO_ROOT / "uvr5"),
96
  )
97
 
 
127
  checkpoint_dir = [str(CKPTS_ROOT)]
128
  # Remote ckpt locations on HF (used if local ckpts are not present)
129
  remote_ckpts = {
130
+ "multilingual_grl": f"{PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
131
+ "multilingual_prosody": f"{PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
132
  }
133
 
134
  if project_name is None:
 
181
  def get_available_projects():
182
  """Get available project names from data directory"""
183
  data_paths = [
184
+ str(PRETRAINED_ROOT / "data"),
185
  ]
186
 
187
  project_list = []
 
238
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
239
 
240
  # Resolve vocab file (local or HF)
241
+ local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
242
  if local_vocab.is_file():
243
  vocab_file = str(local_vocab)
244
  else:
245
  remote_vocab_map = {
246
+ "multilingual_grl": f"{PRETRAINED_ROOT}/data/multilingual_grl/vocab.txt",
247
+ "multilingual_prosody": f"{PRETRAINED_ROOT}/data/multilingual_prosody/vocab.txt",
248
  }
249
  remote_vocab = remote_vocab_map.get(project)
250
  if remote_vocab is None:
 
262
  prosody_cfg_path = str(local_prosody_cfg)
263
  else:
264
  prosody_cfg_path = str(
265
+ cached_path(f"{PRETRAINED_ROOT}/ckpts/prosody_encoder/pretssel_cfg.json")
266
  )
267
  if local_prosody_ckpt.is_file():
268
  prosody_ckpt_path = str(local_prosody_ckpt)
269
  else:
270
  prosody_ckpt_path = str(
271
+ cached_path(f"{PRETRAINED_ROOT}/ckpts/prosody_encoder/prosody_encoder_UnitY2.pt")
272
  )
273
 
274
  try:
 
482
 
483
  # Examples
484
  def _resolve_example(name: str) -> str:
485
+ local = os.path.join(PRETRAINED_ROOT, "data", "test_examples", name)
486
  if os.path.isfile(local):
487
  return local
488
  remote_map = {
489
+ "en.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "en.wav"),
490
+ "es.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "es.wav"),
491
+ "pt.wav": os.path.join(PRETRAINED_ROOT, "data", "test_examples", "pt.wav"),
492
  }
493
  url = remote_map.get(name)
494
  return str(cached_path(url)) if url is not None else ""
 
599
  server_port=port,
600
  share=share,
601
  show_api=api,
602
+ allowed_paths=[str(os.path.join(PRETRAINED_ROOT, "data"))],
603
  )
604
 
605