Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_mix.py
Browse files- gradio_mix.py +49 -18
gradio_mix.py
CHANGED
|
@@ -15,6 +15,7 @@ import jieba, zhconv
|
|
| 15 |
from pypinyin.core import Pinyin
|
| 16 |
from pypinyin import Style
|
| 17 |
|
|
|
|
| 18 |
from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
|
| 19 |
from lemas_tts.infer.edit_multilingual import gen_wav_multilingual
|
| 20 |
from lemas_tts.infer.text_norm.txt2pinyin import (
|
|
@@ -46,21 +47,18 @@ DEMO_PATH = os.getenv("DEMO_PATH", "./pretrained_models/demo")
|
|
| 46 |
TMP_PATH = os.getenv("TMP_PATH", "./pretrained_models/demo/temp")
|
| 47 |
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
def _pick_device():
|
| 54 |
forced = os.getenv("LEMAS_DEVICE")
|
| 55 |
if forced:
|
| 56 |
return forced
|
| 57 |
-
if torch.cuda.is_available()
|
| 58 |
-
try:
|
| 59 |
-
torch.zeros(1).to("cuda")
|
| 60 |
-
return "cuda"
|
| 61 |
-
except Exception as e:
|
| 62 |
-
logging.warning("CUDA appears available but failed (%s); falling back to CPU.", e)
|
| 63 |
-
return "cpu"
|
| 64 |
|
| 65 |
device = _pick_device()
|
| 66 |
ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
|
|
@@ -355,10 +353,8 @@ class MMSAlignModel:
|
|
| 355 |
def __init__(self):
|
| 356 |
from torchaudio.pipelines import MMS_FA as bundle
|
| 357 |
self.mms_model = bundle.get_model()
|
| 358 |
-
# MMS
|
| 359 |
-
|
| 360 |
-
# model still uses GPU.
|
| 361 |
-
self.mms_model.to("cpu")
|
| 362 |
self.mms_tokenizer = bundle.get_tokenizer()
|
| 363 |
self.mms_aligner = bundle.get_aligner()
|
| 364 |
self.text_normalizer = ur.Uroman()
|
|
@@ -380,7 +376,7 @@ class MMSAlignModel:
|
|
| 380 |
|
| 381 |
def compute_alignments(self, waveform: torch.Tensor, tokens):
|
| 382 |
with torch.inference_mode():
|
| 383 |
-
emission, _ = self.mms_model(waveform.to(
|
| 384 |
token_spans = self.mms_aligner(emission[0], tokens)
|
| 385 |
return emission, token_spans
|
| 386 |
|
|
@@ -399,7 +395,7 @@ class MMSAlignModel:
|
|
| 399 |
assert len(text_normed) == len(raw_text), f"normalized text len != raw text len: {len(text_normed)} != {len(raw_text)}"
|
| 400 |
tokens = self.mms_tokenizer(text_normed)
|
| 401 |
with torch.inference_mode():
|
| 402 |
-
emission, _ = self.mms_model(waveform.to(
|
| 403 |
token_spans = self.mms_aligner(emission[0], tokens)
|
| 404 |
num_frames = emission.size(1)
|
| 405 |
ratio = waveform.size(1) / num_frames
|
|
@@ -562,12 +558,41 @@ def load_models(lemas_model_name, whisper_model_name, alignment_model_name, deno
|
|
| 562 |
# Load LEMAS-TTS editing model (selected multilingual variant)
|
| 563 |
from pathlib import Path
|
| 564 |
|
|
|
|
| 565 |
ckpt_dir = Path(CKPTS_ROOT) / lemas_model_name
|
| 566 |
ckpt_candidates = sorted(
|
| 567 |
list(ckpt_dir.glob("*.safetensors")) + list(ckpt_dir.glob("*.pt"))
|
| 568 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
if not ckpt_candidates:
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
ckpt_file = str(ckpt_candidates[-1])
|
| 572 |
|
| 573 |
vocab_file = Path(PRETRAINED_ROOT) / "data" / lemas_model_name / "vocab.txt"
|
|
@@ -1201,6 +1226,12 @@ if __name__ == "__main__":
|
|
| 1201 |
parser.add_argument("--port", default=41020, type=int, help="App port")
|
| 1202 |
parser.add_argument("--share", action="store_true", help="Launch with public url")
|
| 1203 |
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
|
| 1205 |
os.environ["USER"] = os.getenv("USER", "user")
|
| 1206 |
args = parser.parse_args()
|
|
|
|
| 15 |
from pypinyin.core import Pinyin
|
| 16 |
from pypinyin import Style
|
| 17 |
|
| 18 |
+
from cached_path import cached_path
|
| 19 |
from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
|
| 20 |
from lemas_tts.infer.edit_multilingual import gen_wav_multilingual
|
| 21 |
from lemas_tts.infer.text_norm.txt2pinyin import (
|
|
|
|
| 47 |
TMP_PATH = os.getenv("TMP_PATH", "./pretrained_models/demo/temp")
|
| 48 |
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
|
| 49 |
|
| 50 |
+
# HF location for large TTS checkpoints (too big for Space storage).
|
| 51 |
+
# Mirrors LEMAS-TTS `inference_gradio.py`.
|
| 52 |
+
HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
|
| 53 |
+
|
| 54 |
+
# Pick device for the TTS editing model.
|
| 55 |
+
# - Default: "cuda" if available, else "cpu"
|
| 56 |
+
# - You can override via LEMAS_DEVICE env (e.g. "cpu" or "cuda").
|
| 57 |
def _pick_device():
|
| 58 |
forced = os.getenv("LEMAS_DEVICE")
|
| 59 |
if forced:
|
| 60 |
return forced
|
| 61 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
device = _pick_device()
|
| 64 |
ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
|
|
|
|
| 353 |
def __init__(self):
|
| 354 |
from torchaudio.pipelines import MMS_FA as bundle
|
| 355 |
self.mms_model = bundle.get_model()
|
| 356 |
+
# Keep MMS on the same device as the main edit model unless overridden.
|
| 357 |
+
self.mms_model.to(device)
|
|
|
|
|
|
|
| 358 |
self.mms_tokenizer = bundle.get_tokenizer()
|
| 359 |
self.mms_aligner = bundle.get_aligner()
|
| 360 |
self.text_normalizer = ur.Uroman()
|
|
|
|
| 376 |
|
| 377 |
def compute_alignments(self, waveform: torch.Tensor, tokens):
|
| 378 |
with torch.inference_mode():
|
| 379 |
+
emission, _ = self.mms_model(waveform.to(device))
|
| 380 |
token_spans = self.mms_aligner(emission[0], tokens)
|
| 381 |
return emission, token_spans
|
| 382 |
|
|
|
|
| 395 |
assert len(text_normed) == len(raw_text), f"normalized text len != raw text len: {len(text_normed)} != {len(raw_text)}"
|
| 396 |
tokens = self.mms_tokenizer(text_normed)
|
| 397 |
with torch.inference_mode():
|
| 398 |
+
emission, _ = self.mms_model(waveform.to(device))
|
| 399 |
token_spans = self.mms_aligner(emission[0], tokens)
|
| 400 |
num_frames = emission.size(1)
|
| 401 |
ratio = waveform.size(1) / num_frames
|
|
|
|
| 558 |
# Load LEMAS-TTS editing model (selected multilingual variant)
|
| 559 |
from pathlib import Path
|
| 560 |
|
| 561 |
+
# Local ckpt search under the standard CKPTS_ROOT layout
|
| 562 |
ckpt_dir = Path(CKPTS_ROOT) / lemas_model_name
|
| 563 |
ckpt_candidates = sorted(
|
| 564 |
list(ckpt_dir.glob("*.safetensors")) + list(ckpt_dir.glob("*.pt"))
|
| 565 |
)
|
| 566 |
+
# Fallbacks for simpler layouts: allow ckpts directly under CKPTS_ROOT,
|
| 567 |
+
# e.g. ./pretrained_models/ckpts/multilingual_grl.safetensors
|
| 568 |
+
if not ckpt_candidates:
|
| 569 |
+
root_candidates = sorted(
|
| 570 |
+
list(Path(CKPTS_ROOT).glob(f"{lemas_model_name}*.safetensors"))
|
| 571 |
+
+ list(Path(CKPTS_ROOT).glob(f"{lemas_model_name}*.pt"))
|
| 572 |
+
)
|
| 573 |
+
ckpt_candidates = root_candidates
|
| 574 |
+
|
| 575 |
+
# If no local ckpt is found, fall back to remote HF checkpoints
|
| 576 |
+
# (using the same mapping as LEMAS-TTS `inference_gradio.py`).
|
| 577 |
if not ckpt_candidates:
|
| 578 |
+
remote_ckpts = {
|
| 579 |
+
"multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
|
| 580 |
+
"multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
|
| 581 |
+
}
|
| 582 |
+
remote_path = remote_ckpts.get(lemas_model_name)
|
| 583 |
+
if remote_path is not None:
|
| 584 |
+
try:
|
| 585 |
+
resolved = cached_path(remote_path)
|
| 586 |
+
ckpt_candidates = [Path(resolved)]
|
| 587 |
+
logging.info("Resolved remote ckpt %s -> %s", remote_path, resolved)
|
| 588 |
+
except Exception as e:
|
| 589 |
+
raise gr.Error(f"Failed to download remote ckpt {remote_path}: {e}")
|
| 590 |
+
|
| 591 |
+
if not ckpt_candidates:
|
| 592 |
+
raise gr.Error(
|
| 593 |
+
f"No LEMAS-TTS ckpt found for '{lemas_model_name}' under {ckpt_dir} "
|
| 594 |
+
f"or {CKPTS_ROOT}"
|
| 595 |
+
)
|
| 596 |
ckpt_file = str(ckpt_candidates[-1])
|
| 597 |
|
| 598 |
vocab_file = Path(PRETRAINED_ROOT) / "data" / lemas_model_name / "vocab.txt"
|
|
|
|
| 1226 |
parser.add_argument("--port", default=41020, type=int, help="App port")
|
| 1227 |
parser.add_argument("--share", action="store_true", help="Launch with public url")
|
| 1228 |
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
|
| 1229 |
+
parser.add_argument(
|
| 1230 |
+
"--models-path",
|
| 1231 |
+
default="./pretrained_models",
|
| 1232 |
+
dest="models_path",
|
| 1233 |
+
help="Path to pretrained_models root (mirrors LEMAS-TTS layout).",
|
| 1234 |
+
)
|
| 1235 |
|
| 1236 |
os.environ["USER"] = os.getenv("USER", "user")
|
| 1237 |
args = parser.parse_args()
|