Approximetal commited on
Commit
a73b16c
·
verified ·
1 Parent(s): 9e98ef8

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +25 -3
inference_gradio.py CHANGED
@@ -12,6 +12,8 @@ import torchaudio
12
  import soundfile as sf
13
  from pathlib import Path
14
 
 
 
15
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
16
 
17
  # Global variables
@@ -33,6 +35,9 @@ device = (
33
 
34
  REPO_ROOT = Path(__file__).resolve().parent
35
 
 
 
 
36
  # 1) 指向 `pretrained_models` 里的 libespeak-ng.so(本地路径)
37
  ESPEAK_LIB = Path(PRETRAINED_ROOT) / "espeak-ng-lib" / "libespeak-ng.so"
38
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
@@ -117,6 +122,11 @@ def cancel_denoise(audio_info):
117
  def get_checkpoints_project(project_name=None, is_gradio=True):
118
  """Get available checkpoint files"""
119
  checkpoint_dir = [str(CKPTS_ROOT)]
 
 
 
 
 
120
 
121
  if project_name is None:
122
  # Look for checkpoints in local directory
@@ -126,12 +136,17 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
126
  files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
127
  files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
128
  break
 
 
 
129
  else:
130
  if os.path.isdir(checkpoint_dir[0]):
131
  files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
132
  files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
133
  else:
134
- files_checkpoints = []
 
 
135
  print("files_checkpoints:", project_name, files_checkpoints)
136
  # Separate pretrained and regular checkpoints
137
  pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
@@ -188,9 +203,16 @@ def infer(
188
  ):
189
  global last_checkpoint, last_device, tts_api, last_ema
190
 
191
- # Resolve checkpoint path (local or HF-style, though we now rely on local PRETRAINED_ROOT)
192
  ckpt_path = file_checkpoint
193
- ckpt_resolved = ckpt_path
 
 
 
 
 
 
 
194
 
195
  if not os.path.isfile(ckpt_resolved):
196
  return None, "Checkpoint not found!", ""
 
12
  import soundfile as sf
13
  from pathlib import Path
14
 
15
+ from cached_path import cached_path
16
+
17
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
18
 
19
  # Global variables
 
35
 
36
  REPO_ROOT = Path(__file__).resolve().parent
37
 
38
+ # HF location for large TTS checkpoints (too big for Space storage)
39
+ HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
40
+
41
  # 1) 指向 `pretrained_models` 里的 libespeak-ng.so(本地路径)
42
  ESPEAK_LIB = Path(PRETRAINED_ROOT) / "espeak-ng-lib" / "libespeak-ng.so"
43
  os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
 
122
  def get_checkpoints_project(project_name=None, is_gradio=True):
123
  """Get available checkpoint files"""
124
  checkpoint_dir = [str(CKPTS_ROOT)]
125
+ # Remote ckpt locations on HF (used when local ckpts are not present)
126
+ remote_ckpts = {
127
+ "multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
128
+ "multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
129
+ }
130
 
131
  if project_name is None:
132
  # Look for checkpoints in local directory
 
136
  files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
137
  files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
138
  break
139
+ # Fallback to remote ckpts if none found locally
140
+ if not files_checkpoints:
141
+ files_checkpoints = list(remote_ckpts.values())
142
  else:
143
  if os.path.isdir(checkpoint_dir[0]):
144
  files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
145
  files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
146
  else:
147
+ # No local ckpts for this project, try remote mapping
148
+ ckpt = remote_ckpts.get(project_name)
149
+ files_checkpoints = [ckpt] if ckpt is not None else []
150
  print("files_checkpoints:", project_name, files_checkpoints)
151
  # Separate pretrained and regular checkpoints
152
  pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
 
203
  ):
204
  global last_checkpoint, last_device, tts_api, last_ema
205
 
206
+ # Resolve checkpoint path (local or HF URL)
207
  ckpt_path = file_checkpoint
208
+ if isinstance(ckpt_path, str) and ckpt_path.startswith("hf://"):
209
+ try:
210
+ ckpt_resolved = str(cached_path(ckpt_path))
211
+ except Exception as e:
212
+ traceback.print_exc()
213
+ return None, f"Error downloading checkpoint: {str(e)}", ""
214
+ else:
215
+ ckpt_resolved = ckpt_path
216
 
217
  if not os.path.isfile(ckpt_resolved):
218
  return None, "Checkpoint not found!", ""