Spaces:
Running
on
Zero
Running
on
Zero
Update inference_gradio.py
Browse files- inference_gradio.py +25 -3
inference_gradio.py
CHANGED
|
@@ -12,6 +12,8 @@ import torchaudio
|
|
| 12 |
import soundfile as sf
|
| 13 |
from pathlib import Path
|
| 14 |
|
|
|
|
|
|
|
| 15 |
from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
|
| 16 |
|
| 17 |
# Global variables
|
|
@@ -33,6 +35,9 @@ device = (
|
|
| 33 |
|
| 34 |
REPO_ROOT = Path(__file__).resolve().parent
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
# 1) 指向 `pretrained_models` 里的 libespeak-ng.so(本地路径)
|
| 37 |
ESPEAK_LIB = Path(PRETRAINED_ROOT) / "espeak-ng-lib" / "libespeak-ng.so"
|
| 38 |
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
|
|
@@ -117,6 +122,11 @@ def cancel_denoise(audio_info):
|
|
| 117 |
def get_checkpoints_project(project_name=None, is_gradio=True):
|
| 118 |
"""Get available checkpoint files"""
|
| 119 |
checkpoint_dir = [str(CKPTS_ROOT)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
if project_name is None:
|
| 122 |
# Look for checkpoints in local directory
|
|
@@ -126,12 +136,17 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
|
|
| 126 |
files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
|
| 127 |
files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
|
| 128 |
break
|
|
|
|
|
|
|
|
|
|
| 129 |
else:
|
| 130 |
if os.path.isdir(checkpoint_dir[0]):
|
| 131 |
files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
|
| 132 |
files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
|
| 133 |
else:
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
print("files_checkpoints:", project_name, files_checkpoints)
|
| 136 |
# Separate pretrained and regular checkpoints
|
| 137 |
pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
|
|
@@ -188,9 +203,16 @@ def infer(
|
|
| 188 |
):
|
| 189 |
global last_checkpoint, last_device, tts_api, last_ema
|
| 190 |
|
| 191 |
-
# Resolve checkpoint path (local or HF
|
| 192 |
ckpt_path = file_checkpoint
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
if not os.path.isfile(ckpt_resolved):
|
| 196 |
return None, "Checkpoint not found!", ""
|
|
|
|
| 12 |
import soundfile as sf
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
+
from cached_path import cached_path
|
| 16 |
+
|
| 17 |
from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
|
| 18 |
|
| 19 |
# Global variables
|
|
|
|
| 35 |
|
| 36 |
REPO_ROOT = Path(__file__).resolve().parent
|
| 37 |
|
| 38 |
+
# HF location for large TTS checkpoints (too big for Space storage)
|
| 39 |
+
HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
|
| 40 |
+
|
| 41 |
# 1) 指向 `pretrained_models` 里的 libespeak-ng.so(本地路径)
|
| 42 |
ESPEAK_LIB = Path(PRETRAINED_ROOT) / "espeak-ng-lib" / "libespeak-ng.so"
|
| 43 |
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = str(ESPEAK_LIB)
|
|
|
|
| 122 |
def get_checkpoints_project(project_name=None, is_gradio=True):
|
| 123 |
"""Get available checkpoint files"""
|
| 124 |
checkpoint_dir = [str(CKPTS_ROOT)]
|
| 125 |
+
# Remote ckpt locations on HF (used when local ckpts are not present)
|
| 126 |
+
remote_ckpts = {
|
| 127 |
+
"multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
|
| 128 |
+
"multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
|
| 129 |
+
}
|
| 130 |
|
| 131 |
if project_name is None:
|
| 132 |
# Look for checkpoints in local directory
|
|
|
|
| 136 |
files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
|
| 137 |
files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
|
| 138 |
break
|
| 139 |
+
# Fallback to remote ckpts if none found locally
|
| 140 |
+
if not files_checkpoints:
|
| 141 |
+
files_checkpoints = list(remote_ckpts.values())
|
| 142 |
else:
|
| 143 |
if os.path.isdir(checkpoint_dir[0]):
|
| 144 |
files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
|
| 145 |
files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
|
| 146 |
else:
|
| 147 |
+
# No local ckpts for this project, try remote mapping
|
| 148 |
+
ckpt = remote_ckpts.get(project_name)
|
| 149 |
+
files_checkpoints = [ckpt] if ckpt is not None else []
|
| 150 |
print("files_checkpoints:", project_name, files_checkpoints)
|
| 151 |
# Separate pretrained and regular checkpoints
|
| 152 |
pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
|
|
|
|
| 203 |
):
|
| 204 |
global last_checkpoint, last_device, tts_api, last_ema
|
| 205 |
|
| 206 |
+
# Resolve checkpoint path (local or HF URL)
|
| 207 |
ckpt_path = file_checkpoint
|
| 208 |
+
if isinstance(ckpt_path, str) and ckpt_path.startswith("hf://"):
|
| 209 |
+
try:
|
| 210 |
+
ckpt_resolved = str(cached_path(ckpt_path))
|
| 211 |
+
except Exception as e:
|
| 212 |
+
traceback.print_exc()
|
| 213 |
+
return None, f"Error downloading checkpoint: {str(e)}", ""
|
| 214 |
+
else:
|
| 215 |
+
ckpt_resolved = ckpt_path
|
| 216 |
|
| 217 |
if not os.path.isfile(ckpt_resolved):
|
| 218 |
return None, "Checkpoint not found!", ""
|