Approximetal commited on
Commit
2a1e401
·
verified ·
1 Parent(s): b90bc68

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +31 -82
inference_gradio.py CHANGED
@@ -12,9 +12,7 @@ import torchaudio
12
  import soundfile as sf
13
  from pathlib import Path
14
 
15
- from cached_path import cached_path
16
-
17
- from lemas_tts.api import TTS
18
 
19
  # Global variables
20
  tts_api = None
@@ -34,21 +32,15 @@ device = (
34
  )
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
37
- # Local pretrained root (used when running from a repo / Space that bundles weights)
38
- # PRETRAINED_ROOT = str(REPO_ROOT / "pretrained_models")
39
-
40
- # HF location for pretrained assets (used as a fallback when local files are missing)
41
- PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
42
- CKPTS_ROOT = os.path.join(PRETRAINED_ROOT, "ckpts")
43
 
44
- # 1) 指向你仓库里的 libespeak-ng.so
45
- ESPEAK_LIB = os.path.join(PRETRAINED_ROOT, "espeak-ng-lib", "libespeak-ng.so")
46
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
47
 
48
- # 2) 指向你仓库里的 espeak-ng-data
49
- ESPEAK_DATA_DIR = os.path.join(PRETRAINED_ROOT, "espeak-ng-data")
50
- os.environ["ESPEAK_DATA_PATH"] = ESPEAK_DATA_DIR
51
- os.environ["ESPEAKNG_DATA_PATH"] = ESPEAK_DATA_DIR
52
 
53
 
54
  class UVR5:
@@ -68,7 +60,8 @@ class UVR5:
68
 
69
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
70
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
71
- configs = json.loads(open(config_path, "r", encoding="utf-8").read())
 
72
  model_data = ModelData(
73
  model_path=model_path,
74
  audio_path=model_dir,
@@ -90,8 +83,8 @@ class UVR5:
90
  return output_audio.squeeze().T.numpy(), 44100
91
 
92
  denoise_model = UVR5(
93
- model_dir=cached_path(os.path.join(PRETRAINED_ROOT, "uvr5")),
94
- code_dir=cached_path(str(REPO_ROOT / "uvr5")),
95
  )
96
 
97
  def load_wav(audio_info, sr=16000, channel=1):
@@ -124,11 +117,6 @@ def cancel_denoise(audio_info):
124
  def get_checkpoints_project(project_name=None, is_gradio=True):
125
  """Get available checkpoint files"""
126
  checkpoint_dir = [str(CKPTS_ROOT)]
127
- # Remote ckpt locations on HF (used if local ckpts are not present)
128
- remote_ckpts = {
129
- "multilingual_grl": f"{PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
130
- "multilingual_prosody": f"{PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
131
- }
132
 
133
  if project_name is None:
134
  # Look for checkpoints in local directory
@@ -138,16 +126,12 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
138
  files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
139
  files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
140
  break
141
- # Fallback: use HF ckpts
142
- if not files_checkpoints:
143
- files_checkpoints = list(remote_ckpts.values())
144
  else:
145
  if os.path.isdir(checkpoint_dir[0]):
146
  files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
147
  files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
148
  else:
149
- ckpt = remote_ckpts.get(project_name)
150
- files_checkpoints = [ckpt] if ckpt is not None else []
151
  print("files_checkpoints:", project_name, files_checkpoints)
152
  # Separate pretrained and regular checkpoints
153
  pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
@@ -180,7 +164,7 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
180
  def get_available_projects():
181
  """Get available project names from data directory"""
182
  data_paths = [
183
- cached_path(str(PRETRAINED_ROOT / "data")),
184
  ]
185
 
186
  project_list = []
@@ -204,16 +188,9 @@ def infer(
204
  ):
205
  global last_checkpoint, last_device, tts_api, last_ema
206
 
207
- # Resolve checkpoint path (local or HF)
208
  ckpt_path = file_checkpoint
209
- if isinstance(ckpt_path, str) and ckpt_path.startswith("hf://"):
210
- try:
211
- ckpt_resolved = str(cached_path(ckpt_path))
212
- except Exception as e:
213
- traceback.print_exc()
214
- return None, f"Error downloading checkpoint: {str(e)}", ""
215
- else:
216
- ckpt_resolved = ckpt_path
217
 
218
  if not os.path.isfile(ckpt_resolved):
219
  return None, "Checkpoint not found!", ""
@@ -236,39 +213,19 @@ def infer(
236
  # Automatically enable prosody encoder when using the prosody checkpoint
237
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
238
 
239
- # Resolve vocab file (local or HF)
240
- local_vocab = cached_path(str(PRETRAINED_ROOT / "data" / project / "vocab.txt"))
241
- if local_vocab.is_file():
242
- vocab_file = str(local_vocab)
243
- else:
244
- remote_vocab_map = {
245
- "multilingual_grl": cached_path(f"{PRETRAINED_ROOT}/data/multilingual_grl/vocab.txt"),
246
- "multilingual_prosody": cached_path(f"{PRETRAINED_ROOT}/data/multilingual_prosody/vocab.txt"),
247
- }
248
- remote_vocab = remote_vocab_map.get(project)
249
- if remote_vocab is None:
250
- return None, "Vocab file not found!", ""
251
- try:
252
- vocab_file = str(cached_path(remote_vocab))
253
- except Exception as e:
254
- traceback.print_exc()
255
- return None, f"Error downloading vocab: {str(e)}", ""
256
-
257
- # Resolve prosody encoder config & weights
258
- local_prosody_cfg = CKPTS_ROOT / "prosody_encoder" / "pretssel_cfg.json"
259
- local_prosody_ckpt = CKPTS_ROOT / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
260
- if local_prosody_cfg.is_file():
261
- prosody_cfg_path = str(local_prosody_cfg)
262
- else:
263
- prosody_cfg_path = str(
264
- cached_path(f"{PRETRAINED_ROOT}/ckpts/prosody_encoder/pretssel_cfg.json")
265
- )
266
- if local_prosody_ckpt.is_file():
267
- prosody_ckpt_path = str(local_prosody_ckpt)
268
- else:
269
- prosody_ckpt_path = str(
270
- cached_path(f"{PRETRAINED_ROOT}/ckpts/prosody_encoder/prosody_encoder_UnitY2.pt")
271
- )
272
 
273
  try:
274
  tts_api = TTS(
@@ -481,16 +438,8 @@ with gr.Blocks(title="LEMAS-TTS Inference") as app:
481
 
482
  # Examples
483
  def _resolve_example(name: str) -> str:
484
- local = cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", name))
485
- if os.path.isfile(local):
486
- return local
487
- remote_map = {
488
- "en.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "en.wav")),
489
- "es.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "es.wav")),
490
- "pt.wav": cached_path(os.path.join(PRETRAINED_ROOT, "data", "test_examples", "pt.wav")),
491
- }
492
- url = remote_map.get(name)
493
- return str(cached_path(url)) if url is not None else ""
494
 
495
  examples = gr.Examples(
496
  examples=[
@@ -598,7 +547,7 @@ def main(port, host, share, api):
598
  server_port=port,
599
  share=share,
600
  show_api=api,
601
- allowed_paths=[str(cached_path(os.path.join(PRETRAINED_ROOT, "data")))],
602
  )
603
 
604
 
 
12
  import soundfile as sf
13
  from pathlib import Path
14
 
15
+ from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
 
 
16
 
17
  # Global variables
18
  tts_api = None
 
32
  )
33
 
34
  REPO_ROOT = Path(__file__).resolve().parent
 
 
 
 
 
 
35
 
36
+ # 1) 指向 `pretrained_models` 里的 libespeak-ng.so(本地路径)
37
+ ESPEAK_LIB = Path(PRETRAINED_ROOT) / "espeak-ng-lib" / "libespeak-ng.so"
38
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
39
 
40
+ # 2) 指向 `pretrained_models` 里的 espeak-ng-data(本地路径)
41
+ ESPEAK_DATA_DIR = Path(PRETRAINED_ROOT) / "espeak-ng-data"
42
+ os.environ["ESPEAK_DATA_PATH"] = str(ESPEAK_DATA_DIR)
43
+ os.environ["ESPEAKNG_DATA_PATH"] = str(ESPEAK_DATA_DIR)
44
 
45
 
46
  class UVR5:
 
60
 
61
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
62
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
63
+ with open(config_path, "r", encoding="utf-8") as f:
64
+ configs = json.load(f)
65
  model_data = ModelData(
66
  model_path=model_path,
67
  audio_path=model_dir,
 
83
  return output_audio.squeeze().T.numpy(), 44100
84
 
85
  denoise_model = UVR5(
86
+ model_dir=str(Path(PRETRAINED_ROOT) / "uvr5"),
87
+ code_dir=str(REPO_ROOT / "uvr5"),
88
  )
89
 
90
  def load_wav(audio_info, sr=16000, channel=1):
 
117
  def get_checkpoints_project(project_name=None, is_gradio=True):
118
  """Get available checkpoint files"""
119
  checkpoint_dir = [str(CKPTS_ROOT)]
 
 
 
 
 
120
 
121
  if project_name is None:
122
  # Look for checkpoints in local directory
 
126
  files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
127
  files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
128
  break
 
 
 
129
  else:
130
  if os.path.isdir(checkpoint_dir[0]):
131
  files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
132
  files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
133
  else:
134
+ files_checkpoints = []
 
135
  print("files_checkpoints:", project_name, files_checkpoints)
136
  # Separate pretrained and regular checkpoints
137
  pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
 
164
  def get_available_projects():
165
  """Get available project names from data directory"""
166
  data_paths = [
167
+ str(Path(PRETRAINED_ROOT) / "data"),
168
  ]
169
 
170
  project_list = []
 
188
  ):
189
  global last_checkpoint, last_device, tts_api, last_ema
190
 
191
+ # Resolve checkpoint path (local or HF-style, though we now rely on local PRETRAINED_ROOT)
192
  ckpt_path = file_checkpoint
193
+ ckpt_resolved = ckpt_path
 
 
 
 
 
 
 
194
 
195
  if not os.path.isfile(ckpt_resolved):
196
  return None, "Checkpoint not found!", ""
 
213
  # Automatically enable prosody encoder when using the prosody checkpoint
214
  use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
215
 
216
+ # Resolve vocab file (local)
217
+ local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
218
+ if not local_vocab.is_file():
219
+ return None, "Vocab file not found!", ""
220
+ vocab_file = str(local_vocab)
221
+
222
+ # Resolve prosody encoder config & weights (local)
223
+ local_prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
224
+ local_prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
225
+ if not local_prosody_cfg.is_file() or not local_prosody_ckpt.is_file():
226
+ return None, "Prosody encoder files not found!", ""
227
+ prosody_cfg_path = str(local_prosody_cfg)
228
+ prosody_ckpt_path = str(local_prosody_ckpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  try:
231
  tts_api = TTS(
 
438
 
439
  # Examples
440
  def _resolve_example(name: str) -> str:
441
+ local = Path(PRETRAINED_ROOT) / "data" / "test_examples" / name
442
+ return str(local) if local.is_file() else ""
 
 
 
 
 
 
 
 
443
 
444
  examples = gr.Examples(
445
  examples=[
 
547
  server_port=port,
548
  share=share,
549
  show_api=api,
550
+ allowed_paths=[str(Path(PRETRAINED_ROOT) / "data")],
551
  )
552
 
553